Skip to content

Commit

Permalink
[Low Code CDK] configurable oauth request payload (#13993)
Browse files Browse the repository at this point in the history
* configurable oauth request payload

* support interpolation for dictionaries that are not new subcomponents

* rewrite a declarative oauth authenticator that performs interpolation at runtime

* formatting

* whatever i don't know why factory gets flagged w/ the newline change

* we java now

* remove duplicate oauth

* add some comments

* parse time properly from string interpolation

* move declarative oauth to its own package in declarative module

* add changelog info
  • Loading branch information
brianjlai authored Jul 8, 2022
1 parent f809270 commit 374e265
Show file tree
Hide file tree
Showing 12 changed files with 498 additions and 70 deletions.
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.1.64
- Add support for configurable oauth request payload and declarative oauth authenticator.

## 0.1.63
- Define `namespace` property on the `Stream` class inside `core.py`.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from airbyte_cdk.sources.declarative.auth.oauth import DeclarativeOauth2Authenticator

__all__ = [
"DeclarativeOauth2Authenticator",
]
137 changes: 137 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, List, Mapping

import pendulum
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator


class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator):
"""
Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials based on
a declarative connector configuration file. Credentials can be defined explicitly or via interpolation
at runtime. The generated access token is attached to each request via the Authorization header.
"""

def __init__(
self,
token_refresh_endpoint: str,
client_id: str,
client_secret: str,
refresh_token: str,
config: Mapping[str, Any],
scopes: List[str] = None,
token_expiry_date: str = None,
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_request_body: Mapping[str, Any] = None,
):
self.config = config
self.token_refresh_endpoint = InterpolatedString(token_refresh_endpoint)
self.client_secret = InterpolatedString(client_secret)
self.client_id = InterpolatedString(client_id)
self.refresh_token = InterpolatedString(refresh_token)
self.scopes = scopes
self.access_token_name = InterpolatedString(access_token_name)
self.expires_in_name = InterpolatedString(expires_in_name)
self.refresh_request_body = InterpolatedMapping(refresh_request_body)

self.token_expiry_date = (
pendulum.parse(InterpolatedString(token_expiry_date).eval(self.config))
if token_expiry_date
else pendulum.now().subtract(days=1)
)
self.access_token = None

@property
def config(self) -> Mapping[str, Any]:
return self._config

@config.setter
def config(self, value: Mapping[str, Any]):
self._config = value

@property
def token_refresh_endpoint(self) -> InterpolatedString:
get_some = self._token_refresh_endpoint.eval(self.config)
return get_some

@token_refresh_endpoint.setter
def token_refresh_endpoint(self, value: InterpolatedString):
self._token_refresh_endpoint = value

@property
def client_id(self) -> InterpolatedString:
return self._client_id.eval(self.config)

@client_id.setter
def client_id(self, value: InterpolatedString):
self._client_id = value

@property
def client_secret(self) -> InterpolatedString:
return self._client_secret.eval(self.config)

@client_secret.setter
def client_secret(self, value: InterpolatedString):
self._client_secret = value

@property
def refresh_token(self) -> InterpolatedString:
return self._refresh_token.eval(self.config)

@refresh_token.setter
def refresh_token(self, value: InterpolatedString):
self._refresh_token = value

@property
def scopes(self) -> [str]:
return self._scopes

@scopes.setter
def scopes(self, value: [str]):
self._scopes = value

@property
def token_expiry_date(self) -> pendulum.DateTime:
return self._token_expiry_date

@token_expiry_date.setter
def token_expiry_date(self, value: pendulum.DateTime):
self._token_expiry_date = value

@property
def access_token_name(self) -> InterpolatedString:
return self._access_token_name.eval(self.config)

@access_token_name.setter
def access_token_name(self, value: InterpolatedString):
self._access_token_name = value

@property
def expires_in_name(self) -> InterpolatedString:
return self._expires_in_name.eval(self.config)

@expires_in_name.setter
def expires_in_name(self, value: InterpolatedString):
self._expires_in_name = value

@property
def refresh_request_body(self) -> InterpolatedMapping:
return self._refresh_request_body.eval(self.config)

@refresh_request_body.setter
def refresh_request_body(self, value: InterpolatedMapping):
self._refresh_request_body = value

@property
def access_token(self) -> str:
return self._access_token

@access_token.setter
def access_token(self, value: str):
self._access_token = value
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def _create_subcomponent(self, key, definition, kwargs, config, parent_class):
if self.is_object_definition_with_class_name(definition):
# propagate kwargs to inner objects
definition["options"] = self._merge_dicts(kwargs.get("options", dict()), definition.get("options", dict()))

return self.create_component(definition, config)()
elif self.is_object_definition_with_type(definition):
# If type is set instead of class_name, get the class_name from the CLASS_TYPES_REGISTRY
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#
# Copyright (c) 2021 Airbyte, Inc., all rights reserved.
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from .oauth import Oauth2Authenticator
from .token import MultipleTokenAuthenticator, TokenAuthenticator

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from abc import abstractmethod
from typing import Any, Mapping, MutableMapping, Tuple

import pendulum
import requests
from requests.auth import AuthBase


class AbstractOauth2Authenticator(AuthBase):
"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
is designed to generically perform the refresh flow without regard to how config fields are get/set by
delegating that behavior to the classes implementing the interface.
"""

def __call__(self, request):
request.headers.update(self.get_auth_header())
return request

def get_auth_header(self) -> Mapping[str, Any]:
return {"Authorization": f"Bearer {self.get_access_token()}"}

def get_access_token(self):
if self.token_has_expired():
t0 = pendulum.now()
token, expires_in = self.refresh_access_token()
self.access_token = token
self.token_expiry_date = t0.add(seconds=expires_in)

return self.access_token

def token_has_expired(self) -> bool:
return pendulum.now() > self.token_expiry_date

def get_refresh_request_body(self) -> Mapping[str, Any]:
"""Override to define additional parameters"""
payload: MutableMapping[str, Any] = {
"grant_type": "refresh_token",
"client_id": self.client_id,
"client_secret": self.client_secret,
"refresh_token": self.refresh_token,
}

if self.scopes:
payload["scopes"] = self.scopes

if self.refresh_request_body:
for key, val in self.refresh_request_body.items():
# We defer to existing oauth constructs over custom configured fields
if key not in payload:
payload[key] = val

return payload

def refresh_access_token(self) -> Tuple[str, int]:
"""
returns a tuple of (access_token, token_lifespan_in_seconds)
"""
try:
response = requests.request(method="POST", url=self.token_refresh_endpoint, data=self.get_refresh_request_body())
response.raise_for_status()
response_json = response.json()
return response_json[self.access_token_name], response_json[self.expires_in_name]
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e

@property
@abstractmethod
def token_refresh_endpoint(self):
pass

@property
@abstractmethod
def client_id(self):
pass

@property
@abstractmethod
def client_secret(self):
pass

@property
@abstractmethod
def refresh_token(self):
pass

@property
@abstractmethod
def scopes(self):
pass

@property
@abstractmethod
def token_expiry_date(self):
pass

@token_expiry_date.setter
@abstractmethod
def token_expiry_date(self, value):
pass

@property
@abstractmethod
def access_token_name(self):
pass

@property
@abstractmethod
def expires_in_name(self):
pass

@property
@abstractmethod
def refresh_request_body(self):
pass

@property
@abstractmethod
def access_token(self):
pass

@access_token.setter
@abstractmethod
def access_token(self, value):
pass
Loading

0 comments on commit 374e265

Please sign in to comment.