Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete unnecessary auth configuration #858

Merged
merged 12 commits into from
Feb 23, 2022
165 changes: 124 additions & 41 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

import base64 as _base64
import logging as _logging
import subprocess
import time
from typing import List
from typing import Optional

import requests as _requests
from flyteidl.service import admin_pb2_grpc as _admin_service
from flyteidl.service import auth_pb2
from flyteidl.service import auth_pb2_grpc as auth_service
from google.protobuf.json_format import MessageToJson as _MessageToJson
from grpc import RpcError as _RpcError
from grpc import StatusCode as _GrpcStatusCode
Expand All @@ -11,14 +18,12 @@
from grpc import ssl_channel_credentials as _ssl_channel_credentials

from flytekit.clis.auth import credentials as _credentials_access
from flytekit.clis.sdk_in_container import basic_auth as _basic_auth
from flytekit.configuration import creds as _creds_config
from flytekit.configuration.creds import _DEPRECATED_CLIENT_CREDENTIALS_SCOPE as _DEPRECATED_SCOPE
from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET
from flytekit.configuration.creds import CLIENT_ID as _CLIENT_ID
from flytekit.configuration.creds import COMMAND as _COMMAND
from flytekit.configuration.creds import DEPRECATED_OAUTH_SCOPES, SCOPES
from flytekit.configuration.platform import AUTH as _AUTH
from flytekit.exceptions import user as _user_exceptions
from flytekit.exceptions.user import FlyteAuthenticationException
from flytekit.loggers import cli_logger


Expand All @@ -29,31 +34,26 @@ def _refresh_credentials_standard(flyte_client):
:param flyte_client: RawSynchronousFlyteClient
:return:
"""

client = _credentials_access.get_client(flyte_client.url)
if not flyte_client.oauth2_metadata or not flyte_client.public_client_config:
raise ValueError(
"Raw Flyte client attempting client credentials flow but no response from Admin detected. "
"Check your Admin server's .well-known endpoints to make sure they're working as expected."
)
client = _credentials_access.get_client(
redirect_endpoint=flyte_client.public_client_config.redirect_uri,
client_id=flyte_client.public_client_config.client_id,
scopes=flyte_client.public_client_config.scopes,
auth_endpoint=flyte_client.oauth2_metadata.authorization_endpoint,
token_endpoint=flyte_client.oauth2_metadata.token_endpoint,
)
if client.can_refresh_token:
client.refresh_access_token()

flyte_client.set_access_token(client.credentials.access_token)


def _get_basic_flow_scopes() -> List[str]:
"""
Merge the scope value between the old scope config option and the new list option.

:return: The scopes to use for basic auth flow.
"""
deprecated_single_scope = _DEPRECATED_SCOPE.get()
if deprecated_single_scope:
return [deprecated_single_scope]
scopes = DEPRECATED_OAUTH_SCOPES.get() or SCOPES.get()
if "openid" in scopes:
cli_logger.warning("Basic flow authentication should never use openid.")

return scopes
authorization_header_key = flyte_client.public_client_config.authorization_metadata_key or None
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
flyte_client.set_access_token(client.credentials.access_token, authorization_header_key)


def _refresh_credentials_basic(flyte_client):
def _refresh_credentials_basic(flyte_client: RawSynchronousFlyteClient):
"""
This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler
is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin,
Expand All @@ -63,16 +63,24 @@ def _refresh_credentials_basic(flyte_client):
:param flyte_client: RawSynchronousFlyteClient
:return:
"""
auth_endpoints = _credentials_access.get_authorization_endpoints(flyte_client.url)
token_endpoint = auth_endpoints.token_endpoint
client_secret = _basic_auth.get_secret()
cli_logger.debug(
"Basic authorization flow with client id {} scope {}".format(_CLIENT_ID.get(), _get_basic_flow_scopes())
)
authorization_header = _basic_auth.get_basic_authorization_header(_CLIENT_ID.get(), client_secret)
token, expires_in = _basic_auth.get_token(token_endpoint, authorization_header, _get_basic_flow_scopes())
if not flyte_client.oauth2_metadata or not flyte_client.public_client_config:
raise ValueError(
"Raw Flyte client attempting client credentials flow but no response from Admin detected. "
"Check your Admin server's .well-known endpoints to make sure they're working as expected."
)

token_endpoint = flyte_client.oauth2_metadata.token_endpoint
scopes = flyte_client.oauth2_metadata.scopes_supported
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For public client case, we should request scopes listed here instead... the supported scopes can be a huge list...

For client_credentials flow, users should be able to set that just like they set client id and client_secret... if not passed, we default to the public client's scopes here.

scopes = ",".join(scopes)

# Note that unlike the Pkce flow, the client ID does not come from Admin.
client_secret = get_secret()
cli_logger.debug("Basic authorization flow with client id {} scope {}".format(_CLIENT_ID.get(), scopes))
authorization_header = get_basic_authorization_header(_CLIENT_ID.get(), client_secret)
token, expires_in = get_token(token_endpoint, authorization_header, scopes)
cli_logger.info("Retrieved new token, expires in {}".format(expires_in))
flyte_client.set_access_token(token)
authorization_header_key = flyte_client.public_client_config.authorization_metadata_key or None
flyte_client.set_access_token(token, authorization_header_key)


def _refresh_credentials_from_command(flyte_client):
Expand Down Expand Up @@ -101,7 +109,7 @@ def _refresh_credentials_noop(flyte_client):
def _get_refresh_handler(auth_mode):
if auth_mode == "standard":
return _refresh_credentials_standard
elif auth_mode == "basic":
elif auth_mode == "basic" or auth_mode == "client_credentials":
return _refresh_credentials_basic
elif auth_mode == "external_process":
return _refresh_credentials_from_command
Expand Down Expand Up @@ -210,22 +218,42 @@ def __init__(self, url, insecure=False, credentials=None, options=None, root_cer
options=list((options or {}).items()),
)
self._stub = _admin_service.AdminServiceStub(self._channel)
self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel)
try:
resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest())
self._public_client_config = resp
except _RpcError:
cli_logger.debug("No public client auth config found, skipping.")
self._public_client_config = None
try:
resp = self._auth_stub.GetOAuth2Metadata(auth_pb2.OAuth2MetadataRequest())
self._oauth2_metadata = resp
except _RpcError:
cli_logger.debug("No OAuth2 Metadata found, skipping.")
self._oauth2_metadata = None

# metadata will hold the value of the token to send to the various endpoints.
self._metadata = None
if _AUTH.get():
self.force_auth_flow()

@property
def public_client_config(self) -> Optional[auth_pb2.PublicClientAuthConfigResponse]:
return self._public_client_config

@property
def oauth2_metadata(self) -> Optional[auth_pb2.OAuth2MetadataResponse]:
return self._oauth2_metadata

@property
def url(self) -> str:
return self._url

def set_access_token(self, access_token):
def set_access_token(self, access_token: str, authorization_header_key: Optional[str] = "authorization"):
# Always set the header to lower-case regardless of what the config is. The grpc libraries that Admin uses
# to parse the metadata don't change the metadata, but they do automatically lower the key you're looking for.
authorization_metadata_key = _creds_config.AUTHORIZATION_METADATA_KEY.get().lower()
cli_logger.debug(f"Adding authorization header. Header name: {authorization_metadata_key}.")
cli_logger.debug(f"Adding authorization header. Header name: {authorization_header_key}.")
self._metadata = [
(
authorization_metadata_key,
authorization_header_key,
f"Bearer {access_token}",
)
]
Expand Down Expand Up @@ -749,3 +777,58 @@ def list_matchable_attributes(self, matchable_attributes_list_request):

# TODO: (P2) Implement the event endpoints in case there becomes a use-case for third-parties to submit events
# through the client in Python.


def get_token(token_endpoint, authorization_header, scope):
"""
:param Text token_endpoint:
:param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123')
:param Text scope:
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
in seconds
"""
headers = {
"Authorization": authorization_header,
"Cache-Control": "no-cache",
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
body = {
"grant_type": "client_credentials",
}
if scope is not None:
body["scope"] = scope
response = _requests.post(token_endpoint, data=body, headers=headers)
if response.status_code != 200:
_logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text))
raise FlyteAuthenticationException("Non-200 received from IDP")

response = response.json()
return response["access_token"], response["expires_in"]


def get_secret():
"""
This function will either read in the password from the file path given by the CLIENT_CREDENTIALS_SECRET_LOCATION
config object, or from the environment variable using the CLIENT_CREDENTIALS_SECRET config object.
:rtype: Text
"""
secret = _CREDENTIALS_SECRET.get()
if secret:
return secret
raise FlyteAuthenticationException("No secret could be found")


def get_basic_authorization_header(client_id, client_secret):
"""
This function transforms the client id and the client secret into a header that conforms with http basic auth.
It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text.
:param Text client_id:
:param Text client_secret:
:rtype: Text
"""
concated = "{}:{}".format(client_id, client_secret)
return "Basic {}".format(_base64.b64encode(concated.encode(_utf_8)).decode(_utf_8))


_utf_8 = "utf-8"
2 changes: 0 additions & 2 deletions flytekit/clis/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def __init__(
scopes=None,
client_id=None,
redirect_uri=None,
client_secret=None,
):
self._auth_endpoint = auth_endpoint
self._token_endpoint = token_endpoint
Expand All @@ -161,7 +160,6 @@ def __init__(
self._refresh_token = None
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._expired = False
self._client_secret = client_secret

self._params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
Expand Down
56 changes: 11 additions & 45 deletions flytekit/clis/auth/credentials.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,28 @@
import urllib.parse as _urlparse
from typing import List

from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient
from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient
from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CLIENT_SECRET
from flytekit.configuration.creds import CLIENT_ID as _CLIENT_ID
from flytekit.configuration.creds import DEPRECATED_OAUTH_SCOPES
from flytekit.configuration.creds import REDIRECT_URI as _REDIRECT_URI
from flytekit.configuration.creds import SCOPES
from flytekit.configuration.platform import HTTP_URL as _HTTP_URL
from flytekit.configuration.platform import INSECURE as _INSECURE
from flytekit.configuration.platform import URL as _URL
from flytekit.clis.auth.auth import AuthorizationClient
from flytekit.loggers import auth_logger

# Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3.
discovery_endpoint_path = "./.well-known/oauth-authorization-server"


def _get_discovery_endpoint(http_config_val, platform_url_val, insecure_val):
if http_config_val:
scheme, netloc, path, _, _, _ = _urlparse.urlparse(http_config_val)
if not scheme:
scheme = "http" if insecure_val else "https"
else: # Use the main _URL config object effectively
scheme = "http" if insecure_val else "https"
netloc = platform_url_val
path = ""

computed_endpoint = _urlparse.urlunparse((scheme, netloc, path, None, None, None))
# The urljoin function needs a trailing slash in order to append things correctly. Also, having an extra slash
# at the end is okay, it just gets stripped out.
computed_endpoint = _urlparse.urljoin(computed_endpoint + "/", discovery_endpoint_path)
auth_logger.debug(f"Using {computed_endpoint} as discovery endpoint")
return computed_endpoint


# Lazy initialized authorization client singleton
_authorization_client = None


def get_client(flyte_client_url):
def get_client(
redirect_endpoint: str, client_id: str, scopes: List[str], auth_endpoint: str, token_endpoint: str
) -> AuthorizationClient:
global _authorization_client
if _authorization_client is not None and not _authorization_client.expired:
return _authorization_client
authorization_endpoints = get_authorization_endpoints(flyte_client_url)

_authorization_client = _AuthorizationClient(
redirect_uri=_REDIRECT_URI.get(),
client_id=_CLIENT_ID.get(),
scopes=DEPRECATED_OAUTH_SCOPES.get() or SCOPES.get(),
auth_endpoint=authorization_endpoints.auth_endpoint,
token_endpoint=authorization_endpoints.token_endpoint,
client_secret=_CLIENT_SECRET.get(),
_authorization_client = AuthorizationClient(
redirect_uri=redirect_endpoint,
client_id=client_id,
scopes=scopes,
auth_endpoint=auth_endpoint,
token_endpoint=token_endpoint,
)

auth_logger.debug(f"Created oauth client with redirect {_authorization_client}")
Expand All @@ -59,9 +31,3 @@ def get_client(flyte_client_url):
_authorization_client.start_authorization_flow()

return _authorization_client


def get_authorization_endpoints(flyte_client_url):
discovery_endpoint = _get_discovery_endpoint(_HTTP_URL.get(), flyte_client_url or _URL.get(), _INSECURE.get())
discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint)
return discovery_client.get_authorization_endpoints()
73 changes: 0 additions & 73 deletions flytekit/clis/auth/discovery.py

This file was deleted.

Loading