-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Low Code CDK] configurable oauth request payload (#13993)
* configurable oauth request payload * support interpolation for dictionaries that are not new subcomponents * rewrite a declarative oauth authenticator that performs interpolation at runtime * formatting * whatever i don't know why factory gets flagged w/ the newline change * we java now * remove duplicate oauth * add some comments * parse time properly from string interpolation * move declarative oauth to its own package in declarative module * add changelog info
- Loading branch information
Showing
12 changed files
with
498 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# | ||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from airbyte_cdk.sources.declarative.auth.oauth import DeclarativeOauth2Authenticator | ||
|
||
__all__ = [ | ||
"DeclarativeOauth2Authenticator", | ||
] |
137 changes: 137 additions & 0 deletions
137
airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# | ||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from typing import Any, List, Mapping | ||
|
||
import pendulum | ||
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping | ||
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString | ||
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator | ||
|
||
|
||
class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator): | ||
""" | ||
Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials based on | ||
a declarative connector configuration file. Credentials can be defined explicitly or via interpolation | ||
at runtime. The generated access token is attached to each request via the Authorization header. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
token_refresh_endpoint: str, | ||
client_id: str, | ||
client_secret: str, | ||
refresh_token: str, | ||
config: Mapping[str, Any], | ||
scopes: List[str] = None, | ||
token_expiry_date: str = None, | ||
access_token_name: str = "access_token", | ||
expires_in_name: str = "expires_in", | ||
refresh_request_body: Mapping[str, Any] = None, | ||
): | ||
self.config = config | ||
self.token_refresh_endpoint = InterpolatedString(token_refresh_endpoint) | ||
self.client_secret = InterpolatedString(client_secret) | ||
self.client_id = InterpolatedString(client_id) | ||
self.refresh_token = InterpolatedString(refresh_token) | ||
self.scopes = scopes | ||
self.access_token_name = InterpolatedString(access_token_name) | ||
self.expires_in_name = InterpolatedString(expires_in_name) | ||
self.refresh_request_body = InterpolatedMapping(refresh_request_body) | ||
|
||
self.token_expiry_date = ( | ||
pendulum.parse(InterpolatedString(token_expiry_date).eval(self.config)) | ||
if token_expiry_date | ||
else pendulum.now().subtract(days=1) | ||
) | ||
self.access_token = None | ||
|
||
@property | ||
def config(self) -> Mapping[str, Any]: | ||
return self._config | ||
|
||
@config.setter | ||
def config(self, value: Mapping[str, Any]): | ||
self._config = value | ||
|
||
@property | ||
def token_refresh_endpoint(self) -> InterpolatedString: | ||
get_some = self._token_refresh_endpoint.eval(self.config) | ||
return get_some | ||
|
||
@token_refresh_endpoint.setter | ||
def token_refresh_endpoint(self, value: InterpolatedString): | ||
self._token_refresh_endpoint = value | ||
|
||
@property | ||
def client_id(self) -> InterpolatedString: | ||
return self._client_id.eval(self.config) | ||
|
||
@client_id.setter | ||
def client_id(self, value: InterpolatedString): | ||
self._client_id = value | ||
|
||
@property | ||
def client_secret(self) -> InterpolatedString: | ||
return self._client_secret.eval(self.config) | ||
|
||
@client_secret.setter | ||
def client_secret(self, value: InterpolatedString): | ||
self._client_secret = value | ||
|
||
@property | ||
def refresh_token(self) -> InterpolatedString: | ||
return self._refresh_token.eval(self.config) | ||
|
||
@refresh_token.setter | ||
def refresh_token(self, value: InterpolatedString): | ||
self._refresh_token = value | ||
|
||
@property | ||
def scopes(self) -> [str]: | ||
return self._scopes | ||
|
||
@scopes.setter | ||
def scopes(self, value: [str]): | ||
self._scopes = value | ||
|
||
@property | ||
def token_expiry_date(self) -> pendulum.DateTime: | ||
return self._token_expiry_date | ||
|
||
@token_expiry_date.setter | ||
def token_expiry_date(self, value: pendulum.DateTime): | ||
self._token_expiry_date = value | ||
|
||
@property | ||
def access_token_name(self) -> InterpolatedString: | ||
return self._access_token_name.eval(self.config) | ||
|
||
@access_token_name.setter | ||
def access_token_name(self, value: InterpolatedString): | ||
self._access_token_name = value | ||
|
||
@property | ||
def expires_in_name(self) -> InterpolatedString: | ||
return self._expires_in_name.eval(self.config) | ||
|
||
@expires_in_name.setter | ||
def expires_in_name(self, value: InterpolatedString): | ||
self._expires_in_name = value | ||
|
||
@property | ||
def refresh_request_body(self) -> InterpolatedMapping: | ||
return self._refresh_request_body.eval(self.config) | ||
|
||
@refresh_request_body.setter | ||
def refresh_request_body(self, value: InterpolatedMapping): | ||
self._refresh_request_body = value | ||
|
||
@property | ||
def access_token(self) -> str: | ||
return self._access_token | ||
|
||
@access_token.setter | ||
def access_token(self, value: str): | ||
self._access_token = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 1 addition & 2 deletions
3
airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
129 changes: 129 additions & 0 deletions
129
airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# | ||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from abc import abstractmethod | ||
from typing import Any, Mapping, MutableMapping, Tuple | ||
|
||
import pendulum | ||
import requests | ||
from requests.auth import AuthBase | ||
|
||
|
||
class AbstractOauth2Authenticator(AuthBase): | ||
""" | ||
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator | ||
is designed to generically perform the refresh flow without regard to how config fields are get/set by | ||
delegating that behavior to the classes implementing the interface. | ||
""" | ||
|
||
def __call__(self, request): | ||
request.headers.update(self.get_auth_header()) | ||
return request | ||
|
||
def get_auth_header(self) -> Mapping[str, Any]: | ||
return {"Authorization": f"Bearer {self.get_access_token()}"} | ||
|
||
def get_access_token(self): | ||
if self.token_has_expired(): | ||
t0 = pendulum.now() | ||
token, expires_in = self.refresh_access_token() | ||
self.access_token = token | ||
self.token_expiry_date = t0.add(seconds=expires_in) | ||
|
||
return self.access_token | ||
|
||
def token_has_expired(self) -> bool: | ||
return pendulum.now() > self.token_expiry_date | ||
|
||
def get_refresh_request_body(self) -> Mapping[str, Any]: | ||
"""Override to define additional parameters""" | ||
payload: MutableMapping[str, Any] = { | ||
"grant_type": "refresh_token", | ||
"client_id": self.client_id, | ||
"client_secret": self.client_secret, | ||
"refresh_token": self.refresh_token, | ||
} | ||
|
||
if self.scopes: | ||
payload["scopes"] = self.scopes | ||
|
||
if self.refresh_request_body: | ||
for key, val in self.refresh_request_body.items(): | ||
# We defer to existing oauth constructs over custom configured fields | ||
if key not in payload: | ||
payload[key] = val | ||
|
||
return payload | ||
|
||
def refresh_access_token(self) -> Tuple[str, int]: | ||
""" | ||
returns a tuple of (access_token, token_lifespan_in_seconds) | ||
""" | ||
try: | ||
response = requests.request(method="POST", url=self.token_refresh_endpoint, data=self.get_refresh_request_body()) | ||
response.raise_for_status() | ||
response_json = response.json() | ||
return response_json[self.access_token_name], response_json[self.expires_in_name] | ||
except Exception as e: | ||
raise Exception(f"Error while refreshing access token: {e}") from e | ||
|
||
@property | ||
@abstractmethod | ||
def token_refresh_endpoint(self): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def client_id(self): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def client_secret(self): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def refresh_token(self): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def scopes(self): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def token_expiry_date(self): | ||
pass | ||
|
||
@token_expiry_date.setter | ||
@abstractmethod | ||
def token_expiry_date(self, value): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def access_token_name(self): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def expires_in_name(self): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def refresh_request_body(self): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def access_token(self): | ||
pass | ||
|
||
@access_token.setter | ||
@abstractmethod | ||
def access_token(self, value): | ||
pass |
Oops, something went wrong.