Skip to content

Commit

Permalink
read access token expiration date from config and update it
Browse files Browse the repository at this point in the history
  • Loading branch information
alafanechere committed Dec 29, 2022
1 parent 71ab084 commit 833d978
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ def __init__(
connector_config: Mapping[str, Any],
token_refresh_endpoint: str,
scopes: List[str] = None,
token_expiry_date: pendulum.DateTime = None,
token_expiry_date_format: str = None,
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_token_name: str = "refresh_token",
Expand All @@ -115,14 +113,14 @@ def __init__(
client_secret_config_path: Sequence[str] = ("credentials", "client_secret"),
access_token_config_path: Sequence[str] = ("credentials", "access_token"),
refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"),
access_token_expiration_datetime_config_path: Sequence[str] = ("credentials", "access_token_expiration_datetime"),
):
"""
Args:
connector_config (Mapping[str, Any]): The full connector configuration
token_refresh_endpoint (str): Full URL to the token refresh endpoint
scopes (List[str], optional): List of OAuth scopes to pass in the refresh token request body. Defaults to None.
token_expiry_date (pendulum.DateTime, optional): Datetime at which the current token will expire. Defaults to None.
access_token_name (str, optional): Name of the access token field, used to parse the refresh token response. Defaults to "access_token".
expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in".
refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
Expand All @@ -132,11 +130,13 @@ def __init__(
client_secret_config_path (Sequence[str]): Dpath to the client_secret field in the connector configuration. Defaults to ("credentials", "client_secret").
access_token_config_path (Sequence[str]): Dpath to the access_token field in the connector configuration. Defaults to ("credentials", "access_token").
refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token").
access_token_expiration_datetime_config_path (Sequence[str]): Dpath to the access_token_expiration_datetime field in the connector configuration. Defaults to ("credentials", "access_token_expiration_datetime").
"""
self._client_id_config_path = client_id_config_path
self._client_secret_config_path = client_secret_config_path
self._access_token_config_path = access_token_config_path
self._refresh_token_config_path = refresh_token_config_path
self._access_token_expiration_datetime_config_path = access_token_expiration_datetime_config_path
self._refresh_token_name = refresh_token_name
self._connector_config = observe_connector_config(connector_config)
self._validate_connector_config()
Expand All @@ -145,13 +145,12 @@ def __init__(
self.get_client_id(),
self.get_client_secret(),
self.get_refresh_token(),
scopes,
token_expiry_date,
token_expiry_date_format,
access_token_name,
expires_in_name,
refresh_request_body,
grant_type,
scopes=scopes,
token_expiry_date=self.get_access_token_expiration_datetime(),
access_token_name=access_token_name,
expires_in_name=expires_in_name,
refresh_request_body=refresh_request_body,
grant_type=grant_type
)

def _validate_connector_config(self):
Expand All @@ -164,6 +163,7 @@ def _validate_connector_config(self):
(self._client_id_config_path, self.get_client_id, "client_id_config_path"),
(self._client_secret_config_path, self.get_client_secret, "client_secret_config_path"),
(self._refresh_token_config_path, self.get_refresh_token, "refresh_token_config_path"),
(self._access_token_expiration_datetime_config_path, self.get_access_token_expiration_datetime, "access_token_expiration_datetime_config_path"),
]:
try:
assert getter()
Expand All @@ -184,19 +184,24 @@ def get_client_secret(self) -> str:
def get_refresh_token(self) -> str:
return dpath.util.get(self._connector_config, self._refresh_token_config_path)

def get_access_token_expiration_datetime(self) -> pendulum.DateTime:
return pendulum.parse(dpath.util.get(self._connector_config, self._access_token_expiration_datetime_config_path))


def _update_config_with_access_and_refresh_tokens(self, new_access_token: str, new_refresh_token: str):
def _update_config_with_access_and_refresh_tokens(self, new_access_token: str, new_refresh_token: str, new_access_token_expiration_datetime: pendulum.DateTime):
"""Update the connector configuration with new access and refresh token values.
The mutation of the connector_config object will emit Airbyte control messages.
Args:
new_access_token (str): The new access token value.
new_refresh_token (str): The new refresh token value.
new_access_token_expiration_datetime (pendulum.DateTime): The new access token expiration date.
"""
# TODO alafanechere this will sequentially emit two control messages.
# TODO alafanechere this will sequentially emit three control messages.
# We should rework the observer/config mutation logic if we want to have atomic config updates in a single control message.
dpath.util.set(self._connector_config, self._access_token_config_path, new_access_token)
dpath.util.set(self._connector_config, self._refresh_token_config_path, new_refresh_token)
dpath.util.set(self._connector_config, self._access_token_expiration_datetime_config_path, new_access_token_expiration_datetime)


def get_access_token(self) -> str:
Expand All @@ -206,11 +211,10 @@ def get_access_token(self) -> str:
str: The current access_token, updated if it was previously expired.
"""
if self.token_has_expired():
t0 = pendulum.now()
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
self.access_token = new_access_token
self.set_token_expiry_date(t0, access_token_expires_in)
self._update_config_with_access_and_refresh_tokens(new_access_token, new_refresh_token)
self.set_token_expiry_date(pendulum.now("UTC"), access_token_expires_in)
self._update_config_with_access_and_refresh_tokens(new_access_token, new_refresh_token, self.get_token_expiry_date())
return self.access_token

def refresh_access_token(self) -> Tuple[str, int, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging

import freezegun
import pendulum
import pytest
import requests
Expand Down Expand Up @@ -188,6 +189,7 @@ def connector_config(self):
"refresh_token": "my_refresh_token",
"client_id": "my_client_id",
"client_secret": "my_client_secret",
"access_token_expiration_datetime": "2022-12-31T00:00:00+00:00"
}
}

Expand All @@ -209,6 +211,7 @@ def test_init_with_invalid_config(self, invalid_connector_config):
token_refresh_endpoint="foobar",
)

@freezegun.freeze_time("2022-12-31")
def test_get_access_token(self, capsys, mocker, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
Expand All @@ -222,7 +225,7 @@ def test_get_access_token(self, capsys, mocker, connector_config):
expected_new_config = connector_config.copy()
expected_new_config["credentials"]["access_token"] = "new_access_token"
expected_new_config["credentials"]["refresh_token"] = "new_refresh_token"

expected_new_config["credentials"]["access_token_expiration_datetime"] = "2022-12-31T00:00:42+00:00"
assert airbyte_message["control"]["connectorConfig"]["config"] == expected_new_config
assert authenticator.access_token == access_token == "new_access_token"
assert authenticator.get_refresh_token() == "new_refresh_token"
Expand Down

0 comments on commit 833d978

Please sign in to comment.