Skip to content

Commit

Permalink
feat: add reauth feature to user credentials (#727)
Browse files Browse the repository at this point in the history
* feat: add reauth support to oauth2 credentials

* update
  • Loading branch information
arithmetic1728 authored Apr 14, 2021
1 parent e383636 commit 82293fe
Show file tree
Hide file tree
Showing 11 changed files with 1,152 additions and 48 deletions.
9 changes: 9 additions & 0 deletions google/auth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ class ClientCertError(GoogleAuthError):
class OAuthError(GoogleAuthError):
"""Used to indicate an error occurred during an OAuth related HTTP
request."""


class ReauthFailError(RefreshError):
"""An exception for when reauth failed."""

def __init__(self, message=None):
super(ReauthFailError, self).__init__(
"Reauthentication failed. {0}".format(message)
)
130 changes: 99 additions & 31 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,29 @@
from google.auth import jwt

_URLENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
_JSON_CONTENT_TYPE = "application/json"
_JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
_REFRESH_GRANT_TYPE = "refresh_token"


def _handle_error_response(response_body):
""""Translates an error response into an exception.
def _handle_error_response(response_data):
"""Translates an error response into an exception.
Args:
response_body (str): The decoded response data.
response_data (Mapping): The decoded response data.
Raises:
google.auth.exceptions.RefreshError
google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
try:
error_data = json.loads(response_body)
error_details = "{}: {}".format(
error_data["error"], error_data.get("error_description")
response_data["error"], response_data.get("error_description")
)
# If no details could be extracted, use the response data.
except (KeyError, ValueError):
error_details = response_body
error_details = json.dumps(response_data)

raise exceptions.RefreshError(error_details, response_body)
raise exceptions.RefreshError(error_details, response_data)


def _parse_expiry(response_data):
Expand All @@ -78,25 +78,35 @@ def _parse_expiry(response_data):
return None


def _token_endpoint_request(request, token_uri, body):
def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
successful, and a mapping for the JSON-decoded response data.
"""
body = urllib.parse.urlencode(body).encode("utf-8")
headers = {"content-type": _URLENCODED_CONTENT_TYPE}
if use_json:
headers = {"Content-Type": _JSON_CONTENT_TYPE}
body = json.dumps(body).encode("utf-8")
else:
headers = {"Content-Type": _URLENCODED_CONTENT_TYPE}
body = urllib.parse.urlencode(body).encode("utf-8")

if access_token:
headers["Authorization"] = "Bearer {}".format(access_token)

retry = 0
# retry to fetch token for maximum of two times if any internal failure
Expand All @@ -121,8 +131,38 @@ def _token_endpoint_request(request, token_uri, body):
):
retry += 1
continue
_handle_error_response(response_body)
return response.status == http_client.OK, response_data

return response.status == http_client.OK, response_data


def _token_endpoint_request(
request, token_uri, body, access_token=None, use_json=False
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
"""
response_status_ok, response_data = _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
)
if not response_status_ok:
_handle_error_response(response_data)
return response_data


Expand Down Expand Up @@ -204,8 +244,43 @@ def id_token_jwt_grant(request, token_uri, assertion):
return id_token, expiry, response_data


def _handle_refresh_grant_response(response_data, refresh_token):
"""Extract tokens from refresh grant response.
Args:
response_data (Mapping[str, str]): Refresh grant response data.
refresh_token (str): Current refresh token.
Returns:
Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access token,
refresh token, expiration, and additional data returned by the token
endpoint. If response_data doesn't have refresh token, then the current
refresh token will be returned.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
"""
try:
access_token = response_data["access_token"]
except KeyError as caught_exc:
new_exc = exceptions.RefreshError("No access token in response.", response_data)
six.raise_from(new_exc, caught_exc)

refresh_token = response_data.get("refresh_token", refresh_token)
expiry = _parse_expiry(response_data)

return access_token, refresh_token, expiry, response_data


def refresh_grant(
request, token_uri, refresh_token, client_id, client_secret, scopes=None
request,
token_uri,
refresh_token,
client_id,
client_secret,
scopes=None,
rapt_token=None,
):
"""Implements the OAuth 2.0 refresh token grant.
Expand All @@ -224,10 +299,11 @@ def refresh_grant(
scopes must be authorized for the refresh token. Useful if refresh
token has a wild card scope (e.g.
'https://www.googleapis.com/auth/any-api').
rapt_token (Optional(str)): The reauth Proof Token.
Returns:
Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
access token, new refresh token, expiration, and additional data
Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access
token, new or current refresh token, expiration, and additional data
returned by the token endpoint.
Raises:
Expand All @@ -244,16 +320,8 @@ def refresh_grant(
}
if scopes:
body["scope"] = " ".join(scopes)
if rapt_token:
body["rapt"] = rapt_token

response_data = _token_endpoint_request(request, token_uri, body)

try:
access_token = response_data["access_token"]
except KeyError as caught_exc:
new_exc = exceptions.RefreshError("No access token in response.", response_data)
six.raise_from(new_exc, caught_exc)

refresh_token = response_data.get("refresh_token", refresh_token)
expiry = _parse_expiry(response_data)

return access_token, refresh_token, expiry, response_data
return _handle_refresh_grant_response(response_data, refresh_token)
157 changes: 157 additions & 0 deletions google/oauth2/challenges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Challenges for reauthentication.
"""

import abc
import base64
import getpass
import sys

import six

from google.auth import _helpers
from google.auth import exceptions


REAUTH_ORIGIN = "https://accounts.google.com"


def get_user_password(text):
"""Get password from user.
Override this function with a different logic if you are using this library
outside a CLI.
Args:
text (str): message for the password prompt.
Returns:
str: password string.
"""
return getpass.getpass(text)


@six.add_metaclass(abc.ABCMeta)
class ReauthChallenge(object):
"""Base class for reauth challenges."""

@property
@abc.abstractmethod
def name(self): # pragma: NO COVER
"""Returns the name of the challenge."""
raise NotImplementedError("name property must be implemented")

@property
@abc.abstractmethod
def is_locally_eligible(self): # pragma: NO COVER
"""Returns true if a challenge is supported locally on this machine."""
raise NotImplementedError("is_locally_eligible property must be implemented")

@abc.abstractmethod
def obtain_challenge_input(self, metadata): # pragma: NO COVER
"""Performs logic required to obtain credentials and returns it.
Args:
metadata (Mapping): challenge metadata returned in the 'challenges' field in
the initial reauth request. Includes the 'challengeType' field
and other challenge-specific fields.
Returns:
response that will be send to the reauth service as the content of
the 'proposalResponse' field in the request body. Usually a dict
with the keys specific to the challenge. For example,
``{'credential': password}`` for password challenge.
"""
raise NotImplementedError("obtain_challenge_input method must be implemented")


class PasswordChallenge(ReauthChallenge):
"""Challenge that asks for user's password."""

@property
def name(self):
return "PASSWORD"

@property
def is_locally_eligible(self):
return True

@_helpers.copy_docstring(ReauthChallenge)
def obtain_challenge_input(self, unused_metadata):
passwd = get_user_password("Please enter your password:")
if not passwd:
passwd = " " # avoid the server crashing in case of no password :D
return {"credential": passwd}


class SecurityKeyChallenge(ReauthChallenge):
"""Challenge that asks for user's security key touch."""

@property
def name(self):
return "SECURITY_KEY"

@property
def is_locally_eligible(self):
return True

@_helpers.copy_docstring(ReauthChallenge)
def obtain_challenge_input(self, metadata):
try:
import pyu2f.convenience.authenticator
import pyu2f.errors
import pyu2f.model
except ImportError:
raise exceptions.ReauthFailError(
"pyu2f dependency is required to use Security key reauth feature. "
"It can be installed via `pip install pyu2f` or `pip install google-auth[reauth]`."
)
sk = metadata["securityKey"]
challenges = sk["challenges"]
app_id = sk["applicationId"]

challenge_data = []
for c in challenges:
kh = c["keyHandle"].encode("ascii")
key = pyu2f.model.RegisteredKey(bytearray(base64.urlsafe_b64decode(kh)))
challenge = c["challenge"].encode("ascii")
challenge = base64.urlsafe_b64decode(challenge)
challenge_data.append({"key": key, "challenge": challenge})

try:
api = pyu2f.convenience.authenticator.CreateCompositeAuthenticator(
REAUTH_ORIGIN
)
response = api.Authenticate(
app_id, challenge_data, print_callback=sys.stderr.write
)
return {"securityKey": response}
except pyu2f.errors.U2FError as e:
if e.code == pyu2f.errors.U2FError.DEVICE_INELIGIBLE:
sys.stderr.write("Ineligible security key.\n")
elif e.code == pyu2f.errors.U2FError.TIMEOUT:
sys.stderr.write("Timed out while waiting for security key touch.\n")
else:
raise e
except pyu2f.errors.NoDeviceFoundError:
sys.stderr.write("No security key found.\n")
return None


AVAILABLE_CHALLENGES = {
challenge.name: challenge
for challenge in [SecurityKeyChallenge(), PasswordChallenge()]
}
Loading

0 comments on commit 82293fe

Please sign in to comment.