Skip to content

Commit

Permalink
Modify transport init parameters
Browse files Browse the repository at this point in the history
Adding connect_args parameter to be able to provide any argument to the ws_connect method

Removing the following parameters (they can now be provided in the connect_args dict):
 - autoclose
 - autoping
 - compress
 - max_msg_size
 - verify_ssl
 - method

Renaming protocols to subprotocols to be more similar to the websockets transport
  • Loading branch information
leszekhanusz committed Jul 15, 2024
1 parent 277fd5d commit a8b276d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 30 deletions.
51 changes: 24 additions & 27 deletions gql/transport/aiohttp_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)

import aiohttp
from aiohttp import BasicAuth, Fingerprint, WSMsgType, hdrs
from aiohttp import BasicAuth, Fingerprint, WSMsgType
from aiohttp.typedefs import LooseHeaders, StrOrURL
from graphql import DocumentNode, ExecutionResult, print_ast
from multidict import CIMultiDictProxy
Expand Down Expand Up @@ -110,23 +110,17 @@ def __init__(
self,
url: StrOrURL,
*,
method: str = hdrs.METH_GET,
protocols: Collection[str] = (),
autoclose: bool = True,
autoping: bool = True,
subprotocols: Optional[Collection[str]] = None,
heartbeat: Optional[float] = None,
auth: Optional[BasicAuth] = None,
origin: Optional[str] = None,
params: Optional[Mapping[str, str]] = None,
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
proxy_headers: Optional[LooseHeaders] = None,
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
ssl_context: Optional[SSLContext] = None,
verify_ssl: Optional[bool] = True,
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
websocket_close_timeout: float = 10.0,
receive_timeout: Optional[float] = None,
ssl_close_timeout: Optional[Union[int, float]] = 10,
Expand All @@ -139,32 +133,31 @@ def __init__(
pong_timeout: Optional[Union[int, float]] = None,
answer_pings: bool = True,
client_session_args: Optional[Dict[str, Any]] = None,
connect_args: Dict[str, Any] = {},
) -> None:
self.url: StrOrURL = url
self.headers: Optional[LooseHeaders] = headers
self.auth: Optional[BasicAuth] = auth
self.autoclose: bool = autoclose
self.autoping: bool = autoping
self.compress: int = compress
self.heartbeat: Optional[float] = heartbeat
self.max_msg_size: int = max_msg_size
self.method: str = method
self.auth: Optional[BasicAuth] = auth
self.origin: Optional[str] = origin
self.params: Optional[Mapping[str, str]] = params
self.protocols: Collection[str] = protocols
self.headers: Optional[LooseHeaders] = headers

self.proxy: Optional[StrOrURL] = proxy
self.proxy_auth: Optional[BasicAuth] = proxy_auth
self.proxy_headers: Optional[LooseHeaders] = proxy_headers
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout

self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl
self.ssl_context: Optional[SSLContext] = ssl_context

self.websocket_close_timeout: float = websocket_close_timeout
self.receive_timeout: Optional[float] = receive_timeout

self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
self.connect_timeout: Optional[Union[int, float]] = connect_timeout
self.close_timeout: Optional[Union[int, float]] = close_timeout
self.ack_timeout: Optional[Union[int, float]] = ack_timeout
self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout
self.verify_ssl: Optional[bool] = verify_ssl

self.init_payload: Dict[str, Any] = init_payload

# We need to set an event loop here if there is none
Expand Down Expand Up @@ -221,12 +214,15 @@ def __init__(
"""pong_received is an asyncio Event which will fire each time
a pong is received with the graphql-ws protocol"""

self.supported_subprotocols: Collection[str] = protocols or (
self.supported_subprotocols: Collection[str] = subprotocols or (
self.APOLLO_SUBPROTOCOL,
self.GRAPHQLWS_SUBPROTOCOL,
)

self.close_exception: Optional[Exception] = None

self.client_session_args = client_session_args
self.connect_args = connect_args

def _parse_answer_graphqlws(
self, answer: Dict[str, Any]
Expand Down Expand Up @@ -782,28 +778,29 @@ async def connect(self) -> None:
if self.websocket is None and not self._connecting:
self._connecting = True

connect_args: Dict[str, Any] = {}

# Adding custom parameters passed from init
if self.connect_args:
connect_args.update(self.connect_args)

try:
self.websocket = await self.session.ws_connect(
method=self.method,
url=self.url,
headers=self.headers,
auth=self.auth,
autoclose=self.autoclose,
autoping=self.autoping,
compress=self.compress,
heartbeat=self.heartbeat,
max_msg_size=self.max_msg_size,
origin=self.origin,
params=self.params,
protocols=self.supported_subprotocols,
proxy=self.proxy,
proxy_auth=self.proxy_auth,
proxy_headers=self.proxy_headers,
timeout=self.websocket_close_timeout,
receive_timeout=self.receive_timeout,
ssl=self.ssl,
ssl_context=None,
timeout=self.websocket_close_timeout,
verify_ssl=self.verify_ssl,
**connect_args,
)
finally:
self._connecting = False
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ async def client_and_aiohttp_websocket_graphql_server(graphqlws_server):
url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}"
sample_transport = AIOHTTPWebsocketsTransport(
url=url,
protocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL],
subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL],
)

async with Client(transport=sample_transport) as session:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_aiohttp_websocket_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,12 @@ async def test_aiohttp_websocket_add_extra_parameters_to_connect(event_loop, ser

url = f"ws://{server.hostname}:{server.port}/graphql"

# Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions
# Increase max payload size
transport = AIOHTTPWebsocketsTransport(
url=url,
max_msg_size=(2**21),
connect_args={
"max_msg_size": 2**21,
},
)

query = gql(query1_str)
Expand Down

0 comments on commit a8b276d

Please sign in to comment.