Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose http_session_cls in AioConfig #1102

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changes
-------

xxx
^^^^^^^^^^^^^^^^^^^
* expose configuration of ``http_session_cls`` in ``AioConfig``

2.12.1 (2024-03-04)
^^^^^^^^^^^^^^^^^^^
* fix use of proxies #1070
Expand Down
6 changes: 5 additions & 1 deletion aiobotocore/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .config import AioConfig
from .endpoint import AioEndpointCreator
from .httpsession import AIOHTTPSession
from .regions import AioEndpointRulesetResolver
from .signers import AioRequestSigner

Expand Down Expand Up @@ -67,8 +68,10 @@ def get_client_args(
# aiobotocore addition
if isinstance(client_config, AioConfig):
connector_args = client_config.connector_args
http_session_cls = client_config.http_session_cls
else:
connector_args = None
http_session_cls = AIOHTTPSession

new_config = AioConfig(connector_args, **config_kwargs)
endpoint_creator = AioEndpointCreator(event_emitter)
Expand All @@ -79,9 +82,10 @@ def get_client_args(
endpoint_url=endpoint_config['endpoint_url'],
verify=verify,
response_parser_factory=self._response_parser_factory,
timeout=(new_config.connect_timeout, new_config.read_timeout),
max_pool_connections=new_config.max_pool_connections,
http_session_cls=http_session_cls,
proxies=new_config.proxies,
timeout=(new_config.connect_timeout, new_config.read_timeout),
socket_options=socket_options,
client_cert=new_config.client_cert,
proxies_config=new_config.proxies_config,
Expand Down
7 changes: 6 additions & 1 deletion aiobotocore/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import botocore.client
from botocore.exceptions import ParamValidationError

from aiobotocore.httpsession import AIOHTTPSession


class AioConfig(botocore.client.Config):
def __init__(self, connector_args=None, **kwargs):
def __init__(
self, connector_args=None, http_session_cls=AIOHTTPSession, **kwargs
thehesiod marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(**kwargs)

self._validate_connector_args(connector_args)
self.connector_args = copy.copy(connector_args)
self.http_session_cls = http_session_cls
if not self.connector_args:
self.connector_args = dict()

Expand Down
25 changes: 25 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from botocore.exceptions import ParamValidationError, ReadTimeoutError

from aiobotocore.config import AioConfig
from aiobotocore.httpsession import AIOHTTPSession
from aiobotocore.session import AioSession, get_session
from tests.mock_server import AIOServer

Expand Down Expand Up @@ -132,3 +133,27 @@ def test_merge():
assert isinstance(new_config, AioConfig)
assert new_config is not config
assert new_config is not other_config


# Check that it's possible to specify custom http_session_cls
@pytest.mark.moto
@pytest.mark.asyncio
async def test_config_http_session_cls():
class SuccessExc(Exception):
...

class MyHttpSession(AIOHTTPSession):
async def send(self, request):
raise SuccessExc

config = AioConfig(http_session_cls=MyHttpSession)
session = AioSession()
async with AIOServer() as server, session.create_client(
's3',
config=config,
endpoint_url=server.endpoint_url,
aws_secret_access_key='xxx',
aws_access_key_id='xxx',
) as s3_client:
with pytest.raises(SuccessExc):
await s3_client.get_object(Bucket='foo', Key='bar')
Loading