Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OAuth refresh token flow #520

Merged
merged 23 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions strawberryfields/api/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
import io
from datetime import datetime
from typing import List
from typing import List, Dict

import numpy as np
import requests
Expand Down Expand Up @@ -96,7 +96,8 @@ def __init__(
self._verbose = verbose

self._base_url = "http{}://{}:{}".format("s" if self.use_ssl else "", self.host, self.port)
self._headers = {"Authorization": self.token, "Accept-Version": self.api_version}

self._headers = {"Accept-Version": self.api_version}

self.log = create_logger(__name__)

Expand Down Expand Up @@ -156,7 +157,7 @@ def get_device_spec(self, target: str) -> DeviceSpec:
def _get_device_dict(self, target: str) -> dict:
"""Returns the device specifications as a dictionary"""
path = f"/devices/{target}/specifications"
response = requests.get(self._url(path), headers=self._headers)
response = self._request("GET", self._url(path), headers=self._headers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own understanding; what does this change do? Figured it out!


if response.status_code == 200:
self.log.info("The device spec %s has been successfully retrieved.", target)
Expand Down Expand Up @@ -185,7 +186,9 @@ def create_job(self, target: str, program: Program, run_options: dict = None) ->
circuit = bb.serialize()

path = "/jobs"
response = requests.post(self._url(path), headers=self._headers, json={"circuit": circuit})
response = self._request(
"POST", self._url(path), headers=self._headers, json={"circuit": circuit}
)
if response.status_code == 201:
job_id = response.json()["id"]
if self._verbose:
Expand Down Expand Up @@ -221,7 +224,7 @@ def get_job(self, job_id: str) -> Job:
strawberryfields.api.Job: the job
"""
path = "/jobs/{}".format(job_id)
response = requests.get(self._url(path), headers=self._headers)
response = self._request("GET", self._url(path), headers=self._headers)
if response.status_code == 200:
return Job(
id_=response.json()["id"],
Expand Down Expand Up @@ -254,8 +257,8 @@ def get_job_result(self, job_id: str) -> Result:
strawberryfields.api.Result: the job result
"""
path = "/jobs/{}/result".format(job_id)
response = requests.get(
self._url(path), headers={"Accept": "application/x-numpy", **self._headers}
response = self._request(
"GET", self._url(path), headers={"Accept": "application/x-numpy", **self._headers}
)
if response.status_code == 200:
# Read the numpy binary data in the payload into memory
Expand Down Expand Up @@ -283,8 +286,11 @@ def cancel_job(self, job_id: str):
job_id (str): the job ID
"""
path = "/jobs/{}".format(job_id)
response = requests.patch(
self._url(path), headers=self._headers, json={"status": JobStatus.CANCELLED.value}
response = self._request(
"PATCH",
self._url(path),
headers=self._headers,
json={"status": JobStatus.CANCELLED.value},
)
if response.status_code == 204:
if self._verbose:
Expand All @@ -301,12 +307,56 @@ def ping(self) -> bool:
bool: ``True`` if the connection is successful, and ``False`` otherwise
"""
path = "/healthz"
response = requests.get(self._url(path), headers=self._headers)
response = self._request("GET", self._url(path), headers=self._headers)
return response.status_code == 200

def _url(self, path: str) -> str:
return self._base_url + path

def _refresh_access_token(self):
"""Use the offline token to request a new access token."""
self._headers.pop("Authorization", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor / thinking aloud) I'm always partial to using del instead of pop (since the latter unnecessarily returns the item as well), but I see that pop might be a cleaner solution, since it works even if the item doesn't exist. 😆

path = "/auth/realms/platform/protocol/openid-connect/token"
headers = {**self._headers}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps better to do the following instead of the pop above?

Suggested change
headers = {**self._headers}
headers = {**self._headers}
headers.pop("Authorization", None)

It feels slightly 'safer' than mutating the instance attribute above, even if it has the same effect

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way, if the method fails midway, there are no side effects

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought here was that since we are in the refresh token method, we already know the existing Authorization header is invalid, so removing it from the instance headers is a housekeeping action before we replace it with a new one. Whether the method fails midway or not, any future requests are still going to fail regardless of if we've popped the header or not.

response = requests.post(
self._url(path),
headers=headers,
data={
"grant_type": "refresh_token",
"refresh_token": self._token,
"client_id": "public",
},
)
if response.status_code == 200:
access_token = response.json().get("access_token")
self._headers["Authorization"] = f"Bearer {access_token}"
else:
raise RequestFailedError(
"Authorization failed for request, please check your token provided."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Authorization failed for request, please check your token provided."
"Could not retrieve access token. Please check that your API key is correct."

?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, you could do response.raise_for_status() instead on line 330, which will raise a requests exception automatically if the status code isn't 200. That could be helpful because you could catch and re-raise the way you're doing here, but chain the exceptions together to get a more detailed traceback that includes the actual status code it failed with.

Copy link
Contributor

@antalszava antalszava Jan 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh @jswinarton, it seems that it would allow other non-error response status codes though (like 201), correct?

)

def _request(self, method: str, path: str, headers: Dict = None, **kwargs):
"""Wrap all API requests with an authentication token refresh if a 401 status
is received from the initial request.

Args:
method (str): the HTTP request method to use
path (str): path of the endpoint to use
headers (dict): dictionary containing the headers of the request

Returns:
requests.Response: the response received for the sent request
"""
Comment on lines +338 to +349
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this explains my question above! 💯

headers = headers or {}
request_headers = {**headers, **self._headers}
response = requests.request(method, path, headers=request_headers, **kwargs)
if response.status_code == 401:
corvust marked this conversation as resolved.
Show resolved Hide resolved
# Refresh the access_token and retry the request
self._refresh_access_token()
request_headers = {**headers, **self._headers}
response = requests.request(method, path, headers=request_headers, **kwargs)
return response

@staticmethod
def _format_error_message(response: requests.Response) -> str:
body = response.json()
Expand Down
81 changes: 64 additions & 17 deletions tests/api/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
port = 443
"""

test_host = "SomeHost"
test_token = "SomeToken"

class MockResponse:
"""A mock response with a JSON or binary body."""
Expand All @@ -72,9 +74,11 @@ def content(self):
class TestConnection:
"""Tests for the ``Connection`` class."""

def test_init(self):
def test_init(self, monkeypatch):
antalszava marked this conversation as resolved.
Show resolved Hide resolved
"""Tests that a ``Connection`` is initialized correctly."""
token, host, port, use_ssl = "token", "host", 123, True

monkeypatch.setattr(requests, "request", mock_return(MockResponse(200, {})))
antalszava marked this conversation as resolved.
Show resolved Hide resolved
connection = Connection(token, host, port, use_ssl)

assert connection.token == token
Expand All @@ -95,12 +99,13 @@ def test_get_device_spec(self, prog, connection, monkeypatch):

monkeypatch.setattr(
requests,
"get",
"request",
mock_return(MockResponse(
200,
{"layout": "", "modes": 42, "compiler": [], "gate_parameters": {"param": [[0, 1]]}}
)),
)
monkeypatch.setattr(connection, "_refresh_access_token", lambda: None)
antalszava marked this conversation as resolved.
Show resolved Hide resolved

device_spec = connection.get_device_spec(target)

Expand All @@ -114,7 +119,7 @@ def test_get_device_spec(self, prog, connection, monkeypatch):

def test_get_device_spec_error(self, connection, monkeypatch):
"""Tests a failed device spec request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get device specifications"):
connection.get_device_spec("123")
Expand All @@ -124,7 +129,7 @@ def test_create_job(self, prog, connection, monkeypatch):
id_, status = "123", JobStatus.QUEUED

monkeypatch.setattr(
requests, "post", mock_return(MockResponse(201, {"id": id_, "status": status})),
requests, "request", mock_return(MockResponse(201, {"id": id_, "status": status})),
)

job = connection.create_job("X8_01", prog, {"shots": 1})
Expand All @@ -134,7 +139,7 @@ def test_create_job(self, prog, connection, monkeypatch):

def test_create_job_error(self, prog, connection, monkeypatch):
"""Tests a failed job creation flow."""
monkeypatch.setattr(requests, "post", mock_return(MockResponse(400, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(400, {})))

with pytest.raises(RequestFailedError, match="Failed to create job"):
connection.create_job("X8_01", prog, {"shots": 1})
Expand All @@ -151,7 +156,7 @@ def test_get_all_jobs(self, connection, monkeypatch):
for i in range(1, 10)
]
monkeypatch.setattr(
requests, "get", mock_return(MockResponse(200, {"data": jobs})),
requests, "request", mock_return(MockResponse(200, {"data": jobs})),
)

jobs = connection.get_all_jobs(after=datetime(2020, 1, 5))
Expand All @@ -161,7 +166,7 @@ def test_get_all_jobs(self, connection, monkeypatch):
@pytest.mark.xfail(reason="method not yet implemented")
def test_get_all_jobs_error(self, connection, monkeypatch):
"""Tests a failed job list request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get all jobs"):
connection.get_all_jobs()
Expand All @@ -172,7 +177,7 @@ def test_get_job(self, connection, monkeypatch):

monkeypatch.setattr(
requests,
"get",
"request",
mock_return(MockResponse(200, {"id": id_, "status": status.value, "meta": meta})),
)

Expand All @@ -184,7 +189,7 @@ def test_get_job(self, connection, monkeypatch):

def test_get_job_error(self, connection, monkeypatch):
"""Tests a failed job request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get job"):
connection.get_job("123")
Expand All @@ -195,15 +200,15 @@ def test_get_job_status(self, connection, monkeypatch):

monkeypatch.setattr(
requests,
"get",
"request",
mock_return(MockResponse(200, {"id": id_, "status": status.value, "meta": {}})),
)

assert connection.get_job_status(id_) == status.value

def test_get_job_status_error(self, connection, monkeypatch):
"""Tests a failed job status request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get job"):
connection.get_job_status("123")
Expand Down Expand Up @@ -231,7 +236,7 @@ def test_get_job_result(self, connection, result_dtype, monkeypatch):
np.save(buf, result_samples)
buf.seek(0)
monkeypatch.setattr(
requests, "get", mock_return(MockResponse(200, binary_body=buf.getvalue())),
requests, "request", mock_return(MockResponse(200, binary_body=buf.getvalue())),
)

result = connection.get_job_result("123")
Expand All @@ -240,7 +245,7 @@ def test_get_job_result(self, connection, result_dtype, monkeypatch):

def test_get_job_result_error(self, connection, monkeypatch):
"""Tests a failed job result request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get job result"):
connection.get_job_result("123")
Expand All @@ -255,30 +260,72 @@ def function(*args, **kwargs):

return function

monkeypatch.setattr(requests, "patch", _mock_return(MockResponse(204, {})))
monkeypatch.setattr(requests, "request", _mock_return(MockResponse(204, {})))

# A successful cancellation does not raise an exception
connection.cancel_job("123")

def test_cancel_job_error(self, connection, monkeypatch):
"""Tests a failed job cancellation request."""
monkeypatch.setattr(requests, "patch", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to cancel job"):
connection.cancel_job("123")

def test_ping_success(self, connection, monkeypatch):
"""Tests a successful ping to the remote host."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(200, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(200, {})))

assert connection.ping()

def test_ping_failure(self, connection, monkeypatch):
"""Tests a failed ping to the remote host."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(500, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(500, {})))

assert not connection.ping()

def test_refresh_access_token(self, mocker, monkeypatch):
"""Test that the access token is created by passing the expected headers."""
path = "/auth/realms/platform/protocol/openid-connect/token"

data={
"grant_type": "refresh_token",
"refresh_token": test_token,
"client_id": "public",
}

monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {})))
spy = mocker.spy(requests, "post")

conn = Connection(token=test_token, host=test_host)
conn._refresh_access_token()
expected_headers = {'Accept-Version': conn.api_version}
expected_url = f"https://{test_host}:443{path}"
spy.assert_called_once_with(expected_url, headers=expected_headers, data=data)

def test_refresh_access_token_raises(self, monkeypatch):
"""Test that an error is raised when the access token could not be
generated while creating the Connection object."""
monkeypatch.setattr(requests, "post", mock_return(MockResponse(500, {})))
conn = Connection(token=test_token, host=test_host)
with pytest.raises(RequestFailedError, match="Authorization failed for request"):
conn._refresh_access_token()

def test_wrapped_request_refreshes(self, mocker, monkeypatch):
"""Test that the _request method refreshes the access token when
getting a 401 response."""
# Mock post function used while refreshing
monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {})))

# Mock request function used for general requests
monkeypatch.setattr(requests, "request", mock_return(MockResponse(401, {})))

conn = Connection(token=test_token, host=test_host)

spy = mocker.spy(conn, "_refresh_access_token")
conn._request("SomeRequestMethod", "SomePath")
spy.assert_called_once_with()


class TestConnectionIntegration:
"""Integration tests for using instances of the Connection."""
Expand Down