From 470882603e1d497cb3d9652389828e363ee51dd4 Mon Sep 17 00:00:00 2001 From: dblock Date: Fri, 17 Nov 2023 14:29:04 -0500 Subject: [PATCH] Fix Amazon OpenSearch Serverless integration with LangChain. Signed-off-by: dblock --- CHANGELOG.md | 1 + opensearchpy/helpers/signer.py | 2 ++ .../test_connection/test_requests_http_connection.py | 2 ++ .../test_connection/test_urllib3_http_connection.py | 2 ++ 4 files changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93658586..69bdda60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Removed ### Fixed - Fix `TypeError` on `parallel_bulk` ([#601](https://github.com/opensearch-project/opensearch-py/pull/601)) +- Fix Amazon OpenSearch Serverless integration with LangChain ([#603](https://github.com/opensearch-project/opensearch-py/pull/603)) ### Security ## [2.4.1] diff --git a/opensearchpy/helpers/signer.py b/opensearchpy/helpers/signer.py index 930b8d25..43b5ee3c 100644 --- a/opensearchpy/helpers/signer.py +++ b/opensearchpy/helpers/signer.py @@ -78,6 +78,7 @@ class RequestsAWSV4SignerAuth(requests.auth.AuthBase): def __init__(self, credentials, region, service: str = "es") -> None: # type: ignore self.signer = AWSV4Signer(credentials, region, service) + self.service = service # tools like LangChain rely on this, see https://github.com/opensearch-project/opensearch-py/issues/600 def __call__(self, request): # type: ignore return self._sign_request(request) # type: ignore @@ -133,6 +134,7 @@ class AWSV4SignerAuth(RequestsAWSV4SignerAuth): class Urllib3AWSV4SignerAuth(Callable): # type: ignore def __init__(self, credentials, region, service: str = "es") -> None: # type: ignore self.signer = AWSV4Signer(credentials, region, service) + self.service = service # tools like LangChain rely on this, see https://github.com/opensearch-project/opensearch-py/issues/600 def __call__(self, method: str, url: str, body: Any) -> Dict[str, str]: return self.signer.sign(method, url, body) diff --git a/test_opensearchpy/test_connection/test_requests_http_connection.py b/test_opensearchpy/test_connection/test_requests_http_connection.py index bdfb97d7..62adf39f 100644 --- a/test_opensearchpy/test_connection/test_requests_http_connection.py +++ b/test_opensearchpy/test_connection/test_requests_http_connection.py @@ -460,6 +460,7 @@ def test_aws_signer_as_http_auth(self) -> None: from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth auth = RequestsAWSV4SignerAuth(self.mock_session(), region) + self.assertEqual(auth.service, "es") con = RequestsHttpConnection(http_auth=auth) prepared_request = requests.Request("GET", "http://localhost").prepare() auth(prepared_request) @@ -478,6 +479,7 @@ def test_aws_signer_when_service_is_specified(self) -> None: from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth auth = RequestsAWSV4SignerAuth(self.mock_session(), region, service) + self.assertEqual(auth.service, service) con = RequestsHttpConnection(http_auth=auth) prepared_request = requests.Request("GET", "http://localhost").prepare() auth(prepared_request) diff --git a/test_opensearchpy/test_connection/test_urllib3_http_connection.py b/test_opensearchpy/test_connection/test_urllib3_http_connection.py index e22e943f..971a3254 100644 --- a/test_opensearchpy/test_connection/test_urllib3_http_connection.py +++ b/test_opensearchpy/test_connection/test_urllib3_http_connection.py @@ -192,6 +192,7 @@ def test_aws_signer_as_http_auth_adds_headers(self, mock_open: Any) -> None: from opensearchpy.helpers.signer import Urllib3AWSV4SignerAuth auth = Urllib3AWSV4SignerAuth(self.mock_session(), "us-west-2") + self.assertEqual(auth.service, "es") con = Urllib3HttpConnection(http_auth=auth, headers={"x": "y"}) con.perform_request("GET", "/") self.assertEqual(mock_open.call_count, 1) @@ -249,6 +250,7 @@ def test_aws_signer_when_service_is_specified(self) -> None: from opensearchpy.helpers.signer import Urllib3AWSV4SignerAuth auth = Urllib3AWSV4SignerAuth(self.mock_session(), region, service) + self.assertEqual(auth.service, service) headers = auth("GET", "http://localhost", None) self.assertIn("Authorization", headers) self.assertIn("X-Amz-Date", headers)