Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Async support #131

Closed
csn-eric-cheung opened this issue Feb 15, 2022 · 15 comments · Fixed by #254
Closed

[FEATURE] Async support #131

csn-eric-cheung opened this issue Feb 15, 2022 · 15 comments · Fixed by #254
Labels
enhancement New feature or request

Comments

@csn-eric-cheung
Copy link

Is the async feature stable to use? I tried following code but get error "AttributeError: 'AWS4Auth' object has no attribute 'encode'"

async_search = AsyncOpenSearch(
    hosts = [{'host': host, 'port': 443}],
    http_auth = awsauth,
    use_ssl = True,
    verify_certs = True,
    transport_class=AsyncTransport,
)

async_search.search(sth...)
@csn-eric-cheung csn-eric-cheung added enhancement New feature or request untriaged Need triage labels Feb 15, 2022
@VijayanB
Copy link
Member

@Shivamdhar can you verify do you see any issue with AyncOpenSearch client?

@atishbits
Copy link

atishbits commented Mar 17, 2022

+1

Attempting to run search using awsauth with AsyncOpenSearch results in below error.

es_response = await es_client.search(index="test-index", body=some_search_payload)
File "/usr/local/lib/python3.10/site-packages/elasticsearch/_async/client/init.py", line 1676, in search
return await self.transport.perform_request(
File "/usr/local/lib/python3.10/site-packages/elasticsearch/_async/transport.py", line 296, in perform_request
status, headers, data = await connection.perform_request(
TypeError: object tuple can't be used in 'await' expression

code snippet:

 async def run_search(some_search_payload)
    session = boto3.Session()
    credentials = session.get_credentials()
    awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, session.region_name, 'es', 
    session_token=credentials.token)
    es_client = AsyncOpenSearch(
        [host],
        http_auth=awsauth,
        port= 443,
        timeout=60,
        max_retries=3,
        retry_on_timeout=True,
        connection_class=RequestsHttpConnection,
    )
    es_response = await es_client.search(index="test-index", body=some_search_payload)

cc: @VijayanB @Shivamdhar

@Shivamdhar
Copy link
Contributor

This will require some code changes here: https://github.com/opensearch-project/opensearch-py/blob/main/opensearchpy/_async/http_aiohttp.py#L148-L151, looks like async doesn't support custom auth as of now. @atishbits / @csn-eric-cheung , please feel free to contribute.

@atishbits
Copy link

atishbits commented Apr 18, 2022

@Shivamdhar - Haven't had a chance to make the suggested code changes.

Meanwhile, here's a work around (at the application code level) to ensure that requests are signed with AWS headers: https://stackoverflow.com/a/69272472/940154

@jpllana
Copy link

jpllana commented Apr 27, 2022

@atishbits - Any update on this feature? BTW the work around that you mentioned didn't work for me.

@nadobando
Copy link

nadobando commented May 20, 2022

I had this working workaround thanks to samuelcolvin/aioaws

the AIOHttpConnection uses aiohttp.ClientSession so I created a new class for AIOHttpConnection class and a new aiohttp.ClientSession to sign the requests, enjoy 😉

import base64
import hashlib
import hmac
from asyncio import get_running_loop
from binascii import hexlify
from datetime import datetime, timezone
from functools import reduce
from ssl import SSLContext
from types import SimpleNamespace
from typing import (
    Any,
    Dict,
    Literal,
    Mapping,
    Optional,
    Tuple,
    Union,
    Iterable,
)
from urllib.parse import quote as url_quote

import aiohttp
from aiohttp import BasicAuth, ClientTimeout, Fingerprint, ClientResponse
from aiohttp.helpers import sentinel
from aiohttp.typedefs import StrOrURL, LooseCookies, LooseHeaders
from botocore.credentials import Credentials
from opensearchpy import AIOHttpConnection
from opensearchpy._async.http_aiohttp import OpenSearchClientResponse
from yarl import URL

_HTTP_METHODS = Literal[
    "HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"
]
_AWS_AUTH_REQUEST = "aws4_request"
_CONTENT_TYPE = "application/json"
_AUTH_ALGORITHM = "AWS4-HMAC-SHA256"


class AWSSignedSession(aiohttp.ClientSession):
    service = 'es'

    def __init__(self, aws_credentials: Credentials, aws_region: str, **kwargs):
        super().__init__(**kwargs)
        self.aws_region = aws_region
        self.aws_credentials = aws_credentials

    def _auth_headers(
            self,
            method: _HTTP_METHODS,
            url: str,
            *,
            data: Optional[bytes] = None,
            content_type: Optional[str] = _CONTENT_TYPE,
    ) -> Dict[str, str]:
        now = datetime.utcnow().replace(tzinfo=timezone.utc)
        data = data or b""

        # WARNING! order is important here, headers need to be in alphabetical order
        headers = {
            "content-md5": base64.b64encode(hashlib.md5(data).digest()).decode(),
            "content-type": content_type,
            "host": url.host,
            "x-amz-date": self._aws4_x_amz_date(now),
        }

        payload_sha256_hash = hashlib.sha256(data).hexdigest()
        signed_headers, signature = self._aws4_signature(
            now, method, url, headers, payload_sha256_hash
        )
        credential = self._aws4_credential(now)
        authorization_header = f"{_AUTH_ALGORITHM} Credential={credential},SignedHeaders={signed_headers},Signature={signature}"
        headers.update(
            {
                "authorization": authorization_header,
                "x-amz-content-sha256": payload_sha256_hash,
            }
        )
        return headers

    def _aws4_signature(
            self,
            dt: datetime,
            method: _HTTP_METHODS,
            url: URL,
            headers: Dict[str, str],
            payload_hash: str,
    ) -> Tuple[str, str]:
        header_keys = sorted(headers)
        signed_headers = ";".join(header_keys)
        canonical_request_parts = (
            method,
            url_quote(url.path),
            url.query_string,
            "".join(f"{k}:{headers[k]}\n" for k in header_keys),
            signed_headers,
            payload_hash,
        )
        canonical_request = "\n".join(canonical_request_parts)
        string_to_sign_parts = (
            _AUTH_ALGORITHM,
            self._aws4_x_amz_date(dt),
            self._aws4_scope(dt),
            hashlib.sha256(canonical_request.encode()).hexdigest(),
        )
        string_to_sign = "\n".join(string_to_sign_parts)
        return signed_headers, self._aws4_sign_string(string_to_sign, dt)

    def _aws4_sign_string(self, string_to_sign: str, dt: datetime) -> str:
        key_parts = (
            b"AWS4" + self.aws_credentials.secret_key.encode(),
            self._aws4_date_stamp(dt),
            self.aws_region,
            self.service,
            _AWS_AUTH_REQUEST,
            string_to_sign,
        )
        signature_bytes: bytes = reduce(self._aws4_reduce_signature, key_parts)  # type: ignore
        return hexlify(signature_bytes).decode()

    def _aws4_scope(self, dt: datetime) -> str:
        return f"{self._aws4_date_stamp(dt)}/{self.aws_region}/{self.service}/{_AWS_AUTH_REQUEST}"

    def _aws4_credential(self, dt: datetime) -> str:
        return f"{self.aws_credentials.access_key}/{self._aws4_scope(dt)}"

    @staticmethod
    def _aws4_date_stamp(dt: datetime) -> str:
        return dt.strftime("%Y%m%d")

    @staticmethod
    def _aws4_x_amz_date(dt: datetime) -> str:
        return dt.strftime("%Y%m%dT%H%M%SZ")

    @staticmethod
    def _aws4_reduce_signature(key: bytes, msg: str) -> bytes:
        return hmac.new(key, msg.encode(), hashlib.sha256).digest()

    async def _request(
            self,
            method: str,
            str_or_url: StrOrURL,
            *,
            params: Optional[Mapping[str, str]] = None,
            data: Any = None,
            json: Any = None,
            cookies: Optional[LooseCookies] = None,
            headers: Optional[LooseHeaders] = None,
            skip_auto_headers: Optional[Iterable[str]] = None,
            auth: Optional[BasicAuth] = None,
            allow_redirects: bool = True,
            max_redirects: int = 10,
            compress: Optional[str] = None,
            chunked: Optional[bool] = None,
            expect100: bool = False,
            raise_for_status: Optional[bool] = None,
            read_until_eof: bool = True,
            proxy: Optional[StrOrURL] = None,
            proxy_auth: Optional[BasicAuth] = None,
            timeout: Union[ClientTimeout, object] = sentinel,
            verify_ssl: Optional[bool] = None,
            fingerprint: Optional[bytes] = None,
            ssl_context: Optional[SSLContext] = None,
            ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None,
            proxy_headers: Optional[LooseHeaders] = None,
            trace_request_ctx: Optional[SimpleNamespace] = None,
            read_bufsize: Optional[int] = None,
    ) -> ClientResponse:
        headers.update(self._auth_headers(method, str_or_url, data=data))
        return await super(AWSSignedSession, self)._request(
            method,
            str_or_url,
            params=params,
            data=data,
            json=json,
            cookies=cookies,
            headers=headers,
            skip_auto_headers=skip_auto_headers,
            auth=auth,
            allow_redirects=allow_redirects,
            max_redirects=max_redirects,
            compress=compress,
            chunked=chunked,
            expect100=expect100,
            raise_for_status=raise_for_status,
            read_until_eof=read_until_eof,
            proxy=proxy,
            proxy_auth=proxy_auth,
            proxy_headers=proxy_headers,
            timeout=timeout,
            verify_ssl=verify_ssl,
            fingerprint=fingerprint,
            ssl_context=ssl_context,
            ssl=ssl,
            trace_request_ctx=trace_request_ctx,
            read_bufsize=read_bufsize
        )


class AWSAsyncConnection(AIOHttpConnection):
    service = "es"

    def __init__(self, aws_credentials: Credentials, aws_region: str, **kwargs: Any):
        super().__init__(**kwargs)
        self.aws_credentials = aws_credentials
        self.aws_region = aws_region

    async def _create_aiohttp_session(self):
        if self.loop is None:
            self.loop = get_running_loop()
        self.session = AWSSignedSession(
            aws_credentials=self.aws_credentials,
            aws_region=self.aws_region,
            headers=self.headers,
            skip_auto_headers=("accept", "accept-encoding"),
            auto_decompress=True,
            loop=self.loop,
            cookie_jar=aiohttp.DummyCookieJar(),
            response_class=OpenSearchClientResponse,
            connector=aiohttp.TCPConnector(
                limit=self._limit, use_dns_cache=True, ssl=self._ssl_context
            ),
        )

async def run():
    session= boto3.Session()
    search = AsyncOpenSearch(
        hosts=endpoint,
        connection_class=AWSAsyncConnection,
        aws_region=session.region_name,
        aws_credentials=session.get_credentials(),
        port=443,
        use_ssl=True,
        verify_certs=True,
    )
    
    print(await search.cat.indices())
    await search.close()

if __name__ == "__main__":
    asyncio.run(run())

@dacevedo12
Copy link

dacevedo12 commented Jul 27, 2022

Based on the previous comments, I modified it a bit and ended up with:

from boto3 import Session
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
from opensearchpy import AIOHttpConnection, AsyncOpenSearch
from opensearchpy.helpers.signer import OPENSEARCH_SERVICE
from typing import Any, Optional
from urllib.parse import urlencode

class AsyncAWSConnection(AIOHttpConnection):
    def __init__(
        self, aws_credentials: Credentials, aws_region: str, **kwargs: Any
    ) -> None:
        super().__init__(**kwargs)
        self.aws_credentials = aws_credentials
        self.aws_region = aws_region
        self.signer = SigV4Auth(
            self.aws_credentials,
            OPENSEARCH_SERVICE,
            self.aws_region,
        )

    async def perform_request(
        self,
        method: str,
        url: str,
        params: Optional[dict[str, Any]] = None,
        body: Optional[bytes] = None,
        timeout: Optional[float] = None,
        ignore: tuple[int, ...] = (),
        headers: Optional[dict[str, str]] = None,
    ) -> tuple[int, dict[str, str], str]:
        headers_ = headers if headers else {}
        aws_body = (
            self._gzip_compress(body) if self.http_compress and body else body
        )
        query_string = "?" + urlencode(params) if params else ""
        aws_request = AWSRequest(
            data=aws_body,
            headers=headers_,
            method=method.upper(),
            url="".join([self.url_prefix, self.host, url, query_string]),
        )

        self.signer.add_auth(aws_request)
        signed_headers = dict(aws_request.headers.items())
        all_headers = {**headers_, **signed_headers}

        return await super().perform_request(
            method, url, params, body, timeout, ignore, all_headers
        )

SESSION = Session()
CLIENT = AsyncOpenSearch(
    aws_credentials=SESSION.get_credentials(),
    aws_region=SESSION.region_name,
    connection_class=AsyncAWSConnection,
    hosts=["your_host_goes_here"],
    http_compress=True,
)

If you agree, I can proceed to open a PR so we get this moving toward a merge

@dblock
Copy link
Member

dblock commented Jul 27, 2022

My only concern with the example above is that a bunch of vendor-specific stuff is in non-vendor-specific classes (e.g. aws_credentials and aws_region are options to AsyncOpenSearch). Those should be split into generic classes and AWS-vendor-specific implementations for when talking to Amazon OpenSearch.

@niffler92
Copy link

niffler92 commented Aug 21, 2022

Any updates on this? Would love to use Async operations from opensearch-py.

@dblock
Copy link
Member

dblock commented Aug 25, 2022

@niffler92 Want to help turn this into a PR?

@sharp-pixel
Copy link

Is anyone working on this? This is important to support IAM auth for OpenSearch Benchmark.

@dblock
Copy link
Member

dblock commented Oct 26, 2022

@sharp-pixel AFAIK nobody is working on this, please feel free to pick it up

@sharp-pixel
Copy link

Alright, I will see what I can do on this.

@wbeckler wbeckler removed the untriaged Need triage label Nov 3, 2022
@harshavamsi
Copy link
Collaborator

@saimedhi can you take a look?

@saimedhi
Copy link
Collaborator

Okay, I will work on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.