Skip to content

Commit

Permalink
fix: auth_session uses transport_options (#550)
Browse files Browse the repository at this point in the history
Now auth_session login/logout/login_user will pass along any original
transport_options from the sdk call.
  • Loading branch information
joeldodge79 authored Apr 2, 2021
1 parent e40e8b3 commit 94d6047
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 47 deletions.
97 changes: 66 additions & 31 deletions python/looker_sdk/rtl/auth_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@


class AuthSession:
"""AuthSession to provide automatic authentication
"""
"""AuthSession to provide automatic authentication"""

def __init__(
self,
Expand Down Expand Up @@ -72,31 +71,41 @@ def is_sudo_authenticated(self) -> bool:
def is_authenticated(self) -> bool:
return self._is_authenticated(self.token)

def _get_sudo_token(self) -> auth_token.AuthToken:
def _get_sudo_token(
self, transport_options: transport.TransportOptions
) -> auth_token.AuthToken:
"""Returns an active sudo token."""
if not self.is_sudo_authenticated:
self._login_sudo()
self._login_sudo(transport_options)
return self.sudo_token

def _get_token(self) -> auth_token.AuthToken:
def _get_token(
self, transport_options: transport.TransportOptions
) -> auth_token.AuthToken:
"""Returns an active token."""
if not self.is_authenticated:
self._login()
self._login(transport_options)
return self.token

def authenticate(self) -> Dict[str, str]:
def authenticate(
self, transport_options: transport.TransportOptions
) -> Dict[str, str]:
"""Return the Authorization header to authenticate each API call.
Expired token renewal happens automatically.
"""
if self._sudo_id:
token = self._get_sudo_token()
token = self._get_sudo_token(transport_options)
else:
token = self._get_token()
token = self._get_token(transport_options)

return {"Authorization": f"Bearer {token.access_token}"}

def login_user(self, sudo_id: int) -> None:
def login_user(
self,
sudo_id: int,
transport_options: Optional[transport.TransportOptions] = None,
) -> None:
"""Authenticate using settings credentials and sudo as sudo_id.
Make API calls as if authenticated as sudo_id. The sudo_id
Expand All @@ -106,7 +115,7 @@ def login_user(self, sudo_id: int) -> None:
if self._sudo_id is None:
self._sudo_id = sudo_id
try:
self._login_sudo()
self._login_sudo(transport_options or {})
except error.SDKError:
self._sudo_id = None
raise
Expand All @@ -118,9 +127,9 @@ def login_user(self, sudo_id: int) -> None:
"is already logged in. Log them out first."
)
elif not self.is_sudo_authenticated:
self._login_sudo()
self._login_sudo(transport_options or {})

def _login(self) -> None:
def _login(self, transport_options: transport.TransportOptions) -> None:
client_id = self.settings.read_config().get("client_id")
client_secret = self.settings.read_config().get("client_secret")
if not (client_id and client_secret):
Expand All @@ -133,14 +142,15 @@ def _login(self) -> None:
}
).encode("utf-8")

transport_options.setdefault("headers", {}).update(
{"Content-Type": "application/x-www-form-urlencoded"}
)
response = self._ok(
self.transport.request(
transport.HttpMethod.POST,
f"{self.settings.base_url}/api/{self.api_version}/login",
body=serialized,
transport_options={
"headers": {"Content-Type": "application/x-www-form-urlencoded"}
},
transport_options=transport_options,
)
)

Expand All @@ -151,14 +161,20 @@ def _login(self) -> None:
assert isinstance(access_token, auth_token.AccessToken)
self.token = auth_token.AuthToken(access_token)

def _login_sudo(self) -> None:
def _login_sudo(self, transport_options: transport.TransportOptions) -> None:
def authenticator(
transport_options: transport.TransportOptions,
) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self._get_token(transport_options).access_token}"
}

response = self._ok(
self.transport.request(
transport.HttpMethod.POST,
f"{self.settings.base_url}/api/{self.api_version}/login/{self._sudo_id}",
authenticator=lambda: {
"Authorization": f"Bearer {self._get_token().access_token}"
},
authenticator=authenticator,
transport_options=transport_options,
)
)
# ignore type: mypy bug doesn't recognized kwarg `structure` to partial func
Expand All @@ -168,7 +184,11 @@ def _login_sudo(self) -> None:
assert isinstance(access_token, auth_token.AccessToken)
self.sudo_token = auth_token.AuthToken(access_token)

def logout(self, full: bool = False) -> None:
def logout(
self,
full: bool = False,
transport_options: Optional[transport.TransportOptions] = None,
) -> None:
"""Logout of API.
If the session is authenticated as sudo_id, logout() "undoes"
Expand All @@ -181,14 +201,18 @@ def logout(self, full: bool = False) -> None:
if self._sudo_id:
self._sudo_id = None
if self.is_sudo_authenticated:
self._logout(sudo=True)
self._logout(sudo=True, transport_options=transport_options)
if full:
self._logout()
self._logout(transport_options=transport_options)

elif self.is_authenticated:
self._logout()
self._logout(transport_options=transport_options)

def _logout(self, sudo: bool = False) -> None:
def _logout(
self,
sudo: bool = False,
transport_options: Optional[transport.TransportOptions] = None,
) -> None:

if sudo:
token = self.sudo_token.access_token
Expand All @@ -197,11 +221,17 @@ def _logout(self, sudo: bool = False) -> None:
token = self.token.access_token
self.token = auth_token.AuthToken()

def authenticator(
_transport_options: transport.TransportOptions,
) -> Dict[str, str]:
return {"Authorization": f"Bearer {token}"}

self._ok(
self.transport.request(
transport.HttpMethod.DELETE,
f"{self.settings.base_url}/api/logout",
authenticator=lambda: {"Authorization": f"Bearer {token}"},
authenticator=authenticator,
transport_options=transport_options,
)
)

Expand Down Expand Up @@ -280,7 +310,9 @@ class RefreshTokenGrantTypeParams(GrantTypeParams):
grant_type: str = "refresh_token"

def _request_token(
self, grant_type: Union[AuthCodeGrantTypeParams, RefreshTokenGrantTypeParams]
self,
grant_type: Union[AuthCodeGrantTypeParams, RefreshTokenGrantTypeParams],
transport_options: transport.TransportOptions,
) -> auth_token.AccessToken:
response = self.transport.request(
transport.HttpMethod.POST,
Expand All @@ -296,7 +328,10 @@ def _request_token(
) # type: ignore

def redeem_auth_code(
self, auth_code: str, code_verifier: Optional[str] = None
self,
auth_code: str,
code_verifier: Optional[str] = None,
transport_options: Optional[transport.TransportOptions] = None,
) -> None:
params = self.AuthCodeGrantTypeParams(
client_id=self.client_id,
Expand All @@ -305,14 +340,14 @@ def redeem_auth_code(
code_verifier=code_verifier or self.code_verifier,
)

access_token = self._request_token(params)
access_token = self._request_token(params, transport_options or {})
self.token = auth_token.AuthToken(access_token)

def _login(self) -> None:
def _login(self, transport_options: transport.TransportOptions) -> None:
params = self.RefreshTokenGrantTypeParams(
client_id=self.client_id,
redirect_uri=self.redirect_uri,
refresh_token=self.token.refresh_token,
)
access_token = self._request_token(params)
access_token = self._request_token(params, transport_options)
self.token = auth_token.AuthToken(access_token)
13 changes: 7 additions & 6 deletions python/looker_sdk/rtl/requests_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@


class RequestsTransport(transport.Transport):
"""RequestsTransport implementation of Transport.
"""
"""RequestsTransport implementation of Transport."""

def __init__(
self, settings: transport.PTransportSettings, session: requests.Session
Expand All @@ -57,14 +56,14 @@ def request(
path: str,
query_params: Optional[MutableMapping[str, str]] = None,
body: Optional[bytes] = None,
authenticator: Optional[Callable[[], Dict[str, str]]] = None,
authenticator: transport.TAuthenticator = None,
transport_options: Optional[transport.TransportOptions] = None,
) -> transport.Response:

headers: Dict[str, str] = {}
headers = {}
timeout = self.settings.timeout
if authenticator:
headers.update(authenticator())
headers.update(authenticator(transport_options or {}))
if transport_options:
if transport_options.get("headers"):
headers.update(transport_options["headers"])
Expand All @@ -83,7 +82,9 @@ def request(
)
except IOError as exc:
ret = transport.Response(
False, bytes(str(exc), encoding="utf-8"), transport.ResponseMode.STRING,
False,
bytes(str(exc), encoding="utf-8"),
transport.ResponseMode.STRING,
)
else:
ret = transport.Response(
Expand Down
2 changes: 1 addition & 1 deletion python/looker_sdk/rtl/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class TransportOptions(TypedDict, total=False):
headers: MutableMapping[str, str]


TAuthenticator = Optional[Callable[[], Dict[str, str]]]
TAuthenticator = Optional[Callable[[TransportOptions], Dict[str, str]]]


class ResponseMode(enum.Enum):
Expand Down
26 changes: 17 additions & 9 deletions python/tests/rtl/test_auth_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,17 @@ def request(
transport_options=None,
):
if authenticator:
authenticator()
authenticator(transport_options)
if method == transport.HttpMethod.POST:
if path.endswith(("login", "login/5")):
if path.endswith("login"):
token = "AdminAccessToken"
expected_header = {
expected_headers = {
"Content-Type": "application/x-www-form-urlencoded"
}
if transport_options["headers"] != expected_header:
raise TypeError(f"Must send {expected_header}")
expected_headers.update(transport_options.get("headers", {}))
if transport_options["headers"] != expected_headers:
raise TypeError(f"Must send {expected_headers}")
else:
token = "UserAccessToken"
access_token = json.dumps(
Expand Down Expand Up @@ -132,26 +133,33 @@ def request(

def test_auto_login(auth_session: auth.AuthSession):
assert not auth_session.is_authenticated
auth_header = auth_session.authenticate()
auth_header = auth_session.authenticate({})
assert auth_header["Authorization"] == "Bearer AdminAccessToken"
assert auth_session.is_authenticated

# even after explicit logout
auth_session.logout()
assert not auth_session.is_authenticated
auth_header = auth_session.authenticate()
auth_header = auth_session.authenticate({})
assert isinstance(auth_header, dict)
assert auth_header["Authorization"] == "Bearer AdminAccessToken"
assert auth_session.is_authenticated


def test_auto_login_with_transport_options(auth_session: auth.AuthSession):
assert not auth_session.is_authenticated
auth_header = auth_session.authenticate({"headers": {"foo": "bar"}})
assert auth_header["Authorization"] == "Bearer AdminAccessToken"
assert auth_session.is_authenticated


def test_sudo_login_auto_logs_in(auth_session: auth.AuthSession):
assert not auth_session.is_authenticated
assert not auth_session.is_sudo_authenticated
auth_session.login_user(5)
assert auth_session.is_authenticated
assert auth_session.is_sudo_authenticated
auth_header = auth_session.authenticate()
auth_header = auth_session.authenticate({})
assert auth_header["Authorization"] == "Bearer UserAccessToken"


Expand Down Expand Up @@ -193,7 +201,7 @@ def test_it_fails_with_missing_credentials(
)

with pytest.raises(error.SDKError) as exc_info:
auth_session.authenticate()
auth_session.authenticate({})
assert "auth credentials not found" in str(exc_info.value)


Expand Down Expand Up @@ -228,7 +236,7 @@ def test_env_variables_override_config_file_credentials(
response_mode=transport.ResponseMode.STRING,
)

auth_session.authenticate()
auth_session.authenticate({})

expected_body = urllib.parse.urlencode(
{"client_id": expected_id, "client_secret": expected_secret}
Expand Down

0 comments on commit 94d6047

Please sign in to comment.