Skip to content

Commit

Permalink
Merge pull request #32 from antoinetran/2023-11-29-issue31-websocketh…
Browse files Browse the repository at this point in the history
…ttpproxy

Added websocket client handler, and websocket-client module
  • Loading branch information
orweis authored Jan 25, 2024
2 parents d77bd3b + ed2ce12 commit 60497a3
Show file tree
Hide file tree
Showing 15 changed files with 225 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
pytest -v --capture=tee-sys
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,26 @@ from fastapi_websocket_rpc.logger import logging_config, LoggingModes
logging_config.set_mode(LoggingModes.UVICORN)
```

## HTTP(S) Proxy
By default, fastapi-websocket-rpc uses websockets module as websocket client handler. This does not support HTTP(S) Proxy, see https://github.com/python-websockets/websockets/issues/364 . If the ability to use a proxy is important to, another websocket client implementation can be used, e.g. websocket-client (https://websocket-client.readthedocs.io). Here is how to use it. Installation:

```
pip install websocket-client
```

Then use websocket_client_handler_cls parameter:

```python
import asyncio
from fastapi_websocket_rpc import RpcMethodsBase, WebSocketRpcClient, ProxyEnabledWebSocketClientHandler

async def run_client(uri):
async with WebSocketRpcClient(uri, RpcMethodsBase(), websocket_client_handler_cls = ProxyEnabledWebSocketClientHandler) as client:
```

Just set standard environment variables (lowercase and uppercase works): http_proxy, https_proxy, and no_proxy before running python script.


## Pull requests - welcome!
- Please include tests for new features

Expand Down
Empty file removed __init__.py
Empty file.
2 changes: 2 additions & 0 deletions fastapi_websocket_rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .rpc_methods import RpcMethodsBase, RpcUtilityMethods
from .websocket_rpc_client import WebSocketRpcClient
from .websocket_rpc_client import ProxyEnabledWebSocketClientHandler
from .websocket_rpc_client import WebSocketsClientHandler
from .websocket_rpc_endpoint import WebsocketRPCEndpoint
from .rpc_channel import RpcChannel
from .schemas import WebSocketFrameType
12 changes: 11 additions & 1 deletion fastapi_websocket_rpc/simplewebsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@ class SimpleWebSocket(ABC):
Abstract base class for all websocket related wrappers.
"""

@abstractmethod
def connect(self, uri: str, **connect_kwargs):
pass

@abstractmethod
def send(self, msg):
pass

# If return None, then it means Connection is closed, and we stop receiving and close.
@abstractmethod
def recv(self):
pass
Expand All @@ -26,6 +31,9 @@ class JsonSerializingWebSocket(SimpleWebSocket):
def __init__(self, websocket: SimpleWebSocket):
self._websocket = websocket

async def connect(self, uri: str, **connect_kwargs):
await self._websocket.connect(uri, **connect_kwargs)

def _serialize(self, msg):
return pydantic_serialize(msg)

Expand All @@ -37,8 +45,10 @@ async def send(self, msg):

async def recv(self):
msg = await self._websocket.recv()

if msg is None:
return None
return self._deserialize(msg)

async def close(self, code: int = 1000):
await self._websocket.close(code)

195 changes: 164 additions & 31 deletions fastapi_websocket_rpc/websocket_rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,146 @@

logger = get_logger("RPC_CLIENT")

try:
import websocket
except ImportError:
# Websocket-client optional module not installed.
pass

class ProxyEnabledWebSocketClientHandler (SimpleWebSocket):
"""
Handler that use https://websocket-client.readthedocs.io/en/latest module.
This implementation supports HTTP proxy, though HTTP_PROXY and HTTPS_PROXY environment variable.
This is not documented, but in code, see https://github.com/websocket-client/websocket-client/blob/master/websocket/_url.py#L163
The module is not written as coroutine: https://websocket-client.readthedocs.io/en/latest/threading.html#asyncio-library-usage, so
as a workaround, the send/recv are called in "run_in_executor" method.
TODO: remove this implementation after https://github.com/python-websockets/websockets/issues/364 is fixed and use WebSocketsClientHandler instead.
Note: the connect timeout, if not specified, is the default socket connect timeout, which could be around 2min, so a bit longer than WebSocketsClientHandler.
"""
def __init__(self):
self._websocket = None

"""
Args:
**kwargs: Additional args passed to connect
https://websocket-client.readthedocs.io/en/latest/examples.html#connection-options
https://websocket-client.readthedocs.io/en/latest/core.html#websocket._core.WebSocket.connect
"""
async def connect(self, uri: str, **connect_kwargs):
try:
self._websocket = await asyncio.get_event_loop().run_in_executor(None, websocket.create_connection, uri, **connect_kwargs)
# See https://websocket-client.readthedocs.io/en/latest/exceptions.html
except websocket._exceptions.WebSocketAddressException:
logger.info("websocket address info cannot be found")
raise
except websocket._exceptions.WebSocketBadStatusException:
logger.info("bad handshake status code")
raise
except websocket._exceptions.WebSocketConnectionClosedException:
logger.info("remote host closed the connection or some network error happened")
raise
except websocket._exceptions.WebSocketPayloadException:
logger.info(
f"WebSocket payload is invalid")
raise
except websocket._exceptions.WebSocketProtocolException:
logger.info(f"WebSocket protocol is invalid")
raise
except websocket._exceptions.WebSocketProxyException:
logger.info(f"proxy error occurred")
raise
except OSError as err:
logger.info("RPC Connection failed - %s", err)
raise
except Exception as err:
logger.exception("RPC Error")
raise

async def send(self, msg):
if self._websocket is None:
# connect must be called before.
logging.error("Websocket connect() must be called before.")
await asyncio.get_event_loop().run_in_executor(None, self._websocket.send, msg)

async def recv(self):
if self._websocket is None:
# connect must be called before.
logging.error("Websocket connect() must be called before.")
try:
msg = await asyncio.get_event_loop().run_in_executor(None, self._websocket.recv)
except websocket._exceptions.WebSocketConnectionClosedException as err:
logger.debug("Connection closed.", exc_info=True)
# websocket.WebSocketConnectionClosedException means remote host closed the connection or some network error happened
# Returning None to ensure we get out of the loop, with no Exception.
return None
return msg

async def close(self, code: int = 1000):
if self._websocket is not None:
# Case opened, we have something to close.
self._websocket.close(code)

class WebSocketsClientHandler(SimpleWebSocket):
"""
Handler that use https://websockets.readthedocs.io/en/stable module.
This implementation does not support HTTP proxy (see https://github.com/python-websockets/websockets/issues/364).
"""
def __init__(self):
self._websocket = None

"""
Args:
**kwargs: Additional args passed to connect
https://websockets.readthedocs.io/en/stable/reference/asyncio/client.html#opening-a-connection
"""
async def connect(self, uri: str, **connect_kwargs):
try:
self._websocket = await websockets.connect(uri, **connect_kwargs)
except ConnectionRefusedError:
logger.info("RPC connection was refused by server")
raise
except ConnectionClosedError:
logger.info("RPC connection lost")
raise
except ConnectionClosedOK:
logger.info("RPC connection closed")
raise
except InvalidStatusCode as err:
logger.info(
f"RPC Websocket failed - with invalid status code {err.status_code}")
raise
except WebSocketException as err:
logger.info(f"RPC Websocket failed - with {err}")
raise
except OSError as err:
logger.info("RPC Connection failed - %s", err)
raise
except Exception as err:
logger.exception("RPC Error")
raise

async def send(self, msg):
if self._websocket is None:
# connect must be called before.
logging.error("Websocket connect() must be called before.")
await self._websocket.send(msg)

async def recv(self):
if self._websocket is None:
# connect must be called before.
logging.error("Websocket connect() must be called before.")
try:
msg = await self._websocket.recv()
except websockets.exceptions.ConnectionClosed:
logger.debug("Connection closed.", exc_info=True)
return None
return msg

async def close(self, code: int = 1000):
if self._websocket is not None:
# Case opened, we have something to close.
self._websocket.close(code)

def isNotInvalidStatusCode(value):
return not isinstance(value, InvalidStatusCode)
Expand Down Expand Up @@ -59,6 +199,7 @@ def __init__(self, uri: str, methods: RpcMethodsBase = None,
on_disconnect: List[OnDisconnectCallback] = None,
keep_alive: float = 0,
serializing_socket_cls: Type[SimpleWebSocket] = JsonSerializingWebSocket,
websocket_client_handler_cls: Type[SimpleWebSocket] = WebSocketsClientHandler,
**kwargs):
"""
Args:
Expand All @@ -71,8 +212,7 @@ def __init__(self, uri: str, methods: RpcMethodsBase = None,
on_disconnect (List[Coroutine]): callbacks on connection termination (each callback is called with the channel)
keep_alive(float): interval in seconds to send a keep-alive ping, Defaults to 0, which means keep alive is disabled.
**kwargs: Additional args passed to connect (@see class Connect at websockets/client.py)
https://websockets.readthedocs.io/en/stable/api.html#websockets.client.connect
**kwargs: Additional args passed to connect, depends on websocket_client_handler_cls
usage:
Expand Down Expand Up @@ -105,15 +245,24 @@ def __init__(self, uri: str, methods: RpcMethodsBase = None,
self._on_connect = on_connect
# serialization
self._serializing_socket_cls = serializing_socket_cls
# websocket client implementation
self._websocket_client_handler_cls = websocket_client_handler_cls

async def __connect__(self):
logger.info(f"Trying server - {self.uri}")
try:
raw_ws = self._websocket_client_handler_cls()
# Wrap socket in our serialization class
self.ws = self._serializing_socket_cls(raw_ws)
except:
logger.exception("Class instantiation error.")
raise
# No try/catch for connect() to avoid double error logging. Any exception from the method must be handled by
# itself for logging, then raised and caught outside of connect() (e.g.: for retry purpose).
# Start connection
await self.ws.connect(self.uri, **self.connect_kwargs)
try:
try:
logger.info(f"Trying server - {self.uri}")
# Start connection
raw_ws = await websockets.connect(self.uri, **self.connect_kwargs)
# Wrap socket in our serialization class
self.ws = self._serializing_socket_cls(raw_ws)
# Init an RPC channel to work on-top of the connection
self.channel = RpcChannel(
self.methods, self.ws, default_response_timeout=self.default_response_timeout)
Expand All @@ -137,25 +286,6 @@ async def __connect__(self):
await self.channel.close()
self.cancel_tasks()
raise
except ConnectionRefusedError:
logger.info("RPC connection was refused by server")
raise
except ConnectionClosedError:
logger.info("RPC connection lost")
raise
except ConnectionClosedOK:
logger.info("RPC connection closed")
raise
except InvalidStatusCode as err:
logger.info(
f"RPC Websocket failed - with invalid status code {err.status_code}")
raise
except WebSocketException as err:
logger.info(f"RPC Websocket failed - with {err}")
raise
except OSError as err:
logger.info("RPC Connection failed - %s", err)
raise
except Exception as err:
logger.exception("RPC Error")
raise
Expand Down Expand Up @@ -200,15 +330,18 @@ async def reader(self):
try:
while True:
raw_message = await self.ws.recv()
await self.channel.on_message(raw_message)
if raw_message is None:
# None is a special case where connection is closed.
logger.info("Connection was terminated.")
await self.close()
break
else:
await self.channel.on_message(raw_message)
# Graceful external termination options
# task was canceled
except asyncio.CancelledError:
pass
except websockets.exceptions.ConnectionClosed:
logger.info("Connection was terminated.")
await self.close()
except:
except Exception as err:
logger.exception("RPC Reader task failed")
raise

Expand Down
4 changes: 4 additions & 0 deletions fastapi_websocket_rpc/websocket_rpc_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def __init__(self, websocket: WebSocket, frame_type: WebSocketFrameType = WebSoc
self.websocket = websocket
self.frame_type = frame_type

# This method is only useful on websocket_rpc_client. Here on endpoint file, it has nothing to connect to.
async def connect(self, uri: str, **connect_kwargs):
pass

@property
def send(self):
if self.frame_type == WebSocketFrameType.Binary:
Expand Down
6 changes: 6 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Handling DeprecationWarning 'asyncio_mode' default value
[pytest]
asyncio_mode = strict
pythonpath = .
log_cli = 1
log_cli_level = DEBUG
log_cli_format = %(asctime)s [%(levelname)s] (%(filename)s:%(lineno)s) %(message)s
log_date_format = %Y-%m-%d %H:%M:%S

9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from setuptools import setup, find_packages

import os

def get_requirements(env=""):
if env:
env = "-{}".format(env)
with open("requirements{}.txt".format(env)) as fp:
return [x.strip() for x in fp.read().split("\n") if not x.startswith("#")]

requirements = [x.strip() for x in fp.read().split("\n") if not x.startswith("#")]
withWebsocketClient = os.environ.get("WITH_WEBSOCKET_CLIENT", "False")
if bool(withWebsocketClient):
requirements.append("websocket-client>=1.1.0")
return requirements

with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
Expand Down
4 changes: 0 additions & 4 deletions tests/advanced_rpc_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import os
import sys

# Add parent path to use local src as package for tests
sys.path.append(os.path.abspath(os.path.join(
os.path.dirname(__file__), os.path.pardir)))

import time
import asyncio
from multiprocessing import Process
Expand Down
6 changes: 1 addition & 5 deletions tests/basic_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
import os
import sys

# Add parent path to use local src as package for tests
sys.path.append(os.path.abspath(os.path.join(
os.path.dirname(__file__), os.path.pardir)))

import asyncio
from multiprocessing import Process

Expand All @@ -14,7 +10,7 @@
from fastapi import FastAPI

from fastapi_websocket_rpc.rpc_methods import RpcUtilityMethods
from fastapi_websocket_rpc.logger import logging_config, LoggingModes
from fastapi_websocket_rpc.logger import logging_config, LoggingModes, get_logger
from fastapi_websocket_rpc.websocket_rpc_client import WebSocketRpcClient
from fastapi_websocket_rpc.websocket_rpc_endpoint import WebsocketRPCEndpoint
from fastapi_websocket_rpc.utils import gen_uid
Expand Down
Loading

0 comments on commit 60497a3

Please sign in to comment.