diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index f584d2888d4e..2cd987c166b4 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -104,8 +104,6 @@ def __init__( connector_config: Mapping[str, Any], token_refresh_endpoint: str, scopes: List[str] = None, - token_expiry_date: pendulum.DateTime = None, - token_expiry_date_format: str = None, access_token_name: str = "access_token", expires_in_name: str = "expires_in", refresh_token_name: str = "refresh_token", @@ -115,6 +113,7 @@ def __init__( client_secret_config_path: Sequence[str] = ("credentials", "client_secret"), access_token_config_path: Sequence[str] = ("credentials", "access_token"), refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"), + access_token_expiration_datetime_config_path: Sequence[str] = ("credentials", "access_token_expiration_datetime"), ): """ @@ -122,7 +121,6 @@ def __init__( connector_config (Mapping[str, Any]): The full connector configuration token_refresh_endpoint (str): Full URL to the token refresh endpoint scopes (List[str], optional): List of OAuth scopes to pass in the refresh token request body. Defaults to None. - token_expiry_date (pendulum.DateTime, optional): Datetime at which the current token will expire. Defaults to None. access_token_name (str, optional): Name of the access token field, used to parse the refresh token response. Defaults to "access_token". expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in". refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token". @@ -132,11 +130,13 @@ def __init__( client_secret_config_path (Sequence[str]): Dpath to the client_secret field in the connector configuration. Defaults to ("credentials", "client_secret"). access_token_config_path (Sequence[str]): Dpath to the access_token field in the connector configuration. Defaults to ("credentials", "access_token"). refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token"). + access_token_expiration_datetime_config_path (Sequence[str]): Dpath to the access_token_expiration_datetime field in the connector configuration. Defaults to ("credentials", "access_token_expiration_datetime"). """ self._client_id_config_path = client_id_config_path self._client_secret_config_path = client_secret_config_path self._access_token_config_path = access_token_config_path self._refresh_token_config_path = refresh_token_config_path + self._access_token_expiration_datetime_config_path = access_token_expiration_datetime_config_path self._refresh_token_name = refresh_token_name self._connector_config = observe_connector_config(connector_config) self._validate_connector_config() @@ -145,13 +145,12 @@ def __init__( self.get_client_id(), self.get_client_secret(), self.get_refresh_token(), - scopes, - token_expiry_date, - token_expiry_date_format, - access_token_name, - expires_in_name, - refresh_request_body, - grant_type, + scopes=scopes, + token_expiry_date=self.get_access_token_expiration_datetime(), + access_token_name=access_token_name, + expires_in_name=expires_in_name, + refresh_request_body=refresh_request_body, + grant_type=grant_type ) def _validate_connector_config(self): @@ -164,6 +163,7 @@ def _validate_connector_config(self): (self._client_id_config_path, self.get_client_id, "client_id_config_path"), (self._client_secret_config_path, self.get_client_secret, "client_secret_config_path"), (self._refresh_token_config_path, self.get_refresh_token, "refresh_token_config_path"), + (self._access_token_expiration_datetime_config_path, self.get_access_token_expiration_datetime, "access_token_expiration_datetime_config_path"), ]: try: assert getter() @@ -184,19 +184,24 @@ def get_client_secret(self) -> str: def get_refresh_token(self) -> str: return dpath.util.get(self._connector_config, self._refresh_token_config_path) + def get_access_token_expiration_datetime(self) -> pendulum.DateTime: + return pendulum.parse(dpath.util.get(self._connector_config, self._access_token_expiration_datetime_config_path)) + - def _update_config_with_access_and_refresh_tokens(self, new_access_token: str, new_refresh_token: str): + def _update_config_with_access_and_refresh_tokens(self, new_access_token: str, new_refresh_token: str, new_access_token_expiration_datetime: pendulum.DateTime): """Update the connector configuration with new access and refresh token values. The mutation of the connector_config object will emit Airbyte control messages. Args: new_access_token (str): The new access token value. new_refresh_token (str): The new refresh token value. + new_access_token_expiration_datetime (pendulum.DateTime): The new access token expiration date. """ - # TODO alafanechere this will sequentially emit two control messages. + # TODO alafanechere this will sequentially emit three control messages. # We should rework the observer/config mutation logic if we want to have atomic config updates in a single control message. dpath.util.set(self._connector_config, self._access_token_config_path, new_access_token) dpath.util.set(self._connector_config, self._refresh_token_config_path, new_refresh_token) + dpath.util.set(self._connector_config, self._access_token_expiration_datetime_config_path, new_access_token_expiration_datetime) def get_access_token(self) -> str: @@ -206,11 +211,10 @@ def get_access_token(self) -> str: str: The current access_token, updated if it was previously expired. """ if self.token_has_expired(): - t0 = pendulum.now() new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token() self.access_token = new_access_token - self.set_token_expiry_date(t0, access_token_expires_in) - self._update_config_with_access_and_refresh_tokens(new_access_token, new_refresh_token) + self.set_token_expiry_date(pendulum.now("UTC"), access_token_expires_in) + self._update_config_with_access_and_refresh_tokens(new_access_token, new_refresh_token, self.get_token_expiry_date()) return self.access_token def refresh_access_token(self) -> Tuple[str, int, str]: diff --git a/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index 093dd205a3cc..fa898b452ad8 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -5,6 +5,7 @@ import json import logging +import freezegun import pendulum import pytest import requests @@ -188,6 +189,7 @@ def connector_config(self): "refresh_token": "my_refresh_token", "client_id": "my_client_id", "client_secret": "my_client_secret", + "access_token_expiration_datetime": "2022-12-31T00:00:00+00:00" } } @@ -209,6 +211,7 @@ def test_init_with_invalid_config(self, invalid_connector_config): token_refresh_endpoint="foobar", ) + @freezegun.freeze_time("2022-12-31") def test_get_access_token(self, capsys, mocker, connector_config): authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, @@ -222,7 +225,7 @@ def test_get_access_token(self, capsys, mocker, connector_config): expected_new_config = connector_config.copy() expected_new_config["credentials"]["access_token"] = "new_access_token" expected_new_config["credentials"]["refresh_token"] = "new_refresh_token" - + expected_new_config["credentials"]["access_token_expiration_datetime"] = "2022-12-31T00:00:42+00:00" assert airbyte_message["control"]["connectorConfig"]["config"] == expected_new_config assert authenticator.access_token == access_token == "new_access_token" assert authenticator.get_refresh_token() == "new_refresh_token"