Skip to content

Commit

Permalink
Python OAuth client_credentials support (#563)
Browse files Browse the repository at this point in the history
The current OAuth implementation in the Python client does not reuse access tokens. Instead, it exchanges the client-id and client-secret for a new access token with every request. This behavior increases the load on the tokenEndpoint and introduces unnecessary latency in the request processing.

This PR updates the OAuth implementation to oauth reuse access tokens, aligning the Python client with the OAuth behavior of the Scala client, as detailed in this [PR for the Scala client](#553).

For additional context on the usage of auth_provider and its functionality, please refer to the description in the original PR for the Spark client.
  • Loading branch information
moderakh authored Aug 14, 2024
1 parent 432f4c3 commit 1be4eed
Show file tree
Hide file tree
Showing 7 changed files with 633 additions and 69 deletions.
233 changes: 233 additions & 0 deletions python/delta_sharing/_internal_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
#
# Copyright (C) 2021 The Delta Lake Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
import requests
import base64
import json
import threading
import requests.sessions
import time
from typing import Dict

from delta_sharing.protocol import (
DeltaSharingProfile,
)

# This module contains internal implementation classes.
# These classes are not part of the public API and should not be used directly by users.
# Internal classes may change or be removed at any time without notice.


class AuthConfig:
def __init__(self, token_exchange_max_retries=5,
token_exchange_max_retry_duration_in_seconds=60,
token_renewal_threshold_in_seconds=600):
self.token_exchange_max_retries = token_exchange_max_retries
self.token_exchange_max_retry_duration_in_seconds = (
token_exchange_max_retry_duration_in_seconds)
self.token_renewal_threshold_in_seconds = token_renewal_threshold_in_seconds


class AuthCredentialProvider(ABC):
@abstractmethod
def add_auth_header(self, session: requests.Session) -> None:
pass

def is_expired(self) -> bool:
return False

@abstractmethod
def get_expiration_time(self) -> Optional[str]:
return None


class BearerTokenAuthProvider(AuthCredentialProvider):
def __init__(self, bearer_token: str, expiration_time: Optional[str]):
self.bearer_token = bearer_token
self.expiration_time = expiration_time

def add_auth_header(self, session: requests.Session) -> None:
session.headers.update(
{
"Authorization": f"Bearer {self.bearer_token}",
}
)

def is_expired(self) -> bool:
if self.expiration_time is None:
return False
try:
expiration_time_as_timestamp = datetime.fromisoformat(self.expiration_time)
return expiration_time_as_timestamp < datetime.now()
except ValueError:
return False

def get_expiration_time(self) -> Optional[str]:
return self.expiration_time


class BasicAuthProvider(AuthCredentialProvider):
def __init__(self, endpoint: str, username: str, password: str):
self.username = username
self.password = password
self.endpoint = endpoint

def add_auth_header(self, session: requests.Session) -> None:
session.auth = (self.username, self.password)
session.post(self.endpoint, data={"grant_type": "client_credentials"},)

def is_expired(self) -> bool:
return False

def get_expiration_time(self) -> Optional[str]:
return None


class OAuthClientCredentials:
def __init__(self, access_token: str, expires_in: int, creation_timestamp: int):
self.access_token = access_token
self.expires_in = expires_in
self.creation_timestamp = creation_timestamp


class OAuthClient:
def __init__(self,
token_endpoint: str,
client_id: str,
client_secret: str,
scope: Optional[str] = None):
self.token_endpoint = token_endpoint
self.client_id = client_id
self.client_secret = client_secret
self.scope = scope

def client_credentials(self) -> OAuthClientCredentials:
credentials = base64.b64encode(
f"{self.client_id}:{self.client_secret}".encode('utf-8')).decode('utf-8')
headers = {
'accept': 'application/json',
'authorization': f'Basic {credentials}',
'content-type': 'application/x-www-form-urlencoded'
}
body = f"grant_type=client_credentials{f'&scope={self.scope}' if self.scope else ''}"
response = requests.post(self.token_endpoint, headers=headers, data=body)
response.raise_for_status()
return self.parse_oauth_token_response(response.text)

def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials:
if not response:
raise RuntimeError("Empty response from OAuth token endpoint")
json_node = json.loads(response)
if 'access_token' not in json_node or not isinstance(json_node['access_token'], str):
raise RuntimeError("Missing 'access_token' field in OAuth token response")
if 'expires_in' not in json_node or not isinstance(json_node['expires_in'], int):
raise RuntimeError("Missing 'expires_in' field in OAuth token response")
return OAuthClientCredentials(
json_node['access_token'],
json_node['expires_in'],
int(datetime.now().timestamp())
)


class OAuthClientCredentialsAuthProvider(AuthCredentialProvider):
def __init__(self, oauth_client: OAuthClient, auth_config: AuthConfig = AuthConfig()):
self.auth_config = auth_config
self.oauth_client = oauth_client
self.current_token: Optional[OAuthClientCredentials] = None
self.lock = threading.RLock()

def add_auth_header(self,session: requests.Session) -> None:
token = self.maybe_refresh_token()
with self.lock:
session.headers.update(
{
"Authorization": f"Bearer {token.access_token}",
}
)

def maybe_refresh_token(self) -> OAuthClientCredentials:
with self.lock:
if self.current_token and not self.needs_refresh(self.current_token):
return self.current_token
new_token = self.oauth_client.client_credentials()
self.current_token = new_token
return new_token

def needs_refresh(self, token: OAuthClientCredentials) -> bool:
now = int(time.time())
expiration_time = token.creation_timestamp + token.expires_in
return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds

def get_expiration_time(self) -> Optional[str]:
return None


class AuthCredentialProviderFactory:
__oauth_auth_provider_cache : Dict[
DeltaSharingProfile,
OAuthClientCredentialsAuthProvider] = {}

@staticmethod
def create_auth_credential_provider(profile: DeltaSharingProfile):
if profile.share_credentials_version == 2:
if profile.type == "oauth_client_credentials":
return AuthCredentialProviderFactory.__oauth_client_credentials(profile)
elif profile.type == "basic":
return AuthCredentialProviderFactory.__auth_basic(profile)
elif (profile.share_credentials_version == 1 and
(profile.type is None or profile.type == "bearer_token")):
return AuthCredentialProviderFactory.__auth_bearer_token(profile)

# any other scenario is unsupported
raise RuntimeError(f"unsupported profile.type: {profile.type}"
f" profile.share_credentials_version"
f" {profile.share_credentials_version}")

@staticmethod
def __oauth_client_credentials(profile):
# Once a clientId/clientSecret is exchanged for an accessToken,
# the accessToken can be reused until it expires.
# The Python client re-creates DeltaSharingClient for different requests.
# To ensure the OAuth access_token is reused,
# we keep a mapping from profile -> OAuthClientCredentialsAuthProvider.
# This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile,
# ensuring the access_token can be reused.
if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache:
return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile]

oauth_client = OAuthClient(
token_endpoint=profile.token_endpoint,
client_id=profile.client_id,
client_secret=profile.client_secret,
scope=profile.scope
)
provider = OAuthClientCredentialsAuthProvider(
oauth_client=oauth_client,
auth_config=AuthConfig()
)
AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider
return provider

@staticmethod
def __auth_bearer_token(profile):
return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time)

@staticmethod
def __auth_basic(profile):
return BasicAuthProvider(profile.endpoint, profile.username, profile.password)
6 changes: 4 additions & 2 deletions python/delta_sharing/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DeltaSharingProfile:
client_secret: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None
scope: Optional[str] = None

def __post_init__(self):
if self.share_credentials_version > DeltaSharingProfile.CURRENT:
Expand Down Expand Up @@ -77,7 +78,7 @@ def from_json(json) -> "DeltaSharingProfile":
)
elif share_credentials_version == 2:
type = json["type"]
if type == "persistent_oauth2.0":
if type == "oauth_client_credentials":
token_endpoint = json["tokenEndpoint"]
if token_endpoint is not None and token_endpoint.endswith("/"):
token_endpoint = token_endpoint[:-1]
Expand All @@ -88,6 +89,7 @@ def from_json(json) -> "DeltaSharingProfile":
token_endpoint=token_endpoint,
client_id=json["clientId"],
client_secret=json["clientSecret"],
scope=json.get("scope"),
)
elif type == "bearer_token":
return DeltaSharingProfile(
Expand All @@ -107,7 +109,7 @@ def from_json(json) -> "DeltaSharingProfile":
)
else:
raise ValueError(
"The current release does not supports {type} type. "
f"The current release does not supports {type} type. "
"Please check type.")
else:
raise ValueError(
Expand Down
68 changes: 11 additions & 57 deletions python/delta_sharing/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import time
import logging
import pprint
from datetime import datetime

import requests
from requests.exceptions import HTTPError, ConnectionError
Expand All @@ -39,6 +38,8 @@
Table,
)

from delta_sharing._internal_auth import AuthCredentialProviderFactory


@dataclass(frozen=True)
class ListSharesResponse:
Expand Down Expand Up @@ -151,65 +152,20 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10):
self._profile = profile
self._num_retries = num_retries
self._sleeper = lambda sleep_ms: time.sleep(sleep_ms / 1000)
self.auth_session(profile)

def auth_session(self, profile):
self._session = requests.Session()
self.__auth_broker(profile)
if urlparse(profile.endpoint).hostname == "localhost":
self._session.verify = False

def __auth_broker(self, profile):
if profile.share_credentials_version == 2:
if profile.type == "persistent_oauth2.0":
self.__auth_persistent_oauth2(profile)
elif profile.type == "bearer_token":
self.__auth_bearer_token(profile)
elif profile.type == "basic":
self.__auth_basic(profile)
else:
self.__auth_bearer_token(profile)
else:
self.__auth_bearer_token(profile)

def __auth_bearer_token(self, profile):
self._session.headers.update(
{
"Authorization": f"Bearer {profile.bearer_token}",
"User-Agent": DataSharingRestClient.USER_AGENT,
}
)

def __auth_persistent_oauth2(self, profile):
headers = {"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json"}

response = requests.post(profile.token_endpoint,
data={"grant_type": "client_credentials"},
headers=headers,
auth=(profile.client_id,
profile.client_secret),)

bearer_token = "{}".format(response.json()["access_token"])
self.__auth_session(profile)

self._session.headers.update(
{
"Authorization": f"Bearer {bearer_token}",
"User-Agent": DataSharingRestClient.USER_AGENT,
}
)

def __auth_basic(self, profile):
self._session.auth = (profile.username, profile.password)

response = self._session.post(profile.endpoint,
data={"grant_type": "client_credentials"},)

self._session.headers.update(
{
"User-Agent": DataSharingRestClient.USER_AGENT,
}
)
def __auth_session(self, profile):
self._session = requests.Session()
self._auth_credential_provider = (
AuthCredentialProviderFactory.create_auth_credential_provider(profile))
if urlparse(profile.endpoint).hostname == "localhost":
self._session.verify = False

def set_sharing_capabilities_header(self):
delta_sharing_capabilities = (
Expand Down Expand Up @@ -502,6 +458,7 @@ def _request_internal(
**kwargs,
):
assert target.startswith("/"), "Targets should start with '/'"
self._auth_credential_provider.add_auth_header(self._session)
response = request(f"{self._profile.endpoint}{target}", **kwargs)
try:
response.raise_for_status()
Expand Down Expand Up @@ -541,10 +498,7 @@ def _should_retry(self, error):
def _error_on_expired_token(self, error):
if isinstance(error, HTTPError) and error.response.status_code == 401:
try:
expiration_time = datetime.strptime(
self._profile.expiration_time, "%Y-%m-%dT%H:%M:%S.%fZ"
)
return datetime.now() > expiration_time
self._auth_credential_provider.is_expired()
except Exception:
return False
else:
Expand Down
Loading

0 comments on commit 1be4eed

Please sign in to comment.