Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing extra headers to Connection.request #36

Merged
merged 8 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

### New features since last release

* Extra request headers may be passed to `Connection.request()`.
[(#36)](https://github.com/XanaduAI/xanadu-cloud-client/pull/36)

* Job lists can now be filtered by status.
[(#30)](https://github.com/XanaduAI/xanadu-cloud-client/pull/30)

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ build==0.7.0
isort[colors]==5.9.3
pylint==2.11.1
pytest-cov==3.0.0
responses==0.14.0
responses==0.22.0
wheel==0.37.1
27 changes: 27 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests
import responses
from requests.exceptions import HTTPError, RequestException
from responses import matchers

import xcc

Expand Down Expand Up @@ -271,6 +272,32 @@ def mock_request(*args, **kwargs):
with pytest.raises(RequestException, match=r"Failed to resolve hostname 'test.xanadu.ai'"):
connection.request("GET", "/healthz")

@responses.activate
def test_request_headers(self, connection):
"""Tests that the correct headers are passed when the headers argument is not provided."""

responses.add(
url=connection.url("path"),
method="POST",
status=200,
match=(matchers.header_matcher(connection.headers),),
)

connection.request(method="POST", path="path")

@responses.activate
@pytest.mark.parametrize("extra_headers", [{"X-Test": "data"}, {}])
def test_request_extra_headers(self, connection, extra_headers):
"""Tests that the correct headers are passed when the headers argument is provided."""
responses.add(
url=connection.url("path"),
method="POST",
status=200,
match=(matchers.header_matcher({**connection.headers, **extra_headers}),),
)

connection.request(method="POST", path="path", headers=extra_headers)

@responses.activate
def test_update_access_token_success(self, connection):
"""Tests that the access token of a connection can be updated."""
Expand Down
13 changes: 11 additions & 2 deletions xcc/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,15 @@ def ping(self) -> requests.Response:
"""
return self.request(method="GET", path="/healthz")

def request(self, method: str, path: str, **kwargs) -> requests.Response:
def request(
self, method: str, path: str, *, headers: Optional[Dict[str, str]] = None, **kwargs
) -> requests.Response:
brownj85 marked this conversation as resolved.
Show resolved Hide resolved
"""Sends an HTTP request to the Xanadu Cloud.

Args:
method (str): HTTP request method
path (str): HTTP request path
headers (Mapping[str, str]): extra headers to pass to the request
**kwargs: optional arguments to pass to :func:`requests.request()`

Returns:
Expand All @@ -235,7 +238,12 @@ def request(self, method: str, path: str, **kwargs) -> requests.Response:
"""
url = self.url(path)

response = self._request(method=method, url=url, headers=self.headers, **kwargs)
if headers:
headers = {**self.headers, **headers}
else:
headers = self.headers

response = self._request(method=method, url=url, headers=headers, **kwargs)

if response.status_code == 401:
self.update_access_token()
Expand Down Expand Up @@ -349,6 +357,7 @@ def _request(self, method: str, url: str, **kwargs) -> requests.Response:
"""
try:
timeout = kwargs.pop("timeout", 10)

return requests.request(method=method, url=url, timeout=timeout, **kwargs)

except requests.exceptions.Timeout as exc:
Expand Down