diff --git a/airbyte-cdk/python/CHANGELOG.md b/airbyte-cdk/python/CHANGELOG.md index b25d00e39b23..29074f4a791d 100644 --- a/airbyte-cdk/python/CHANGELOG.md +++ b/airbyte-cdk/python/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 0.1.22 +Allow passing custom headers to request in `OAuth2Authenticator.refresh_access_token()`: https://github.com/airbytehq/airbyte/pull/6219 + ## 0.1.21 Resolve nested schema references and move external references to single schema definitions. diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/auth/oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/auth/oauth.py index b76cf962ffb9..8e541c259dfd 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/auth/oauth.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/auth/oauth.py @@ -39,12 +39,21 @@ class Oauth2Authenticator(HttpAuthenticator): The generated access token is attached to each request via the Authorization header. """ - def __init__(self, token_refresh_endpoint: str, client_id: str, client_secret: str, refresh_token: str, scopes: List[str] = None): + def __init__( + self, + token_refresh_endpoint: str, + client_id: str, + client_secret: str, + refresh_token: str, + scopes: List[str] = None, + refresh_access_token_headers: Mapping[str, Any] = None, + ): self.token_refresh_endpoint = token_refresh_endpoint self.client_secret = client_secret self.client_id = client_id self.refresh_token = refresh_token self.scopes = scopes + self.refresh_access_token_headers = refresh_access_token_headers self._token_expiry_date = pendulum.now().subtract(days=1) self._access_token = None @@ -83,7 +92,12 @@ def refresh_access_token(self) -> Tuple[str, int]: returns a tuple of (access_token, token_lifespan_in_seconds) """ try: - response = requests.request(method="POST", url=self.token_refresh_endpoint, data=self.get_refresh_request_body()) + response = requests.request( + method="POST", + url=self.token_refresh_endpoint, + data=self.get_refresh_request_body(), + headers=self.refresh_access_token_headers, + ) response.raise_for_status() response_json = response.json() return response_json["access_token"], response_json["expires_in"] diff --git a/airbyte-cdk/python/setup.py b/airbyte-cdk/python/setup.py index 8be2e3ce70e6..104c67d4b94b 100644 --- a/airbyte-cdk/python/setup.py +++ b/airbyte-cdk/python/setup.py @@ -35,7 +35,7 @@ setup( name="airbyte-cdk", - version="0.1.21", + version="0.1.22", description="A framework for writing Airbyte Connectors.", long_description=README, long_description_content_type="text/markdown", diff --git a/airbyte-cdk/python/unit_tests/destinations/__init__.py b/airbyte-cdk/python/unit_tests/destinations/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/singer/__init__.py b/airbyte-cdk/python/unit_tests/singer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/__init__.py b/airbyte-cdk/python/unit_tests/sources/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/streams/__init__.py b/airbyte-cdk/python/unit_tests/sources/streams/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/streams/http/__init__.py b/airbyte-cdk/python/unit_tests/sources/streams/http/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/streams/http/auth/__init__.py b/airbyte-cdk/python/unit_tests/sources/streams/http/auth/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/airbyte-cdk/python/unit_tests/sources/streams/http/auth/test_auth.py b/airbyte-cdk/python/unit_tests/sources/streams/http/auth/test_auth.py index 3e561af92acb..b9961346ac86 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/http/auth/test_auth.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/http/auth/test_auth.py @@ -25,9 +25,7 @@ import logging -import requests from airbyte_cdk.sources.streams.http.auth import MultipleTokenAuthenticator, NoAuth, Oauth2Authenticator, TokenAuthenticator -from requests import Response LOGGER = logging.getLogger(__name__) @@ -68,10 +66,11 @@ class TestOauth2Authenticator: Test class for OAuth2Authenticator. """ - refresh_endpoint = "refresh_end" + refresh_endpoint = "https://some_url.com/v1" client_id = "client_id" client_secret = "client_secret" refresh_token = "refresh_token" + refresh_access_token_headers = {"Header_1": "value 1", "Header_2": "value 2"} def test_get_auth_header_fresh(self, mocker): """ @@ -130,18 +129,22 @@ def test_refresh_request_body(self): } assert body == expected - def test_refresh_access_token(self, mocker): + def test_refresh_access_token(self, requests_mock): + mock_refresh_token_call = requests_mock.post(TestOauth2Authenticator.refresh_endpoint, + json={"access_token": "token", "expires_in": 10}) + oauth = Oauth2Authenticator( TestOauth2Authenticator.refresh_endpoint, TestOauth2Authenticator.client_id, TestOauth2Authenticator.client_secret, TestOauth2Authenticator.refresh_token, + refresh_access_token_headers=TestOauth2Authenticator.refresh_access_token_headers, ) - resp = Response() - resp.status_code = 200 - mocker.patch.object(requests, "request", return_value=resp) - mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}) token = oauth.refresh_access_token() - assert ("access_token", 1000) == token + assert ("token", 10) == token + for header in self.refresh_access_token_headers: + assert header in mock_refresh_token_call.last_request.headers + assert self.refresh_access_token_headers[header] == mock_refresh_token_call.last_request.headers[header] + assert mock_refresh_token_call.called diff --git a/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/__init__.py b/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1