diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bbf825a..c914f65a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Added 'point in time' APIs to the pyi files in sync and async client ([#378](https://github.com/opensearch-project/opensearch-py/pull/378)) - Added MacOS and Windows CI workflows ([#390](https://github.com/opensearch-project/opensearch-py/pull/390)) - Compatibility with OpenSearch 2.1.0 - 2.6.0 ([#381](https://github.com/opensearch-project/opensearch-py/pull/381)) +- Added 'allow_redirects' parameter in perform_request function for RequestsHttpConnection ([#401](https://github.com/opensearch-project/opensearch-py/pull/401)) ### Changed - Upgrading pytest-asyncio to latest version - 0.21.0 ([#339](https://github.com/opensearch-project/opensearch-py/pull/339)) - Fixed flaky CI tests by replacing httpbin with a simple http_server ([#395](https://github.com/opensearch-project/opensearch-py/pull/395)) diff --git a/opensearchpy/connection/http_requests.py b/opensearchpy/connection/http_requests.py index 316bf8ef..e0b6d143 100644 --- a/opensearchpy/connection/http_requests.py +++ b/opensearchpy/connection/http_requests.py @@ -155,7 +155,15 @@ def __init__( ) def perform_request( - self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None + self, + method, + url, + params=None, + body=None, + timeout=None, + allow_redirects=True, + ignore=(), + headers=None, ): url = self.base_url + url headers = headers or {} @@ -173,7 +181,10 @@ def perform_request( settings = self.session.merge_environment_settings( prepared_request.url, {}, None, None, None ) - send_kwargs = {"timeout": timeout or self.timeout} + send_kwargs = { + "timeout": timeout or self.timeout, + "allow_redirects": allow_redirects, + } send_kwargs.update(settings) try: response = self.session.send(prepared_request, **send_kwargs) diff --git a/test_opensearchpy/TestHttpServer.py b/test_opensearchpy/TestHttpServer.py index 4c35f621..7460aad1 100644 --- a/test_opensearchpy/TestHttpServer.py +++ b/test_opensearchpy/TestHttpServer.py @@ -14,9 +14,16 @@ class TestHTTPRequestHandler(BaseHTTPRequestHandler): def do_GET(self): - self.send_response(200) headers = self.headers - self.send_header("Content-type", "application/json") + + if self.path == "/redirect": + new_location = "http://localhost:8090" + self.send_response(302) + self.send_header("Location", new_location) + else: + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() Headers = {} diff --git a/test_opensearchpy/test_connection.py b/test_opensearchpy/test_connection.py index 8d501ffe..c2480946 100644 --- a/test_opensearchpy/test_connection.py +++ b/test_opensearchpy/test_connection.py @@ -1020,6 +1020,46 @@ def test_requests_connection_error(self): conn.perform_request("GET", "/") +@pytest.mark.skipif( + sys.version_info < (3, 0), + reason="http_server is only available from python 3.x", +) +class TestRequestsConnectionRedirect: + @classmethod + def setup_class(cls): + # Start servers + cls.server1 = TestHTTPServer(port=8080) + cls.server1.start() + cls.server2 = TestHTTPServer(port=8090) + cls.server2.start() + + @classmethod + def teardown_class(cls): + # Stop servers + cls.server2.stop() + cls.server1.stop() + + # allow_redirects = False + def test_redirect_failure_when_allow_redirect_false(self): + conn = RequestsHttpConnection("localhost", port=8080, use_ssl=False, timeout=60) + with pytest.raises(TransportError) as e: + conn.perform_request("GET", "/redirect", allow_redirects=False) + assert e.value.status_code == 302 + + # allow_redirects = True (Default) + def test_redirect_success_when_allow_redirect_true(self): + conn = RequestsHttpConnection("localhost", port=8080, use_ssl=False, timeout=60) + user_agent = conn._get_default_user_agent() + status, headers, data = conn.perform_request("GET", "/redirect") + assert status == 200 + data = json.loads(data) + assert data["headers"] == { + "Host": "localhost:8090", + "Accept-Encoding": "identity", + "User-Agent": user_agent, + } + + def test_default_connection_is_returned_by_default(): c = connections.Connections()