diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 83a3b3e..b565361 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -2,6 +2,9 @@ ### New features since last release +* Extra request headers may be passed to `Connection.request()`. + [(#36)](https://github.com/XanaduAI/xanadu-cloud-client/pull/36) + * Job lists can now be filtered by status. [(#30)](https://github.com/XanaduAI/xanadu-cloud-client/pull/30) diff --git a/requirements-dev.txt b/requirements-dev.txt index 45088be..cf9fdc3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,5 +3,5 @@ build==0.7.0 isort[colors]==5.9.3 pylint==2.11.1 pytest-cov==3.0.0 -responses==0.14.0 +responses==0.22.0 wheel==0.37.1 diff --git a/tests/test_connection.py b/tests/test_connection.py index d4147b2..89fe53b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -9,6 +9,7 @@ import requests import responses from requests.exceptions import HTTPError, RequestException +from responses import matchers import xcc @@ -271,6 +272,32 @@ def mock_request(*args, **kwargs): with pytest.raises(RequestException, match=r"Failed to resolve hostname 'test.xanadu.ai'"): connection.request("GET", "/healthz") + @responses.activate + def test_request_headers(self, connection): + """Tests that the correct headers are passed when the headers argument is not provided.""" + + responses.add( + url=connection.url("path"), + method="POST", + status=200, + match=(matchers.header_matcher(connection.headers),), + ) + + connection.request(method="POST", path="path") + + @responses.activate + @pytest.mark.parametrize("extra_headers", [{"X-Test": "data"}, {}]) + def test_request_extra_headers(self, connection, extra_headers): + """Tests that the correct headers are passed when the headers argument is provided.""" + responses.add( + url=connection.url("path"), + method="POST", + status=200, + match=(matchers.header_matcher({**connection.headers, **extra_headers}),), + ) + + connection.request(method="POST", path="path", headers=extra_headers) + @responses.activate def test_update_access_token_success(self, connection): """Tests that the access token of a connection can be updated.""" diff --git a/xcc/connection.py b/xcc/connection.py index 80d34c8..632b259 100644 --- a/xcc/connection.py +++ b/xcc/connection.py @@ -210,12 +210,15 @@ def ping(self) -> requests.Response: """ return self.request(method="GET", path="/healthz") - def request(self, method: str, path: str, **kwargs) -> requests.Response: + def request( + self, method: str, path: str, *, headers: Optional[Dict[str, str]] = None, **kwargs + ) -> requests.Response: """Sends an HTTP request to the Xanadu Cloud. Args: method (str): HTTP request method path (str): HTTP request path + headers (Mapping[str, str]): extra headers to pass to the request **kwargs: optional arguments to pass to :func:`requests.request()` Returns: @@ -235,7 +238,12 @@ def request(self, method: str, path: str, **kwargs) -> requests.Response: """ url = self.url(path) - response = self._request(method=method, url=url, headers=self.headers, **kwargs) + if headers: + headers = {**self.headers, **headers} + else: + headers = self.headers + + response = self._request(method=method, url=url, headers=headers, **kwargs) if response.status_code == 401: self.update_access_token() @@ -349,6 +357,7 @@ def _request(self, method: str, url: str, **kwargs) -> requests.Response: """ try: timeout = kwargs.pop("timeout", 10) + return requests.request(method=method, url=url, timeout=timeout, **kwargs) except requests.exceptions.Timeout as exc: