diff --git a/google_auth_oauthlib/flow.py b/google_auth_oauthlib/flow.py index d7e1c08..d4336d7 100644 --- a/google_auth_oauthlib/flow.py +++ b/google_auth_oauthlib/flow.py @@ -51,9 +51,15 @@ .. _OAuth 2.0 Authorization Flow: https://tools.ietf.org/html/rfc6749#section-1.2 """ - +from base64 import urlsafe_b64encode +import hashlib import json import logging +try: + from secrets import SystemRandom +except ImportError: # pragma: NO COVER + from random import SystemRandom +from string import ascii_letters, digits import webbrowser import wsgiref.simple_server import wsgiref.util @@ -89,7 +95,7 @@ class Flow(object): def __init__( self, oauth2session, client_type, client_config, - redirect_uri=None): + redirect_uri=None, code_verifier=None): """ Args: oauth2session (requests_oauthlib.OAuth2Session): @@ -101,6 +107,8 @@ def __init__( redirect_uri (str): The OAuth 2.0 redirect URI if known at flow creation time. Otherwise, it will need to be set using :attr:`redirect_uri`. + code_verifier (str): random string of 43-128 chars used to verify + the key exchange.using PKCE. Auto-generated if not provided. .. _client secrets: https://developers.google.com/api-client-library/python/guide @@ -113,6 +121,7 @@ def __init__( self.oauth2session = oauth2session """requests_oauthlib.OAuth2Session: The OAuth 2.0 session.""" self.redirect_uri = redirect_uri + self.code_verifier = code_verifier @classmethod def from_client_config(cls, client_config, scopes, **kwargs): @@ -208,6 +217,18 @@ def authorization_url(self, **kwargs): specify the ``state`` when constructing the :class:`Flow`. """ kwargs.setdefault('access_type', 'offline') + if not self.code_verifier: + chars = ascii_letters+digits+'-._~' + rnd = SystemRandom() + random_verifier = [rnd.choice(chars) for _ in range(0, 128)] + self.code_verifier = ''.join(random_verifier) + code_hash = hashlib.sha256() + code_hash.update(str.encode(self.code_verifier)) + unencoded_challenge = code_hash.digest() + b64_challenge = urlsafe_b64encode(unencoded_challenge) + code_challenge = b64_challenge.decode().split('=')[0] + kwargs.setdefault('code_challenge', code_challenge) + kwargs.setdefault('code_challenge_method', 'S256') url, state = self.oauth2session.authorization_url( self.client_config['auth_uri'], **kwargs) @@ -237,6 +258,7 @@ def fetch_token(self, **kwargs): :class:`~google.auth.credentials.Credentials` instance. """ kwargs.setdefault('client_secret', self.client_config['client_secret']) + kwargs.setdefault('code_verifier', self.code_verifier) return self.oauth2session.fetch_token( self.client_config['token_uri'], **kwargs) diff --git a/tests/test_flow.py b/tests/test_flow.py index 3379140..c8a2390 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -17,6 +17,7 @@ from functools import partial import json import os +import re import mock import pytest @@ -87,6 +88,7 @@ def test_redirect_uri(self, instance): def test_authorization_url(self, instance): scope = 'scope_one' instance.oauth2session.scope = [scope] + instance.code_verifier = 'amanaplanacanalpanama' authorization_url_patch = mock.patch.object( instance.oauth2session, 'authorization_url', wraps=instance.oauth2session.authorization_url) @@ -99,11 +101,14 @@ def test_authorization_url(self, instance): authorization_url_spy.assert_called_with( CLIENT_SECRETS_INFO['web']['auth_uri'], access_type='offline', - prompt='consent') + prompt='consent', + code_challenge='2yN0TOdl0gkGwFOmtfx3f913tgEaLM2d2S0WlmG1Z6Q', + code_challenge_method='S256') def test_authorization_url_access_type(self, instance): scope = 'scope_one' instance.oauth2session.scope = [scope] + instance.code_verifier = 'amanaplanacanalpanama' authorization_url_patch = mock.patch.object( instance.oauth2session, 'authorization_url', wraps=instance.oauth2session.authorization_url) @@ -115,9 +120,31 @@ def test_authorization_url_access_type(self, instance): assert scope in url authorization_url_spy.assert_called_with( CLIENT_SECRETS_INFO['web']['auth_uri'], - access_type='meep') + access_type='meep', + code_challenge='2yN0TOdl0gkGwFOmtfx3f913tgEaLM2d2S0WlmG1Z6Q', + code_challenge_method='S256') + + def test_authorization_url_generated_verifier(self, instance): + scope = 'scope_one' + instance.oauth2session.scope = [scope] + authorization_url_path = mock.patch.object( + instance.oauth2session, 'authorization_url', + wraps=instance.oauth2session.authorization_url) + + with authorization_url_path as authorization_url_spy: + instance.authorization_url() + + _, kwargs = authorization_url_spy.call_args_list[0] + assert kwargs['code_challenge_method'] == 'S256' + assert len(instance.code_verifier) == 128 + assert len(kwargs['code_challenge']) == 43 + valid_verifier = r'^[A-Za-z0-9-._~]*$' + valid_challenge = r'^[A-Za-z0-9-_]*$' + assert re.match(valid_verifier, instance.code_verifier) + assert re.match(valid_challenge, kwargs['code_challenge']) def test_fetch_token(self, instance): + instance.code_verifier = 'amanaplanacanalpanama' fetch_token_patch = mock.patch.object( instance.oauth2session, 'fetch_token', autospec=True, return_value=mock.sentinel.token) @@ -129,7 +156,8 @@ def test_fetch_token(self, instance): fetch_token_mock.assert_called_with( CLIENT_SECRETS_INFO['web']['token_uri'], client_secret=CLIENT_SECRETS_INFO['web']['client_secret'], - code=mock.sentinel.code) + code=mock.sentinel.code, + code_verifier='amanaplanacanalpanama') def test_credentials(self, instance): instance.oauth2session.token = { @@ -194,7 +222,7 @@ def set_token(*args, **kwargs): @mock.patch('google_auth_oauthlib.flow.input', autospec=True) def test_run_console(self, input_mock, instance, mock_fetch_token): input_mock.return_value = mock.sentinel.code - + instance.code_verifier = 'amanaplanacanalpanama' credentials = instance.run_console() assert credentials.token == mock.sentinel.access_token @@ -204,7 +232,8 @@ def test_run_console(self, input_mock, instance, mock_fetch_token): mock_fetch_token.assert_called_with( CLIENT_SECRETS_INFO['web']['token_uri'], client_secret=CLIENT_SECRETS_INFO['web']['client_secret'], - code=mock.sentinel.code) + code=mock.sentinel.code, + code_verifier='amanaplanacanalpanama') @pytest.mark.webtest @mock.patch('google_auth_oauthlib.flow.webbrowser', autospec=True) @@ -213,6 +242,7 @@ def test_run_local_server( auth_redirect_url = urllib.parse.urljoin( 'http://localhost:60452', self.REDIRECT_REQUEST_PATH) + instance.code_verifier = 'amanaplanacanalpanama' with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: future = pool.submit(partial( @@ -235,7 +265,8 @@ def test_run_local_server( mock_fetch_token.assert_called_with( CLIENT_SECRETS_INFO['web']['token_uri'], client_secret=CLIENT_SECRETS_INFO['web']['client_secret'], - authorization_response=expected_auth_response) + authorization_response=expected_auth_response, + code_verifier='amanaplanacanalpanama') @mock.patch('google_auth_oauthlib.flow.webbrowser', autospec=True) @mock.patch('wsgiref.simple_server.make_server', autospec=True)