Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing extra headers to Connection.request #36

Merged
merged 8 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

### New features since last release

* Extra request headers may be passed to `Connection.request()`.
[(#36)](https://github.com/XanaduAI/xanadu-cloud-client/pull/36)

* Job lists can now be filtered by status.
[(#30)](https://github.com/XanaduAI/xanadu-cloud-client/pull/30)

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ build==0.7.0
isort[colors]==5.9.3
pylint==2.11.1
pytest-cov==3.0.0
responses==0.14.0
responses==0.22.0
wheel==0.37.1
31 changes: 31 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests
import responses
from requests.exceptions import HTTPError, RequestException
from responses import matchers

import xcc

Expand Down Expand Up @@ -271,6 +272,36 @@ def mock_request(*args, **kwargs):
with pytest.raises(RequestException, match=r"Failed to resolve hostname 'test.xanadu.ai'"):
connection.request("GET", "/healthz")

@responses.activate
def test_request_headers(self, connection):
"""Test that request() passes the headers attribute to _request if the headers argument
is not provided."""
brownj85 marked this conversation as resolved.
Show resolved Hide resolved

responses.add(
url=connection.url("path"),
method="POST",
status=200,
match=(matchers.header_matcher(connection.headers),),
)

assert connection.request(method="POST", path="path").status_code == 200
brownj85 marked this conversation as resolved.
Show resolved Hide resolved

@responses.activate
@pytest.mark.parametrize("extra_headers", [{"X-Test": "data"}, {}])
def test_request_extra_headers(self, connection, extra_headers):
"""Tests that request() passes the combined headers from the headers attribute
and the headers argument to the _request method."""
brownj85 marked this conversation as resolved.
Show resolved Hide resolved
responses.add(
url=connection.url("path"),
method="POST",
status=200,
match=(matchers.header_matcher({**connection.headers, **extra_headers}),),
)

assert (
connection.request(method="POST", path="path", headers=extra_headers).status_code == 200
)
brownj85 marked this conversation as resolved.
Show resolved Hide resolved

@responses.activate
def test_update_access_token_success(self, connection):
"""Tests that the access token of a connection can be updated."""
Expand Down
13 changes: 11 additions & 2 deletions xcc/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,15 @@ def ping(self) -> requests.Response:
"""
return self.request(method="GET", path="/healthz")

def request(self, method: str, path: str, **kwargs) -> requests.Response:
def request(
self, method: str, path: str, *, headers: Optional[Dict[str, str]] = None, **kwargs
) -> requests.Response:
brownj85 marked this conversation as resolved.
Show resolved Hide resolved
"""Sends an HTTP request to the Xanadu Cloud.

Args:
method (str): HTTP request method
path (str): HTTP request path
headers (Mapping[str, str]): Extra headers to pass to request
brownj85 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: optional arguments to pass to :func:`requests.request()`

Returns:
Expand All @@ -235,7 +238,12 @@ def request(self, method: str, path: str, **kwargs) -> requests.Response:
"""
url = self.url(path)

response = self._request(method=method, url=url, headers=self.headers, **kwargs)
if headers:
headers = {**self.headers, **headers}
else:
headers = self.headers

response = self._request(method=method, url=url, headers=headers, **kwargs)

if response.status_code == 401:
self.update_access_token()
Expand Down Expand Up @@ -349,6 +357,7 @@ def _request(self, method: str, url: str, **kwargs) -> requests.Response:
"""
try:
timeout = kwargs.pop("timeout", 10)

return requests.request(method=method, url=url, timeout=timeout, **kwargs)

except requests.exceptions.Timeout as exc:
Expand Down