From 1b7bc5a51f84f5d152f388cc3891e4f1b81a718e Mon Sep 17 00:00:00 2001 From: "timl@xanadu.ai" Date: Mon, 11 Jan 2021 12:46:30 -0500 Subject: [PATCH 01/21] Implement OAuth refresh token flow --- strawberryfields/api/connection.py | 55 ++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 368095b86..3d2980849 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -16,7 +16,7 @@ """ import io from datetime import datetime -from typing import List +from typing import List, Dict import numpy as np import requests @@ -96,7 +96,9 @@ def __init__( self._verbose = verbose self._base_url = "http{}://{}:{}".format("s" if self.use_ssl else "", self.host, self.port) - self._headers = {"Authorization": self.token, "Accept-Version": self.api_version} + + self._headers = {"Accept-Version": self.api_version} + self._refresh_access_token() self.log = create_logger(__name__) @@ -156,7 +158,7 @@ def get_device_spec(self, target: str) -> DeviceSpec: def _get_device_dict(self, target: str) -> dict: """Returns the device specifications as a dictionary""" path = f"/devices/{target}/specifications" - response = requests.get(self._url(path), headers=self._headers) + response = self._request("GET", self._url(path), headers=self._headers) if response.status_code == 200: self.log.info("The device spec %s has been successfully retrieved.", target) @@ -185,7 +187,7 @@ def create_job(self, target: str, program: Program, run_options: dict = None) -> circuit = bb.serialize() path = "/jobs" - response = requests.post(self._url(path), headers=self._headers, json={"circuit": circuit}) + response = self._request("POST", self._url(path), headers=self._headers, json={"circuit": circuit}) if response.status_code == 201: job_id = response.json()["id"] if self._verbose: @@ -221,7 +223,7 @@ def get_job(self, job_id: str) -> Job: strawberryfields.api.Job: the job """ path = "/jobs/{}".format(job_id) - response = requests.get(self._url(path), headers=self._headers) + response = self._request("GET", self._url(path), headers=self._headers) if response.status_code == 200: return Job( id_=response.json()["id"], @@ -254,8 +256,9 @@ def get_job_result(self, job_id: str) -> Result: strawberryfields.api.Result: the job result """ path = "/jobs/{}/result".format(job_id) - response = requests.get( - self._url(path), headers={"Accept": "application/x-numpy", **self._headers} + response = self._request( + "GET", self._url(path), + headers={"Accept": "application/x-numpy", **self._headers} ) if response.status_code == 200: # Read the numpy binary data in the payload into memory @@ -283,8 +286,9 @@ def cancel_job(self, job_id: str): job_id (str): the job ID """ path = "/jobs/{}".format(job_id) - response = requests.patch( - self._url(path), headers=self._headers, json={"status": JobStatus.CANCELLED.value} + response = self._request( + "PATCH", self._url(path), + headers=self._headers, json={"status": JobStatus.CANCELLED.value} ) if response.status_code == 204: if self._verbose: @@ -301,12 +305,43 @@ def ping(self) -> bool: bool: ``True`` if the connection is successful, and ``False`` otherwise """ path = "/healthz" - response = requests.get(self._url(path), headers=self._headers) + response = self._request("GET", self._url(path), headers=self._headers) return response.status_code == 200 def _url(self, path: str) -> str: return self._base_url + path + def _refresh_access_token(self): + """ Use the offline token to request a new access token """ + self._headers.pop("Authorization", None) + # TODO: Make sure this is the right path + path = "/auth/token" + headers = {**self._headers} + response = self._request("POST", self._url(path), headers=headers, data={ + 'grant_type': 'refresh_token', + 'refresh_token': self._token, + 'client_id': 'public', + }) + if response.status_code == 200: + self._headers["Authorization"] = "Bearer {}".format(response.cookies["access_token"]) + else: + raise RequestFailedError( + "Authorization failed for request" + ) + + def _request(self, method: str, path: str, headers: Dict = {}, **kwargs ): + """ Wrap all API requests with an auth token refresh if a 401 is recevied + from the initial request. + """ + request_headers = {**headers, **self._headers} + response = requests.request(method, self._url(path), headers=request_headers, **kwargs) + if response.status_code == 401: + # Refresh the access_token and retry the request + self._refresh_access_token() + request_headers = {**headers, **self._headers} + response = requests.request(method, self._url(path), headers=request_headers, **kwargs) + return response + @staticmethod def _format_error_message(response: requests.Response) -> str: body = response.json() From e785bcf339140c1710088e9bf4645afd6b68c4bf Mon Sep 17 00:00:00 2001 From: corvust Date: Tue, 12 Jan 2021 09:23:54 -0500 Subject: [PATCH 02/21] Apply suggestions from code review Formatting changes Co-authored-by: antalszava --- strawberryfields/api/connection.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 3d2980849..bc874d782 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -312,7 +312,7 @@ def _url(self, path: str) -> str: return self._base_url + path def _refresh_access_token(self): - """ Use the offline token to request a new access token """ + """Use the offline token to request a new access token.""" self._headers.pop("Authorization", None) # TODO: Make sure this is the right path path = "/auth/token" @@ -330,8 +330,16 @@ def _refresh_access_token(self): ) def _request(self, method: str, path: str, headers: Dict = {}, **kwargs ): - """ Wrap all API requests with an auth token refresh if a 401 is recevied - from the initial request. + """Wrap all API requests with an authentication token refresh if a 401 status + is received from the initial request. + + Args: + method (str): the HTTP request method to use + path (str): path of the endpoint to use + headers (dict): dictionary containing the headers of the request + + Returns: + requests.Response: the response received for the sent request """ request_headers = {**headers, **self._headers} response = requests.request(method, self._url(path), headers=request_headers, **kwargs) From 8eda24eb6772da29287151ba915cfe311cd0b88b Mon Sep 17 00:00:00 2001 From: "timl@xanadu.ai" Date: Tue, 12 Jan 2021 09:24:58 -0500 Subject: [PATCH 03/21] Add correct token refresh path --- strawberryfields/api/connection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index bc874d782..71714f739 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -314,8 +314,7 @@ def _url(self, path: str) -> str: def _refresh_access_token(self): """Use the offline token to request a new access token.""" self._headers.pop("Authorization", None) - # TODO: Make sure this is the right path - path = "/auth/token" + path = "/auth/realms/platform/protocol/openid-connect/token" headers = {**self._headers} response = self._request("POST", self._url(path), headers=headers, data={ 'grant_type': 'refresh_token', From 7b69e862e87d1f9314aa8369e8d0a710953efbd6 Mon Sep 17 00:00:00 2001 From: Antal Szava Date: Tue, 12 Jan 2021 13:02:47 -0500 Subject: [PATCH 04/21] update path, update getting the access token, remove url wrapping in _request --- strawberryfields/api/connection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 3d2980849..854e30360 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -315,7 +315,7 @@ def _refresh_access_token(self): """ Use the offline token to request a new access token """ self._headers.pop("Authorization", None) # TODO: Make sure this is the right path - path = "/auth/token" + path = "/auth/realms/platform/protocol/openid-connect/token" headers = {**self._headers} response = self._request("POST", self._url(path), headers=headers, data={ 'grant_type': 'refresh_token', @@ -323,7 +323,7 @@ def _refresh_access_token(self): 'client_id': 'public', }) if response.status_code == 200: - self._headers["Authorization"] = "Bearer {}".format(response.cookies["access_token"]) + self._headers["Authorization"] = "Bearer {}".format(response.json().get('access_token')) else: raise RequestFailedError( "Authorization failed for request" @@ -334,7 +334,7 @@ def _request(self, method: str, path: str, headers: Dict = {}, **kwargs ): from the initial request. """ request_headers = {**headers, **self._headers} - response = requests.request(method, self._url(path), headers=request_headers, **kwargs) + response = requests.request(method, path, headers=request_headers, **kwargs) if response.status_code == 401: # Refresh the access_token and retry the request self._refresh_access_token() From 3b556e6d37484e3777d1a34cc5390ceb2f87d4ad Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 13:17:40 -0500 Subject: [PATCH 05/21] Formatting with black --- strawberryfields/api/connection.py | 38 +++++++++++++++++------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 639028990..581def33e 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -187,7 +187,9 @@ def create_job(self, target: str, program: Program, run_options: dict = None) -> circuit = bb.serialize() path = "/jobs" - response = self._request("POST", self._url(path), headers=self._headers, json={"circuit": circuit}) + response = self._request( + "POST", self._url(path), headers=self._headers, json={"circuit": circuit} + ) if response.status_code == 201: job_id = response.json()["id"] if self._verbose: @@ -257,8 +259,7 @@ def get_job_result(self, job_id: str) -> Result: """ path = "/jobs/{}/result".format(job_id) response = self._request( - "GET", self._url(path), - headers={"Accept": "application/x-numpy", **self._headers} + "GET", self._url(path), headers={"Accept": "application/x-numpy", **self._headers} ) if response.status_code == 200: # Read the numpy binary data in the payload into memory @@ -287,8 +288,10 @@ def cancel_job(self, job_id: str): """ path = "/jobs/{}".format(job_id) response = self._request( - "PATCH", self._url(path), - headers=self._headers, json={"status": JobStatus.CANCELLED.value} + "PATCH", + self._url(path), + headers=self._headers, + json={"status": JobStatus.CANCELLED.value}, ) if response.status_code == 204: if self._verbose: @@ -316,22 +319,25 @@ def _refresh_access_token(self): self._headers.pop("Authorization", None) path = "/auth/realms/platform/protocol/openid-connect/token" headers = {**self._headers} - response = self._request("POST", self._url(path), headers=headers, data={ - 'grant_type': 'refresh_token', - 'refresh_token': self._token, - 'client_id': 'public', - }) + response = self._request( + "POST", + self._url(path), + headers=headers, + data={ + "grant_type": "refresh_token", + "refresh_token": self._token, + "client_id": "public", + }, + ) if response.status_code == 200: - self._headers["Authorization"] = "Bearer {}".format(response.json().get('access_token')) + self._headers["Authorization"] = "Bearer {}".format(response.json().get("access_token")) else: - raise RequestFailedError( - "Authorization failed for request" - ) + raise RequestFailedError("Authorization failed for request") - def _request(self, method: str, path: str, headers: Dict = {}, **kwargs ): + def _request(self, method: str, path: str, headers: Dict = {}, **kwargs): """Wrap all API requests with an authentication token refresh if a 401 status is received from the initial request. - + Args: method (str): the HTTP request method to use path (str): path of the endpoint to use From e7ce730b11fd6eb93904c6643ad1de1cbecc9462 Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 13:19:52 -0500 Subject: [PATCH 06/21] move dict init into func --- strawberryfields/api/connection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 581def33e..afddbee30 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -334,7 +334,7 @@ def _refresh_access_token(self): else: raise RequestFailedError("Authorization failed for request") - def _request(self, method: str, path: str, headers: Dict = {}, **kwargs): + def _request(self, method: str, path: str, headers: Dict = None, **kwargs): """Wrap all API requests with an authentication token refresh if a 401 status is received from the initial request. @@ -346,6 +346,7 @@ def _request(self, method: str, path: str, headers: Dict = {}, **kwargs): Returns: requests.Response: the response received for the sent request """ + headers = headers or {} request_headers = {**headers, **self._headers} response = requests.request(method, path, headers=request_headers, **kwargs) if response.status_code == 401: From 233d36a868179b9d8b9768fc8ca9141bbf292318 Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 17:01:58 -0500 Subject: [PATCH 07/21] Updates --- strawberryfields/api/connection.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index afddbee30..980e9d956 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -330,9 +330,10 @@ def _refresh_access_token(self): }, ) if response.status_code == 200: - self._headers["Authorization"] = "Bearer {}".format(response.json().get("access_token")) + access_token = response.json().get("access_token") + self._headers["Authorization"] = f"Bearer {access_token}" else: - raise RequestFailedError("Authorization failed for request") + raise RequestFailedError("Authorization failed for request.") def _request(self, method: str, path: str, headers: Dict = None, **kwargs): """Wrap all API requests with an authentication token refresh if a 401 status @@ -353,7 +354,7 @@ def _request(self, method: str, path: str, headers: Dict = None, **kwargs): # Refresh the access_token and retry the request self._refresh_access_token() request_headers = {**headers, **self._headers} - response = requests.request(method, self._url(path), headers=request_headers, **kwargs) + response = requests.request(method, path, headers=request_headers, **kwargs) return response @staticmethod From a34f01d786c2e597a08906c3ac81a77c334bb906 Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 17:50:52 -0500 Subject: [PATCH 08/21] refresh access token unit tests --- strawberryfields/api/connection.py | 3 ++- tests/api/test_connection.py | 36 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 980e9d956..ebba149ad 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -89,6 +89,7 @@ def __init__( ): default_config = load_config() + print(token) self._token = token or default_config["api"]["authentication_token"] self._host = host or default_config["api"]["hostname"] self._port = port or default_config["api"]["port"] @@ -333,7 +334,7 @@ def _refresh_access_token(self): access_token = response.json().get("access_token") self._headers["Authorization"] = f"Bearer {access_token}" else: - raise RequestFailedError("Authorization failed for request.") + raise RequestFailedError("Authorization failed for request, please check your token provided.") def _request(self, method: str, path: str, headers: Dict = None, **kwargs): """Wrap all API requests with an authentication token refresh if a 401 status diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index 8cd0a7177..3f226d3b7 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -279,6 +279,42 @@ def test_ping_failure(self, connection, monkeypatch): assert not connection.ping() + def test_refresh_access_token_called(self, mocker, monkeypatch): + """Test that an access token is granted once a Connection object is created.""" + monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) + spy = mocker.spy(Connection, "_refresh_access_token") + conn = Connection() + spy.assert_called_once_with(conn) + + def test_refresh_access_token(self, mocker, monkeypatch): + """Test that the access token is created by passing the expected headers.""" + host = "SomeHost" + path = "/auth/realms/platform/protocol/openid-connect/token" + + token = "SomeToken" + data={ + "grant_type": "refresh_token", + "refresh_token": token, + "client_id": "public", + } + + monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(200, {}))) + spy = mocker.spy(requests, "request") + + conn = Connection(token=token, host=host) + expected_headers = {'Accept-Version': conn.api_version} + expected_url = f"https://{host}:443{path}" + spy.assert_called_once_with("POST", expected_url, headers=expected_headers, data=data) + + def test_refresh_access_token_raises(self, monkeypatch): + """Test that an error is raised when the access token could not be + generated while creating the Connection object.""" + monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(500, {}))) + with pytest.raises(RequestFailedError, match="Authorization failed for request"): + Connection(token="SomeToken", host="SomeHost") + class TestConnectionIntegration: """Integration tests for using instances of the Connection.""" From fc18f7abefa93912b329dcd99e06feaa211ed896 Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 17:56:18 -0500 Subject: [PATCH 09/21] no print --- strawberryfields/api/connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index ebba149ad..f6f02bb71 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -89,7 +89,6 @@ def __init__( ): default_config = load_config() - print(token) self._token = token or default_config["api"]["authentication_token"] self._host = host or default_config["api"]["hostname"] self._port = port or default_config["api"]["port"] From 2953698ee9287c68b54d6b446baf2804fc846255 Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 18:38:45 -0500 Subject: [PATCH 10/21] User request.post directly, update tests --- strawberryfields/api/connection.py | 3 +-- tests/api/test_connection.py | 40 +++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index f6f02bb71..2cdfcc3b2 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -319,8 +319,7 @@ def _refresh_access_token(self): self._headers.pop("Authorization", None) path = "/auth/realms/platform/protocol/openid-connect/token" headers = {**self._headers} - response = self._request( - "POST", + response = requests.post( self._url(path), headers=headers, data={ diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index 3f226d3b7..7a9c4b812 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -299,22 +299,54 @@ def test_refresh_access_token(self, mocker, monkeypatch): } monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) - monkeypatch.setattr(requests, "request", mock_return(MockResponse(200, {}))) - spy = mocker.spy(requests, "request") + monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {}))) + spy = mocker.spy(requests, "post") conn = Connection(token=token, host=host) expected_headers = {'Accept-Version': conn.api_version} expected_url = f"https://{host}:443{path}" - spy.assert_called_once_with("POST", expected_url, headers=expected_headers, data=data) + spy.assert_called_once_with(expected_url, headers=expected_headers, data=data) def test_refresh_access_token_raises(self, monkeypatch): """Test that an error is raised when the access token could not be generated while creating the Connection object.""" monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) - monkeypatch.setattr(requests, "request", mock_return(MockResponse(500, {}))) + monkeypatch.setattr(requests, "post", mock_return(MockResponse(500, {}))) with pytest.raises(RequestFailedError, match="Authorization failed for request"): Connection(token="SomeToken", host="SomeHost") + def test_wrapped_request(self, monkeypatch): + """Test that the access token is created by passing the expected headers.""" + def mock_request(): + count = [] + def func(*args, **kwargs): + if len(count) > 2: + return MockResponse(401, {}) + + count.append(1) + return MockResponse(201, {}) + return func + + monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) + monkeypatch.setattr(requests, "request", mock_request()) + host = "SomeHost" + path = "/auth/realms/platform/protocol/openid-connect/token" + token = "SomeToken" + + data={ + "grant_type": "refresh_token", + "refresh_token": token, + "client_id": "public", + } + expected_url = f"https://{host}:443{path}" + + conn = Connection(token=token, host=host) + expected_headers = {'Accept-Version': conn.api_version} + + spy = mocker.spy(conn, "_refresh_access_token") + conn._request("SomeRequestMethod", "SomePath") + spy.assert_called_once_with("POST", expected_url, headers=expected_headers, data=data) + class TestConnectionIntegration: """Integration tests for using instances of the Connection.""" From e0d5ed49420eb728bfeb8c3ccecb843c338a2464 Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 18:52:49 -0500 Subject: [PATCH 11/21] Wrapped request test, updates --- tests/api/test_connection.py | 53 ++++++++++++------------------------ 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index 7a9c4b812..94df54dda 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -50,6 +50,8 @@ port = 443 """ +test_host = "SomeHost" +test_token = "SomeToken" class MockResponse: """A mock response with a JSON or binary body.""" @@ -281,71 +283,50 @@ def test_ping_failure(self, connection, monkeypatch): def test_refresh_access_token_called(self, mocker, monkeypatch): """Test that an access token is granted once a Connection object is created.""" - monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) + monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {}))) spy = mocker.spy(Connection, "_refresh_access_token") - conn = Connection() + conn = Connection(token=test_token) spy.assert_called_once_with(conn) def test_refresh_access_token(self, mocker, monkeypatch): """Test that the access token is created by passing the expected headers.""" - host = "SomeHost" path = "/auth/realms/platform/protocol/openid-connect/token" - token = "SomeToken" data={ "grant_type": "refresh_token", - "refresh_token": token, + "refresh_token": test_token, "client_id": "public", } - monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {}))) spy = mocker.spy(requests, "post") - conn = Connection(token=token, host=host) + conn = Connection(token=test_token, host=test_host) expected_headers = {'Accept-Version': conn.api_version} - expected_url = f"https://{host}:443{path}" + expected_url = f"https://{test_host}:443{path}" spy.assert_called_once_with(expected_url, headers=expected_headers, data=data) def test_refresh_access_token_raises(self, monkeypatch): """Test that an error is raised when the access token could not be generated while creating the Connection object.""" - monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) monkeypatch.setattr(requests, "post", mock_return(MockResponse(500, {}))) with pytest.raises(RequestFailedError, match="Authorization failed for request"): - Connection(token="SomeToken", host="SomeHost") + Connection(token=test_token, host=test_host) - def test_wrapped_request(self, monkeypatch): - """Test that the access token is created by passing the expected headers.""" - def mock_request(): - count = [] - def func(*args, **kwargs): - if len(count) > 2: - return MockResponse(401, {}) - - count.append(1) - return MockResponse(201, {}) - return func - - monkeypatch.delenv("SF_API_AUTHENTICATION_TOKEN", raising=False) - monkeypatch.setattr(requests, "request", mock_request()) - host = "SomeHost" - path = "/auth/realms/platform/protocol/openid-connect/token" - token = "SomeToken" + def test_wrapped_request_refreshes(self, mocker, monkeypatch): + """Test that a wrapped request refreshes the access token when getting + a 401 response.""" + # Mock post used while refreshing + monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {}))) - data={ - "grant_type": "refresh_token", - "refresh_token": token, - "client_id": "public", - } - expected_url = f"https://{host}:443{path}" + # Mock request used for general requests + monkeypatch.setattr(requests, "request", mock_return(MockResponse(401, {}))) - conn = Connection(token=token, host=host) - expected_headers = {'Accept-Version': conn.api_version} + conn = Connection(token=test_token, host=test_host) spy = mocker.spy(conn, "_refresh_access_token") conn._request("SomeRequestMethod", "SomePath") - spy.assert_called_once_with("POST", expected_url, headers=expected_headers, data=data) + spy.assert_called_once_with() class TestConnectionIntegration: From 6828b027f7887a654620cc1edc8b84ef8e86c1b0 Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 18:54:30 -0500 Subject: [PATCH 12/21] Updates --- tests/api/test_connection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index 94df54dda..188cb3f54 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -314,12 +314,12 @@ def test_refresh_access_token_raises(self, monkeypatch): Connection(token=test_token, host=test_host) def test_wrapped_request_refreshes(self, mocker, monkeypatch): - """Test that a wrapped request refreshes the access token when getting - a 401 response.""" - # Mock post used while refreshing + """Test that the _request method refreshes the access token when + getting a 401 response.""" + # Mock post function used while refreshing monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {}))) - # Mock request used for general requests + # Mock request function used for general requests monkeypatch.setattr(requests, "request", mock_return(MockResponse(401, {}))) conn = Connection(token=test_token, host=test_host) From f28ec567d22a7a01fdcd685140a37dc812c5bee7 Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 18:57:04 -0500 Subject: [PATCH 13/21] Formatting --- strawberryfields/api/connection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 2cdfcc3b2..552dc37b3 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -332,7 +332,9 @@ def _refresh_access_token(self): access_token = response.json().get("access_token") self._headers["Authorization"] = f"Bearer {access_token}" else: - raise RequestFailedError("Authorization failed for request, please check your token provided.") + raise RequestFailedError( + "Authorization failed for request, please check your token provided." + ) def _request(self, method: str, path: str, headers: Dict = None, **kwargs): """Wrap all API requests with an authentication token refresh if a 401 status From d59f165afaedec90627e17e0df07472d0c68e059 Mon Sep 17 00:00:00 2001 From: antalszava Date: Tue, 12 Jan 2021 23:50:49 -0500 Subject: [PATCH 14/21] Remove access token refreshing from init --- strawberryfields/api/connection.py | 1 - tests/api/test_connection.py | 48 ++++++++++++++---------------- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 552dc37b3..831ff75f9 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -98,7 +98,6 @@ def __init__( self._base_url = "http{}://{}:{}".format("s" if self.use_ssl else "", self.host, self.port) self._headers = {"Accept-Version": self.api_version} - self._refresh_access_token() self.log = create_logger(__name__) diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index 188cb3f54..f73cd56c8 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -74,9 +74,11 @@ def content(self): class TestConnection: """Tests for the ``Connection`` class.""" - def test_init(self): + def test_init(self, monkeypatch): """Tests that a ``Connection`` is initialized correctly.""" token, host, port, use_ssl = "token", "host", 123, True + + monkeypatch.setattr(requests, "request", mock_return(MockResponse(200, {}))) connection = Connection(token, host, port, use_ssl) assert connection.token == token @@ -97,12 +99,13 @@ def test_get_device_spec(self, prog, connection, monkeypatch): monkeypatch.setattr( requests, - "get", + "request", mock_return(MockResponse( 200, {"layout": "", "modes": 42, "compiler": [], "gate_parameters": {"param": [[0, 1]]}} )), ) + monkeypatch.setattr(connection, "_refresh_access_token", lambda: None) device_spec = connection.get_device_spec(target) @@ -116,7 +119,7 @@ def test_get_device_spec(self, prog, connection, monkeypatch): def test_get_device_spec_error(self, connection, monkeypatch): """Tests a failed device spec request.""" - monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {}))) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {}))) with pytest.raises(RequestFailedError, match="Failed to get device specifications"): connection.get_device_spec("123") @@ -126,7 +129,7 @@ def test_create_job(self, prog, connection, monkeypatch): id_, status = "123", JobStatus.QUEUED monkeypatch.setattr( - requests, "post", mock_return(MockResponse(201, {"id": id_, "status": status})), + requests, "request", mock_return(MockResponse(201, {"id": id_, "status": status})), ) job = connection.create_job("X8_01", prog, {"shots": 1}) @@ -136,7 +139,7 @@ def test_create_job(self, prog, connection, monkeypatch): def test_create_job_error(self, prog, connection, monkeypatch): """Tests a failed job creation flow.""" - monkeypatch.setattr(requests, "post", mock_return(MockResponse(400, {}))) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(400, {}))) with pytest.raises(RequestFailedError, match="Failed to create job"): connection.create_job("X8_01", prog, {"shots": 1}) @@ -153,7 +156,7 @@ def test_get_all_jobs(self, connection, monkeypatch): for i in range(1, 10) ] monkeypatch.setattr( - requests, "get", mock_return(MockResponse(200, {"data": jobs})), + requests, "request", mock_return(MockResponse(200, {"data": jobs})), ) jobs = connection.get_all_jobs(after=datetime(2020, 1, 5)) @@ -163,7 +166,7 @@ def test_get_all_jobs(self, connection, monkeypatch): @pytest.mark.xfail(reason="method not yet implemented") def test_get_all_jobs_error(self, connection, monkeypatch): """Tests a failed job list request.""" - monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {}))) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {}))) with pytest.raises(RequestFailedError, match="Failed to get all jobs"): connection.get_all_jobs() @@ -174,7 +177,7 @@ def test_get_job(self, connection, monkeypatch): monkeypatch.setattr( requests, - "get", + "request", mock_return(MockResponse(200, {"id": id_, "status": status.value, "meta": meta})), ) @@ -186,7 +189,7 @@ def test_get_job(self, connection, monkeypatch): def test_get_job_error(self, connection, monkeypatch): """Tests a failed job request.""" - monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {}))) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {}))) with pytest.raises(RequestFailedError, match="Failed to get job"): connection.get_job("123") @@ -197,7 +200,7 @@ def test_get_job_status(self, connection, monkeypatch): monkeypatch.setattr( requests, - "get", + "request", mock_return(MockResponse(200, {"id": id_, "status": status.value, "meta": {}})), ) @@ -205,7 +208,7 @@ def test_get_job_status(self, connection, monkeypatch): def test_get_job_status_error(self, connection, monkeypatch): """Tests a failed job status request.""" - monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {}))) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {}))) with pytest.raises(RequestFailedError, match="Failed to get job"): connection.get_job_status("123") @@ -233,7 +236,7 @@ def test_get_job_result(self, connection, result_dtype, monkeypatch): np.save(buf, result_samples) buf.seek(0) monkeypatch.setattr( - requests, "get", mock_return(MockResponse(200, binary_body=buf.getvalue())), + requests, "request", mock_return(MockResponse(200, binary_body=buf.getvalue())), ) result = connection.get_job_result("123") @@ -242,7 +245,7 @@ def test_get_job_result(self, connection, result_dtype, monkeypatch): def test_get_job_result_error(self, connection, monkeypatch): """Tests a failed job result request.""" - monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {}))) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {}))) with pytest.raises(RequestFailedError, match="Failed to get job result"): connection.get_job_result("123") @@ -257,37 +260,30 @@ def function(*args, **kwargs): return function - monkeypatch.setattr(requests, "patch", _mock_return(MockResponse(204, {}))) + monkeypatch.setattr(requests, "request", _mock_return(MockResponse(204, {}))) # A successful cancellation does not raise an exception connection.cancel_job("123") def test_cancel_job_error(self, connection, monkeypatch): """Tests a failed job cancellation request.""" - monkeypatch.setattr(requests, "patch", mock_return(MockResponse(404, {}))) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {}))) with pytest.raises(RequestFailedError, match="Failed to cancel job"): connection.cancel_job("123") def test_ping_success(self, connection, monkeypatch): """Tests a successful ping to the remote host.""" - monkeypatch.setattr(requests, "get", mock_return(MockResponse(200, {}))) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(200, {}))) assert connection.ping() def test_ping_failure(self, connection, monkeypatch): """Tests a failed ping to the remote host.""" - monkeypatch.setattr(requests, "get", mock_return(MockResponse(500, {}))) + monkeypatch.setattr(requests, "request", mock_return(MockResponse(500, {}))) assert not connection.ping() - def test_refresh_access_token_called(self, mocker, monkeypatch): - """Test that an access token is granted once a Connection object is created.""" - monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {}))) - spy = mocker.spy(Connection, "_refresh_access_token") - conn = Connection(token=test_token) - spy.assert_called_once_with(conn) - def test_refresh_access_token(self, mocker, monkeypatch): """Test that the access token is created by passing the expected headers.""" path = "/auth/realms/platform/protocol/openid-connect/token" @@ -302,6 +298,7 @@ def test_refresh_access_token(self, mocker, monkeypatch): spy = mocker.spy(requests, "post") conn = Connection(token=test_token, host=test_host) + conn._refresh_access_token() expected_headers = {'Accept-Version': conn.api_version} expected_url = f"https://{test_host}:443{path}" spy.assert_called_once_with(expected_url, headers=expected_headers, data=data) @@ -310,8 +307,9 @@ def test_refresh_access_token_raises(self, monkeypatch): """Test that an error is raised when the access token could not be generated while creating the Connection object.""" monkeypatch.setattr(requests, "post", mock_return(MockResponse(500, {}))) + conn = Connection(token=test_token, host=test_host) with pytest.raises(RequestFailedError, match="Authorization failed for request"): - Connection(token=test_token, host=test_host) + conn._refresh_access_token() def test_wrapped_request_refreshes(self, mocker, monkeypatch): """Test that the _request method refreshes the access token when From 985c53cd114dcf9a552aa7474c98bfb083e78319 Mon Sep 17 00:00:00 2001 From: antalszava Date: Wed, 13 Jan 2021 00:30:32 -0500 Subject: [PATCH 15/21] Update tests/api/test_connection.py --- tests/api/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index f73cd56c8..c4ea69219 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -74,7 +74,7 @@ def content(self): class TestConnection: """Tests for the ``Connection`` class.""" - def test_init(self, monkeypatch): + def test_init(self): """Tests that a ``Connection`` is initialized correctly.""" token, host, port, use_ssl = "token", "host", 123, True From 3c3bcd76907617956acca08a49d76785bdc9bbf4 Mon Sep 17 00:00:00 2001 From: antalszava Date: Wed, 13 Jan 2021 00:31:15 -0500 Subject: [PATCH 16/21] Update tests/api/test_connection.py --- tests/api/test_connection.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index c4ea69219..527603c12 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -77,8 +77,6 @@ class TestConnection: def test_init(self): """Tests that a ``Connection`` is initialized correctly.""" token, host, port, use_ssl = "token", "host", 123, True - - monkeypatch.setattr(requests, "request", mock_return(MockResponse(200, {}))) connection = Connection(token, host, port, use_ssl) assert connection.token == token From dfe1d6f3c649ee005a11edbc5aeb44ab080ad1fa Mon Sep 17 00:00:00 2001 From: antalszava Date: Wed, 13 Jan 2021 00:31:58 -0500 Subject: [PATCH 17/21] Update tests/api/test_connection.py --- tests/api/test_connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index 527603c12..6fa1343e4 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -103,7 +103,6 @@ def test_get_device_spec(self, prog, connection, monkeypatch): {"layout": "", "modes": 42, "compiler": [], "gate_parameters": {"param": [[0, 1]]}} )), ) - monkeypatch.setattr(connection, "_refresh_access_token", lambda: None) device_spec = connection.get_device_spec(target) From 4d92ce299888d4c2a25c8feba562d106f20c83fa Mon Sep 17 00:00:00 2001 From: antalszava Date: Wed, 13 Jan 2021 00:31:58 -0500 Subject: [PATCH 18/21] Update tests/api/test_connection.py --- tests/api/test_connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index 527603c12..6fa1343e4 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -103,7 +103,6 @@ def test_get_device_spec(self, prog, connection, monkeypatch): {"layout": "", "modes": 42, "compiler": [], "gate_parameters": {"param": [[0, 1]]}} )), ) - monkeypatch.setattr(connection, "_refresh_access_token", lambda: None) device_spec = connection.get_device_spec(target) From 4268d1889211cf5dee3df8ffbc83a04104d3c341 Mon Sep 17 00:00:00 2001 From: antalszava Date: Thu, 14 Jan 2021 11:31:14 -0500 Subject: [PATCH 19/21] Update strawberryfields/api/connection.py Co-authored-by: Jeremy Swinarton --- strawberryfields/api/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index 831ff75f9..a433ded46 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -332,7 +332,7 @@ def _refresh_access_token(self): self._headers["Authorization"] = f"Bearer {access_token}" else: raise RequestFailedError( - "Authorization failed for request, please check your token provided." + "Could not retrieve access token. Please check that your API key is correct." ) def _request(self, method: str, path: str, headers: Dict = None, **kwargs): From 68a345bde2f89d4dbf1a48a05a46120b9a4a635a Mon Sep 17 00:00:00 2001 From: antalszava Date: Thu, 14 Jan 2021 11:35:19 -0500 Subject: [PATCH 20/21] update msg in test --- tests/api/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index 6fa1343e4..a78cee14c 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -305,7 +305,7 @@ def test_refresh_access_token_raises(self, monkeypatch): generated while creating the Connection object.""" monkeypatch.setattr(requests, "post", mock_return(MockResponse(500, {}))) conn = Connection(token=test_token, host=test_host) - with pytest.raises(RequestFailedError, match="Authorization failed for request"): + with pytest.raises(RequestFailedError, match="Could not retrieve access token"): conn._refresh_access_token() def test_wrapped_request_refreshes(self, mocker, monkeypatch): From 169357657a8f805837b2ad0a5ce8500ec3fdba5f Mon Sep 17 00:00:00 2001 From: antalszava Date: Thu, 14 Jan 2021 18:27:07 -0500 Subject: [PATCH 21/21] changelog --- .github/CHANGELOG.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 0c20b7610..cff4a0245 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -104,7 +104,11 @@ * `TDMProgram` objects can now be serialized into Blackbird scripts, and vice versa. [(#476)](https://github.com/XanaduAI/strawberryfields/pull/476) -

Breaking changes

+

Breaking Changes

+ +* Jobs are submitted to the Xanadu Quantum Cloud through a new OAuth based + authentication flow using offline refresh tokens and access tokens. + [(#520)](https://github.com/XanaduAI/strawberryfields/pull/520)

Bug fixes

@@ -138,8 +142,8 @@ This release contains contributions from (in alphabetical order): -Tom Bromley, Jack Brown, Theodor Isacsson, Josh Izaac, Fabian Laudenbach, Nicolas Quesada, -Antal Száva. +Tom Bromley, Jack Brown, Theodor Isacsson, Josh Izaac, Fabian Laudenbach, Tim Leisti, +Nicolas Quesada, Antal Száva. # Release 0.16.0 (current release)