-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement OAuth refresh token flow #520
Changes from 17 commits
1b7bc5a
e785bcf
8eda24e
7b69e86
9be9379
3b556e6
e7ce730
233d36a
a34f01d
fc18f7a
2953698
e0d5ed4
6828b02
f28ec56
d59f165
985c53c
3c3bcd7
dfe1d6f
4d92ce2
c96ca2f
4268d18
68a345b
1693576
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -16,7 +16,7 @@ | |||||||
""" | ||||||||
import io | ||||||||
from datetime import datetime | ||||||||
from typing import List | ||||||||
from typing import List, Dict | ||||||||
|
||||||||
import numpy as np | ||||||||
import requests | ||||||||
|
@@ -96,7 +96,8 @@ def __init__( | |||||||
self._verbose = verbose | ||||||||
|
||||||||
self._base_url = "http{}://{}:{}".format("s" if self.use_ssl else "", self.host, self.port) | ||||||||
self._headers = {"Authorization": self.token, "Accept-Version": self.api_version} | ||||||||
|
||||||||
self._headers = {"Accept-Version": self.api_version} | ||||||||
|
||||||||
self.log = create_logger(__name__) | ||||||||
|
||||||||
|
@@ -156,7 +157,7 @@ def get_device_spec(self, target: str) -> DeviceSpec: | |||||||
def _get_device_dict(self, target: str) -> dict: | ||||||||
"""Returns the device specifications as a dictionary""" | ||||||||
path = f"/devices/{target}/specifications" | ||||||||
response = requests.get(self._url(path), headers=self._headers) | ||||||||
response = self._request("GET", self._url(path), headers=self._headers) | ||||||||
|
||||||||
if response.status_code == 200: | ||||||||
self.log.info("The device spec %s has been successfully retrieved.", target) | ||||||||
|
@@ -185,7 +186,9 @@ def create_job(self, target: str, program: Program, run_options: dict = None) -> | |||||||
circuit = bb.serialize() | ||||||||
|
||||||||
path = "/jobs" | ||||||||
response = requests.post(self._url(path), headers=self._headers, json={"circuit": circuit}) | ||||||||
response = self._request( | ||||||||
"POST", self._url(path), headers=self._headers, json={"circuit": circuit} | ||||||||
) | ||||||||
if response.status_code == 201: | ||||||||
job_id = response.json()["id"] | ||||||||
if self._verbose: | ||||||||
|
@@ -221,7 +224,7 @@ def get_job(self, job_id: str) -> Job: | |||||||
strawberryfields.api.Job: the job | ||||||||
""" | ||||||||
path = "/jobs/{}".format(job_id) | ||||||||
response = requests.get(self._url(path), headers=self._headers) | ||||||||
response = self._request("GET", self._url(path), headers=self._headers) | ||||||||
if response.status_code == 200: | ||||||||
return Job( | ||||||||
id_=response.json()["id"], | ||||||||
|
@@ -254,8 +257,8 @@ def get_job_result(self, job_id: str) -> Result: | |||||||
strawberryfields.api.Result: the job result | ||||||||
""" | ||||||||
path = "/jobs/{}/result".format(job_id) | ||||||||
response = requests.get( | ||||||||
self._url(path), headers={"Accept": "application/x-numpy", **self._headers} | ||||||||
response = self._request( | ||||||||
"GET", self._url(path), headers={"Accept": "application/x-numpy", **self._headers} | ||||||||
) | ||||||||
if response.status_code == 200: | ||||||||
# Read the numpy binary data in the payload into memory | ||||||||
|
@@ -283,8 +286,11 @@ def cancel_job(self, job_id: str): | |||||||
job_id (str): the job ID | ||||||||
""" | ||||||||
path = "/jobs/{}".format(job_id) | ||||||||
response = requests.patch( | ||||||||
self._url(path), headers=self._headers, json={"status": JobStatus.CANCELLED.value} | ||||||||
response = self._request( | ||||||||
"PATCH", | ||||||||
self._url(path), | ||||||||
headers=self._headers, | ||||||||
json={"status": JobStatus.CANCELLED.value}, | ||||||||
) | ||||||||
if response.status_code == 204: | ||||||||
if self._verbose: | ||||||||
|
@@ -301,12 +307,56 @@ def ping(self) -> bool: | |||||||
bool: ``True`` if the connection is successful, and ``False`` otherwise | ||||||||
""" | ||||||||
path = "/healthz" | ||||||||
response = requests.get(self._url(path), headers=self._headers) | ||||||||
response = self._request("GET", self._url(path), headers=self._headers) | ||||||||
return response.status_code == 200 | ||||||||
|
||||||||
def _url(self, path: str) -> str: | ||||||||
return self._base_url + path | ||||||||
|
||||||||
def _refresh_access_token(self): | ||||||||
"""Use the offline token to request a new access token.""" | ||||||||
self._headers.pop("Authorization", None) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (minor / thinking aloud) I'm always partial to using |
||||||||
path = "/auth/realms/platform/protocol/openid-connect/token" | ||||||||
headers = {**self._headers} | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps better to do the following instead of the pop above?
Suggested change
It feels slightly 'safer' than mutating the instance attribute above, even if it has the same effect There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This way, if the method fails midway, there are no side effects There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My thought here was that since we are in the refresh token method, we already know the existing Authorization header is invalid, so removing it from the instance headers is a housekeeping action before we replace it with a new one. Whether the method fails midway or not, any future requests are still going to fail regardless of if we've popped the header or not. |
||||||||
response = requests.post( | ||||||||
self._url(path), | ||||||||
headers=headers, | ||||||||
data={ | ||||||||
"grant_type": "refresh_token", | ||||||||
"refresh_token": self._token, | ||||||||
"client_id": "public", | ||||||||
}, | ||||||||
) | ||||||||
if response.status_code == 200: | ||||||||
access_token = response.json().get("access_token") | ||||||||
self._headers["Authorization"] = f"Bearer {access_token}" | ||||||||
else: | ||||||||
raise RequestFailedError( | ||||||||
"Authorization failed for request, please check your token provided." | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By the way, you could do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh @jswinarton, it seems that it would allow other non-error response status codes though (like |
||||||||
) | ||||||||
|
||||||||
def _request(self, method: str, path: str, headers: Dict = None, **kwargs): | ||||||||
"""Wrap all API requests with an authentication token refresh if a 401 status | ||||||||
is received from the initial request. | ||||||||
|
||||||||
Args: | ||||||||
method (str): the HTTP request method to use | ||||||||
path (str): path of the endpoint to use | ||||||||
headers (dict): dictionary containing the headers of the request | ||||||||
|
||||||||
Returns: | ||||||||
requests.Response: the response received for the sent request | ||||||||
""" | ||||||||
Comment on lines
+338
to
+349
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, this explains my question above! 💯 |
||||||||
headers = headers or {} | ||||||||
request_headers = {**headers, **self._headers} | ||||||||
response = requests.request(method, path, headers=request_headers, **kwargs) | ||||||||
if response.status_code == 401: | ||||||||
corvust marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
# Refresh the access_token and retry the request | ||||||||
self._refresh_access_token() | ||||||||
request_headers = {**headers, **self._headers} | ||||||||
response = requests.request(method, path, headers=request_headers, **kwargs) | ||||||||
return response | ||||||||
|
||||||||
@staticmethod | ||||||||
def _format_error_message(response: requests.Response) -> str: | ||||||||
body = response.json() | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my own understanding; what does this change do?Figured it out!