Skip to content

Commit

Permalink
CDK: SingleUseRefreshTokenOauth2Authenticator update config with ac…
Browse files Browse the repository at this point in the history
…cess tokens and expiration date (#20923)
  • Loading branch information
alafanechere authored Jan 3, 2023
1 parent cae6396 commit 254714b
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 32 deletions.
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.16.3
Do not eagerly refresh access token in `SingleUseRefreshTokenOauth2Authenticator` [#20923](https://github.com/airbytehq/airbyte/pull/20923)

## 0.16.2
Fix the naming of OAuthAuthenticator

Expand Down
21 changes: 11 additions & 10 deletions airbyte-cdk/python/airbyte_cdk/config_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,7 @@ def set_config(self, config: ObservedDict) -> None:
self.config = config

def update(self) -> None:
self._emit_airbyte_control_message()

def _emit_airbyte_control_message(self) -> None:
control_message = AirbyteControlMessage(
type=OrchestratorType.CONNECTOR_CONFIG,
emitted_at=time.time() * 1000,
connectorConfig=AirbyteControlConnectorConfigMessage(config=self.config),
)
airbyte_message = AirbyteMessage(type=Type.CONTROL, control=control_message)
print(airbyte_message.json(exclude_unset=True))
emit_configuration_as_airbyte_control_message(self.config)


def observe_connector_config(non_observed_connector_config: MutableMapping[str, Any]):
Expand All @@ -74,3 +65,13 @@ def observe_connector_config(non_observed_connector_config: MutableMapping[str,
observed_connector_config = ObservedDict(non_observed_connector_config, connector_config_observer)
connector_config_observer.set_config(observed_connector_config)
return observed_connector_config


def emit_configuration_as_airbyte_control_message(config: MutableMapping):
control_message = AirbyteControlMessage(
type=OrchestratorType.CONNECTOR_CONFIG,
emitted_at=time.time() * 1000,
connectorConfig=AirbyteControlConnectorConfigMessage(config=config),
)
airbyte_message = AirbyteMessage(type=Type.CONTROL, control=control_message)
print(airbyte_message.json(exclude_unset=True))
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import dpath
import pendulum
from airbyte_cdk.config_observation import observe_connector_config
from airbyte_cdk.config_observation import emit_configuration_as_airbyte_control_message
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator


Expand Down Expand Up @@ -104,51 +104,53 @@ def __init__(
connector_config: Mapping[str, Any],
token_refresh_endpoint: str,
scopes: List[str] = None,
token_expiry_date: pendulum.DateTime = None,
token_expiry_date_format: str = None,
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_token_name: str = "refresh_token",
refresh_request_body: Mapping[str, Any] = None,
grant_type: str = "refresh_token",
client_id_config_path: Sequence[str] = ("credentials", "client_id"),
client_secret_config_path: Sequence[str] = ("credentials", "client_secret"),
access_token_config_path: Sequence[str] = ("credentials", "access_token"),
refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"),
token_expiry_date_config_path: Sequence[str] = ("credentials", "token_expiry_date"),
):
"""
Args:
connector_config (Mapping[str, Any]): The full connector configuration
token_refresh_endpoint (str): Full URL to the token refresh endpoint
scopes (List[str], optional): List of OAuth scopes to pass in the refresh token request body. Defaults to None.
token_expiry_date (pendulum.DateTime, optional): Datetime at which the current token will expire. Defaults to None.
access_token_name (str, optional): Name of the access token field, used to parse the refresh token response. Defaults to "access_token".
expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in".
refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None.
grant_type (str, optional): OAuth grant type. Defaults to "refresh_token".
client_id_config_path (Sequence[str]): Dpath to the client_id field in the connector configuration. Defaults to ("credentials", "client_id").
client_secret_config_path (Sequence[str]): Dpath to the client_secret field in the connector configuration. Defaults to ("credentials", "client_secret").
access_token_config_path (Sequence[str]): Dpath to the access_token field in the connector configuration. Defaults to ("credentials", "access_token").
refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token").
token_expiry_date_config_path (Sequence[str]): Dpath to the token_expiry_date field in the connector configuration. Defaults to ("credentials", "token_expiry_date").
"""
self._client_id_config_path = client_id_config_path
self._client_secret_config_path = client_secret_config_path
self._access_token_config_path = access_token_config_path
self._refresh_token_config_path = refresh_token_config_path
self._token_expiry_date_config_path = token_expiry_date_config_path
self._refresh_token_name = refresh_token_name
self._connector_config = observe_connector_config(connector_config)
self._connector_config = connector_config
self._validate_connector_config()
super().__init__(
token_refresh_endpoint,
self.get_client_id(),
self.get_client_secret(),
self.get_refresh_token(),
scopes,
token_expiry_date,
token_expiry_date_format,
access_token_name,
expires_in_name,
refresh_request_body,
grant_type,
scopes=scopes,
token_expiry_date=self.get_token_expiry_date(),
access_token_name=access_token_name,
expires_in_name=expires_in_name,
refresh_request_body=refresh_request_body,
grant_type=grant_type,
)

def _validate_connector_config(self):
Expand All @@ -157,10 +159,17 @@ def _validate_connector_config(self):
Raises:
ValueError: Raised if the defined getters are not returning a value.
"""
try:
assert self.access_token
except KeyError:
raise ValueError(
f"This authenticator expects a value under the {self._access_token_config_path} field path. Please check your configuration structure or change the access_token_config_path value at initialization of this authenticator."
)
for field_path, getter, parameter_name in [
(self._client_id_config_path, self.get_client_id, "client_id_config_path"),
(self._client_secret_config_path, self.get_client_secret, "client_secret_config_path"),
(self._refresh_token_config_path, self.get_refresh_token, "refresh_token_config_path"),
(self._token_expiry_date_config_path, self.get_token_expiry_date, "token_expiry_date_config_path"),
]:
try:
assert getter()
Expand All @@ -178,29 +187,47 @@ def get_client_id(self) -> str:
def get_client_secret(self) -> str:
return dpath.util.get(self._connector_config, self._client_secret_config_path)

@property
def access_token(self) -> str:
return dpath.util.get(self._connector_config, self._access_token_config_path)

@access_token.setter
def access_token(self, new_access_token: str):
dpath.util.set(self._connector_config, self._access_token_config_path, new_access_token)

def get_refresh_token(self) -> str:
return dpath.util.get(self._connector_config, self._refresh_token_config_path)

def set_refresh_token(self, new_refresh_token: str):
"""Set the new refresh token value. The mutation of the connector_config object will emit an Airbyte control message.
Args:
new_refresh_token (str): The new refresh token value.
"""
dpath.util.set(self._connector_config, self._refresh_token_config_path, new_refresh_token)

def get_token_expiry_date(self) -> pendulum.DateTime:
return pendulum.parse(dpath.util.get(self._connector_config, self._token_expiry_date_config_path))

def set_token_expiry_date(self, new_token_expiry_date):
dpath.util.set(self._connector_config, self._token_expiry_date_config_path, str(new_token_expiry_date))

def token_has_expired(self) -> bool:
"""Returns True if the token is expired"""
return pendulum.now("UTC") > self.get_token_expiry_date()

@staticmethod
def get_new_token_expiry_date(access_token_expires_in: int):
return pendulum.now("UTC").add(seconds=access_token_expires_in)

def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
The new refresh token is persisted with the set_refresh_token function
Returns:
str: The current access_token, updated if it was previously expired.
"""
if self.token_has_expired():
t0 = pendulum.now()
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
new_token_expiry_date = self.get_new_token_expiry_date(access_token_expires_in)
self.access_token = new_access_token
self.set_token_expiry_date(t0, access_token_expires_in)
self.set_refresh_token(new_refresh_token)
self.set_token_expiry_date(new_token_expiry_date)
emit_configuration_as_airbyte_control_message(self._connector_config)
return self.access_token

def refresh_access_token(self) -> Tuple[str, int, str]:
Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="airbyte-cdk",
version="0.16.2",
version="0.16.3",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import json
import logging

import freezegun
import pendulum
import pytest
import requests
from airbyte_cdk.config_observation import ObservedDict
from airbyte_cdk.sources.streams.http.requests_native_auth import (
BasicHttpAuthenticator,
MultipleTokenAuthenticator,
Expand Down Expand Up @@ -188,6 +188,7 @@ def connector_config(self):
"refresh_token": "my_refresh_token",
"client_id": "my_client_id",
"client_secret": "my_client_secret",
"token_expiry_date": "2022-12-31T00:00:00+00:00"
}
}

Expand All @@ -200,7 +201,9 @@ def test_init(self, connector_config):
connector_config,
token_refresh_endpoint="foobar",
)
assert isinstance(authenticator._connector_config, ObservedDict)
assert authenticator.access_token == connector_config["credentials"]["access_token"]
assert authenticator.get_refresh_token() == connector_config["credentials"]["refresh_token"]
assert authenticator.get_token_expiry_date() == pendulum.parse(connector_config["credentials"]["token_expiry_date"])

def test_init_with_invalid_config(self, invalid_connector_config):
with pytest.raises(ValueError):
Expand All @@ -209,6 +212,7 @@ def test_init_with_invalid_config(self, invalid_connector_config):
token_refresh_endpoint="foobar",
)

@freezegun.freeze_time("2022-12-31")
def test_get_access_token(self, capsys, mocker, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
Expand All @@ -220,7 +224,9 @@ def test_get_access_token(self, capsys, mocker, connector_config):
captured = capsys.readouterr()
airbyte_message = json.loads(captured.out)
expected_new_config = connector_config.copy()
expected_new_config["credentials"]["access_token"] = "new_access_token"
expected_new_config["credentials"]["refresh_token"] = "new_refresh_token"
expected_new_config["credentials"]["token_expiry_date"] = "2022-12-31T00:00:42+00:00"
assert airbyte_message["control"]["connectorConfig"]["config"] == expected_new_config
assert authenticator.access_token == access_token == "new_access_token"
assert authenticator.get_refresh_token() == "new_refresh_token"
Expand Down

0 comments on commit 254714b

Please sign in to comment.