Skip to content

Commit

Permalink
🎉 CDK: Add support for custom headers passing to the request in `OAut…
Browse files Browse the repository at this point in the history
…h2Authenticator. refresh_access_token` (#6219)

* Add support for headers to OAuth2Authenticator

Send custom headers in `refresh_access_token()`.

* Bump version + update CHANGELOG.md

* Add tests

* Update tests for refresh_access_token()

* Assert that mock_refresh_token_call was called

* Remove init file
  • Loading branch information
Zirochkaa authored Sep 22, 2021
1 parent a6c48dc commit 4dca327
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 12 deletions.
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.1.22
Allow passing custom headers to request in `OAuth2Authenticator.refresh_access_token()`: https://github.com/airbytehq/airbyte/pull/6219

## 0.1.21
Resolve nested schema references and move external references to single schema definitions.

Expand Down
18 changes: 16 additions & 2 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/http/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,21 @@ class Oauth2Authenticator(HttpAuthenticator):
The generated access token is attached to each request via the Authorization header.
"""

def __init__(self, token_refresh_endpoint: str, client_id: str, client_secret: str, refresh_token: str, scopes: List[str] = None):
def __init__(
self,
token_refresh_endpoint: str,
client_id: str,
client_secret: str,
refresh_token: str,
scopes: List[str] = None,
refresh_access_token_headers: Mapping[str, Any] = None,
):
self.token_refresh_endpoint = token_refresh_endpoint
self.client_secret = client_secret
self.client_id = client_id
self.refresh_token = refresh_token
self.scopes = scopes
self.refresh_access_token_headers = refresh_access_token_headers

self._token_expiry_date = pendulum.now().subtract(days=1)
self._access_token = None
Expand Down Expand Up @@ -83,7 +92,12 @@ def refresh_access_token(self) -> Tuple[str, int]:
returns a tuple of (access_token, token_lifespan_in_seconds)
"""
try:
response = requests.request(method="POST", url=self.token_refresh_endpoint, data=self.get_refresh_request_body())
response = requests.request(
method="POST",
url=self.token_refresh_endpoint,
data=self.get_refresh_request_body(),
headers=self.refresh_access_token_headers,
)
response.raise_for_status()
response_json = response.json()
return response_json["access_token"], response_json["expires_in"]
Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

setup(
name="airbyte-cdk",
version="0.1.21",
version="0.1.22",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@

import logging

import requests
from airbyte_cdk.sources.streams.http.auth import MultipleTokenAuthenticator, NoAuth, Oauth2Authenticator, TokenAuthenticator
from requests import Response

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,10 +66,11 @@ class TestOauth2Authenticator:
Test class for OAuth2Authenticator.
"""

refresh_endpoint = "refresh_end"
refresh_endpoint = "https://some_url.com/v1"
client_id = "client_id"
client_secret = "client_secret"
refresh_token = "refresh_token"
refresh_access_token_headers = {"Header_1": "value 1", "Header_2": "value 2"}

def test_get_auth_header_fresh(self, mocker):
"""
Expand Down Expand Up @@ -130,18 +129,22 @@ def test_refresh_request_body(self):
}
assert body == expected

def test_refresh_access_token(self, mocker):
def test_refresh_access_token(self, requests_mock):
mock_refresh_token_call = requests_mock.post(TestOauth2Authenticator.refresh_endpoint,
json={"access_token": "token", "expires_in": 10})

oauth = Oauth2Authenticator(
TestOauth2Authenticator.refresh_endpoint,
TestOauth2Authenticator.client_id,
TestOauth2Authenticator.client_secret,
TestOauth2Authenticator.refresh_token,
refresh_access_token_headers=TestOauth2Authenticator.refresh_access_token_headers,
)
resp = Response()
resp.status_code = 200

mocker.patch.object(requests, "request", return_value=resp)
mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": 1000})
token = oauth.refresh_access_token()

assert ("access_token", 1000) == token
assert ("token", 10) == token
for header in self.refresh_access_token_headers:
assert header in mock_refresh_token_call.last_request.headers
assert self.refresh_access_token_headers[header] == mock_refresh_token_call.last_request.headers[header]
assert mock_refresh_token_call.called
Empty file.

0 comments on commit 4dca327

Please sign in to comment.