Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CDK: SingleUseRefreshTokenOauth2Authenticator update config with access tokens and expiration date #20923

Merged
merged 11 commits into from
Jan 3, 2023
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.16.3
Do not eagerly refresh access token in `SingleUseRefreshTokenOauth2Authenticator` [#20923](https://github.com/airbytehq/airbyte/pull/20923)

## 0.16.2
Fix the naming of OAuthAuthenticator

Expand Down
21 changes: 11 additions & 10 deletions airbyte-cdk/python/airbyte_cdk/config_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,7 @@ def set_config(self, config: ObservedDict) -> None:
self.config = config

def update(self) -> None:
self._emit_airbyte_control_message()

def _emit_airbyte_control_message(self) -> None:
control_message = AirbyteControlMessage(
type=OrchestratorType.CONNECTOR_CONFIG,
emitted_at=time.time() * 1000,
connectorConfig=AirbyteControlConnectorConfigMessage(config=self.config),
)
airbyte_message = AirbyteMessage(type=Type.CONTROL, control=control_message)
print(airbyte_message.json(exclude_unset=True))
emit_configuration_as_airbyte_control_message(self.config)


def observe_connector_config(non_observed_connector_config: MutableMapping[str, Any]):
Expand All @@ -74,3 +65,13 @@ def observe_connector_config(non_observed_connector_config: MutableMapping[str,
observed_connector_config = ObservedDict(non_observed_connector_config, connector_config_observer)
connector_config_observer.set_config(observed_connector_config)
return observed_connector_config


def emit_configuration_as_airbyte_control_message(config: MutableMapping):
control_message = AirbyteControlMessage(
type=OrchestratorType.CONNECTOR_CONFIG,
emitted_at=time.time() * 1000,
connectorConfig=AirbyteControlConnectorConfigMessage(config=config),
)
airbyte_message = AirbyteMessage(type=Type.CONTROL, control=control_message)
print(airbyte_message.json(exclude_unset=True))
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import dpath
import pendulum
from airbyte_cdk.config_observation import observe_connector_config
from airbyte_cdk.config_observation import emit_configuration_as_airbyte_control_message
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator


Expand Down Expand Up @@ -104,51 +104,53 @@ def __init__(
connector_config: Mapping[str, Any],
token_refresh_endpoint: str,
scopes: List[str] = None,
token_expiry_date: pendulum.DateTime = None,
token_expiry_date_format: str = None,
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_token_name: str = "refresh_token",
refresh_request_body: Mapping[str, Any] = None,
grant_type: str = "refresh_token",
client_id_config_path: Sequence[str] = ("credentials", "client_id"),
client_secret_config_path: Sequence[str] = ("credentials", "client_secret"),
access_token_config_path: Sequence[str] = ("credentials", "access_token"),
refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"),
token_expiry_date_config_path: Sequence[str] = ("credentials", "token_expiry_date"),
):
"""

Args:
connector_config (Mapping[str, Any]): The full connector configuration
token_refresh_endpoint (str): Full URL to the token refresh endpoint
scopes (List[str], optional): List of OAuth scopes to pass in the refresh token request body. Defaults to None.
token_expiry_date (pendulum.DateTime, optional): Datetime at which the current token will expire. Defaults to None.
access_token_name (str, optional): Name of the access token field, used to parse the refresh token response. Defaults to "access_token".
expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in".
refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None.
grant_type (str, optional): OAuth grant type. Defaults to "refresh_token".
client_id_config_path (Sequence[str]): Dpath to the client_id field in the connector configuration. Defaults to ("credentials", "client_id").
client_secret_config_path (Sequence[str]): Dpath to the client_secret field in the connector configuration. Defaults to ("credentials", "client_secret").
access_token_config_path (Sequence[str]): Dpath to the access_token field in the connector configuration. Defaults to ("credentials", "access_token").
refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token").
token_expiry_date_config_path (Sequence[str]): Dpath to the token_expiry_date field in the connector configuration. Defaults to ("credentials", "token_expiry_date").
"""
self._client_id_config_path = client_id_config_path
self._client_secret_config_path = client_secret_config_path
self._access_token_config_path = access_token_config_path
self._refresh_token_config_path = refresh_token_config_path
self._token_expiry_date_config_path = token_expiry_date_config_path
self._refresh_token_name = refresh_token_name
self._connector_config = observe_connector_config(connector_config)
self._connector_config = connector_config
self._validate_connector_config()
super().__init__(
token_refresh_endpoint,
self.get_client_id(),
self.get_client_secret(),
self.get_refresh_token(),
scopes,
token_expiry_date,
token_expiry_date_format,
access_token_name,
expires_in_name,
refresh_request_body,
grant_type,
scopes=scopes,
token_expiry_date=self.get_token_expiry_date(),
access_token_name=access_token_name,
expires_in_name=expires_in_name,
refresh_request_body=refresh_request_body,
grant_type=grant_type,
)

def _validate_connector_config(self):
Expand All @@ -157,10 +159,17 @@ def _validate_connector_config(self):
Raises:
ValueError: Raised if the defined getters are not returning a value.
"""
try:
assert self.access_token
except KeyError:
raise ValueError(
f"This authenticator expects a value under the {self._access_token_config_path} field path. Please check your configuration structure or change the access_token_config_path value at initialization of this authenticator."
)
for field_path, getter, parameter_name in [
(self._client_id_config_path, self.get_client_id, "client_id_config_path"),
(self._client_secret_config_path, self.get_client_secret, "client_secret_config_path"),
(self._refresh_token_config_path, self.get_refresh_token, "refresh_token_config_path"),
(self._token_expiry_date_config_path, self.get_token_expiry_date, "token_expiry_date_config_path"),
]:
try:
assert getter()
Expand All @@ -178,29 +187,47 @@ def get_client_id(self) -> str:
def get_client_secret(self) -> str:
return dpath.util.get(self._connector_config, self._client_secret_config_path)

@property
def access_token(self) -> str:
return dpath.util.get(self._connector_config, self._access_token_config_path)

@access_token.setter
def access_token(self, new_access_token: str):
dpath.util.set(self._connector_config, self._access_token_config_path, new_access_token)

def get_refresh_token(self) -> str:
return dpath.util.get(self._connector_config, self._refresh_token_config_path)

def set_refresh_token(self, new_refresh_token: str):
"""Set the new refresh token value. The mutation of the connector_config object will emit an Airbyte control message.

Args:
new_refresh_token (str): The new refresh token value.
"""
dpath.util.set(self._connector_config, self._refresh_token_config_path, new_refresh_token)

def get_token_expiry_date(self) -> pendulum.DateTime:
return pendulum.parse(dpath.util.get(self._connector_config, self._token_expiry_date_config_path))

def set_token_expiry_date(self, new_token_expiry_date):
dpath.util.set(self._connector_config, self._token_expiry_date_config_path, str(new_token_expiry_date))

def token_has_expired(self) -> bool:
"""Returns True if the token is expired"""
return pendulum.now("UTC") > self.get_token_expiry_date()

@staticmethod
def get_new_token_expiry_date(access_token_expires_in: int):
return pendulum.now("UTC").add(seconds=access_token_expires_in)

def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
The new refresh token is persisted with the set_refresh_token function
Returns:
str: The current access_token, updated if it was previously expired.
"""
if self.token_has_expired():
t0 = pendulum.now()
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
new_token_expiry_date = self.get_new_token_expiry_date(access_token_expires_in)
self.access_token = new_access_token
self.set_token_expiry_date(t0, access_token_expires_in)
self.set_refresh_token(new_refresh_token)
self.set_token_expiry_date(new_token_expiry_date)
emit_configuration_as_airbyte_control_message(self._connector_config)
return self.access_token

def refresh_access_token(self) -> Tuple[str, int, str]:
Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="airbyte-cdk",
version="0.16.2",
version="0.16.3",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import json
import logging

import freezegun
import pendulum
import pytest
import requests
from airbyte_cdk.config_observation import ObservedDict
from airbyte_cdk.sources.streams.http.requests_native_auth import (
BasicHttpAuthenticator,
MultipleTokenAuthenticator,
Expand Down Expand Up @@ -188,6 +188,7 @@ def connector_config(self):
"refresh_token": "my_refresh_token",
"client_id": "my_client_id",
"client_secret": "my_client_secret",
"token_expiry_date": "2022-12-31T00:00:00+00:00"
}
}

Expand All @@ -200,7 +201,9 @@ def test_init(self, connector_config):
connector_config,
token_refresh_endpoint="foobar",
)
assert isinstance(authenticator._connector_config, ObservedDict)
assert authenticator.access_token == connector_config["credentials"]["access_token"]
assert authenticator.get_refresh_token() == connector_config["credentials"]["refresh_token"]
assert authenticator.get_token_expiry_date() == pendulum.parse(connector_config["credentials"]["token_expiry_date"])

def test_init_with_invalid_config(self, invalid_connector_config):
with pytest.raises(ValueError):
Expand All @@ -209,6 +212,7 @@ def test_init_with_invalid_config(self, invalid_connector_config):
token_refresh_endpoint="foobar",
)

@freezegun.freeze_time("2022-12-31")
def test_get_access_token(self, capsys, mocker, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
Expand All @@ -220,7 +224,9 @@ def test_get_access_token(self, capsys, mocker, connector_config):
captured = capsys.readouterr()
airbyte_message = json.loads(captured.out)
expected_new_config = connector_config.copy()
expected_new_config["credentials"]["access_token"] = "new_access_token"
expected_new_config["credentials"]["refresh_token"] = "new_refresh_token"
expected_new_config["credentials"]["token_expiry_date"] = "2022-12-31T00:00:42+00:00"
assert airbyte_message["control"]["connectorConfig"]["config"] == expected_new_config
assert authenticator.access_token == access_token == "new_access_token"
assert authenticator.get_refresh_token() == "new_refresh_token"
Expand Down