diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 312293d..b1fcb13 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -15,10 +15,10 @@ jobs: runs-on: ubuntu-latest steps: - name: checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: setup python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.9 @@ -33,7 +33,7 @@ jobs: TOXENV: static test: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 strategy: matrix: python-version: @@ -44,10 +44,10 @@ jobs: steps: - name: checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: setup python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -74,10 +74,10 @@ jobs: steps: - name: checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: setup python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.9 diff --git a/.gitignore b/.gitignore index 269f432..d67acb7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ __pycache__ .tox *_pb2.py *_pb2_grpc.py -.coverage \ No newline at end of file +.coverage +.DS_Store +.idea \ No newline at end of file diff --git a/nameko_grpc/__init__.py b/nameko_grpc/__init__.py index e69de29..3f534cf 100644 --- a/nameko_grpc/__init__.py +++ b/nameko_grpc/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +from nameko_grpc.h2_patch import patch_h2_transitions + + +patch_h2_transitions() diff --git a/nameko_grpc/channel.py b/nameko_grpc/channel.py index 496276b..71603c5 100644 --- a/nameko_grpc/channel.py +++ b/nameko_grpc/channel.py @@ -111,14 +111,19 @@ class ServerConnectionPool: Just accepts new connections and allows them to run until close. """ - def __init__(self, host, port, ssl, spawn_thread, handle_request): + def __init__( + self, host, port, ssl, spawn_thread, handle_request, max_concurrent_streams=100 + ): self.host = host self.port = port self.ssl = ssl self.spawn_thread = spawn_thread self.handle_request = handle_request + self.max_concurrent_streams = max_concurrent_streams self.connections = queue.Queue() + self.is_accepting = False + self.listening_socket = None def listen(self): sock = eventlet.listen((self.host, self.port)) @@ -139,7 +144,9 @@ def run(self): sock, _ = self.listening_socket.accept() sock.settimeout(60) # XXX needed and/or correct value? - connection = ServerConnectionManager(sock, self.handle_request) + connection = ServerConnectionManager( + sock, self.handle_request, self.max_concurrent_streams + ) self.connections.put(weakref.ref(connection)) self.spawn_thread( target=connection.run_forever, name=f"grpc server connection [{sock}]" @@ -165,9 +172,11 @@ def stop(self): class ServerChannel: """Simple server channel encapsulating incoming connection management.""" - def __init__(self, host, port, ssl, spawn_thread, handle_request): + def __init__( + self, host, port, ssl, spawn_thread, handle_request, max_concurrent_streams=100 + ): self.conn_pool = ServerConnectionPool( - host, port, ssl, spawn_thread, handle_request + host, port, ssl, spawn_thread, handle_request, max_concurrent_streams ) def start(self): diff --git a/nameko_grpc/connection.py b/nameko_grpc/connection.py index fbd7b42..435e9a2 100644 --- a/nameko_grpc/connection.py +++ b/nameko_grpc/connection.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import itertools +import logging import select import sys from collections import deque @@ -8,7 +9,7 @@ from threading import Event from grpc import StatusCode -from h2.config import H2Configuration +from h2.config import DummyLogger, H2Configuration from h2.connection import H2Connection from h2.errors import ErrorCodes from h2.events import ( @@ -37,9 +38,30 @@ log = getLogger(__name__) +class H2Logger(DummyLogger): + """ + Provide logger to H2 matching required interface + """ + + def __init__(self, logger: logging.Logger): + super().__init__() + self.logger = logger + + def debug(self, *vargs, **kwargs): + self.logger.debug(*vargs, **kwargs) + + def trace(self, *vargs, **kwargs): + # log level below debug + self.logger.log(5, *vargs, **kwargs) + + SELECT_TIMEOUT = 0.01 +class ConnectionTerminatingError(Exception): + pass + + class ConnectionManager: """ Base class for managing a single GRPC HTTP/2 connection. @@ -49,21 +71,24 @@ class ConnectionManager: by subclasses. """ - def __init__(self, sock, client_side): + def __init__(self, sock, client_side, max_concurrent_streams=100): self.sock = sock - config = H2Configuration(client_side=client_side) + h2_logger = H2Logger(log.getChild("h2")) + config = H2Configuration(client_side=client_side, logger=h2_logger) self.conn = H2Connection(config=config) + self.conn.local_settings.max_concurrent_streams = max_concurrent_streams self.receive_streams = {} self.send_streams = {} self.run = True self.stopped = Event() + self.terminating = False @property def alive(self): - return not self.stopped.is_set() + return not self.stopped.is_set() and not self.terminating @contextmanager def cleanup_on_exit(self): @@ -83,7 +108,7 @@ def cleanup_on_exit(self): if send_stream.closed: continue # stream.close() is idemponent but this prevents the log log.info( - f"Terminating send stream {send_stream}" + f"Terminating send stream {send_stream.stream_id}" f"{f' with error {error}' if error else ''}." ) send_stream.close(error) @@ -91,15 +116,17 @@ def cleanup_on_exit(self): if receive_stream.closed: continue # stream.close() is idemponent but this prevents the log log.info( - f"Terminating receive stream {receive_stream}" + f"Terminating receive stream {receive_stream.stream_id}" f"{f' with error {error}' if error else ''}." ) receive_stream.close(error) self.sock.close() self.stopped.set() + log.debug(f"connection terminated {self}") def run_forever(self): """Event loop.""" + log.debug(f"connection initiated {self}") self.conn.initiate_connection() with self.cleanup_on_exit(): @@ -108,6 +135,9 @@ def run_forever(self): self.on_iteration() + if not self.run: + break + self.sock.sendall(self.conn.data_to_send()) ready = select.select([self.sock], [], [], SELECT_TIMEOUT) if not ready[0]: @@ -142,8 +172,10 @@ def run_forever(self): self.connection_terminated(event) def stop(self): - self.run = False - self.stopped.wait() + self.conn.close_connection() + self.terminating = True + log.debug("waiting for connection to terminate (Timeout 5s)") + self.stopped.wait(5) def on_iteration(self): """Called on every iteration of the event loop. @@ -155,6 +187,16 @@ def on_iteration(self): self.send_headers(stream_id) self.send_data(stream_id) + if self.terminating: + send_streams_closed = all( + stream.exhausted for stream in self.send_streams.values() + ) + receive_streams_closed = all( + stream.exhausted for stream in self.receive_streams.values() + ) + if send_streams_closed and receive_streams_closed: + self.run = False + def request_received(self, event): """Called when a request is received on a stream. @@ -199,6 +241,7 @@ def window_updated(self, event): Any data waiting to be sent on the stream may fit in the window now. """ log.debug("window updated, stream %s", event.stream_id) + self.send_headers(event.stream_id) self.send_data(event.stream_id) def stream_ended(self, event): @@ -210,16 +253,22 @@ def stream_ended(self, event): receive_stream = self.receive_streams.pop(event.stream_id, None) if receive_stream: receive_stream.close() + # send_stream = self.send_streams.pop(event.stream_id, None) + # if send_stream: + # send_stream.close() def stream_reset(self, event): """Called when an incoming stream is reset. - Close any `ReceiveStream` that was opened for this stream. + Close any `ReceiveStream` or `SendStream` that was opened for this stream. """ log.debug("stream reset, stream %s", event.stream_id) receive_stream = self.receive_streams.pop(event.stream_id, None) if receive_stream: receive_stream.close() + send_stream = self.send_streams.pop(event.stream_id, None) + if send_stream: + send_stream.close() def settings_changed(self, event): log.debug("settings changed") @@ -235,9 +284,22 @@ def trailers_received(self, event): receive_stream.trailers.set(*event.headers, from_wire=True) - def connection_terminated(self, event): - log.debug("connection terminated") - self.run = False + def connection_terminated(self, event: ConnectionTerminated): + """H2 signals a connection terminated event after receiving a GOAWAY frame + + If no error has occurred, flag termination and initiate a graceful termination + allowing existing streams to finish sending/receiving. + + If an error has occurred then close down immediately. + """ + log.debug(f"received GOAWAY with error code {event.error_code}") + if event.error_code not in (ErrorCodes.NO_ERROR, ErrorCodes.ENHANCE_YOUR_CALM): + log.debug("connection terminating immediately") + self.terminating = True + self.run = False + else: + log.debug("connection terminating") + self.terminating = True def send_headers(self, stream_id, immediate=False): """Attempt to send any headers on a stream. @@ -269,6 +331,14 @@ def send_data(self, stream_id): # has been completely sent return + # When a stream is closed, a STREAM_END item or ERROR is placed in the queue. + # If we never read from the stream again, these are not surfaced, and the + # stream is never exhausted. + # Because we shortcut sending data if headers haven't been set yet, we need + # to manually flush the queue, surfacing the end/error, and ensuring the + # queue exhausts (and we can terminate). + send_stream.flush_queue_to_buffer() + if not send_stream.headers_sent: # don't attempt to send any data until the headers have been sent return @@ -331,7 +401,19 @@ def send_request(self, request_headers): over the response. Invocations are queued and sent on the next iteration of the event loop. + + raises ConnectionTerminatingError if connection is terminating. Check + connection .is_alive() before initiating send_request + + Note: + We are handling termination and raising TerminatingError here as the + underlying library H2 doesn't do this. If H2 ever begins handling graceful + shutdowns, this logic will need altering. """ + if self.terminating: + raise ConnectionTerminatingError( + "Connection is terminating. No new streams can be initiated" + ) stream_id = next(self.counter) request_stream = SendStream(stream_id) @@ -427,8 +509,10 @@ class ServerConnectionManager(ConnectionManager): Extends the base `ConnectionManager` to handle incoming GRPC requests. """ - def __init__(self, sock, handle_request): - super().__init__(sock, client_side=False) + def __init__(self, sock, handle_request, max_concurrent_streams=100): + super().__init__( + sock, client_side=False, max_concurrent_streams=max_concurrent_streams + ) self.handle_request = handle_request def request_received(self, event): diff --git a/nameko_grpc/entrypoint.py b/nameko_grpc/entrypoint.py index db3d061..ca40eb3 100644 --- a/nameko_grpc/entrypoint.py +++ b/nameko_grpc/entrypoint.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- +import queue import sys import time import types +import weakref from functools import partial from logging import getLogger @@ -27,6 +29,7 @@ class GrpcServer(SharedExtension): def __init__(self): super(GrpcServer, self).__init__() self.entrypoints = {} + self.spawned_threads = queue.Queue() def register(self, entrypoint): self.entrypoints[entrypoint.method_path] = entrypoint @@ -78,18 +81,27 @@ def setup(self): host = config.get("GRPC_BIND_HOST", "0.0.0.0") port = config.get("GRPC_BIND_PORT", 50051) ssl = SslConfig(config.get("GRPC_SSL")) + max_concurrent_streams = config.get("MAX_CONCURRENT_STREAMS", 100) def spawn_thread(target, args=(), kwargs=None, name=None): - self.container.spawn_managed_thread( + thread = self.container.spawn_managed_thread( lambda: target(*args, **kwargs or {}), identifier=name ) + self.spawned_threads.put(weakref.ref(thread)) - self.channel = ServerChannel(host, port, ssl, spawn_thread, self.handle_request) + self.channel = ServerChannel( + host, port, ssl, spawn_thread, self.handle_request, max_concurrent_streams + ) def start(self): self.channel.start() def stop(self): + while not self.spawned_threads.empty(): + thread = self.spawned_threads.get()() + if thread: + thread.kill() + self.channel.stop() super(GrpcServer, self).stop() diff --git a/nameko_grpc/h2_patch.py b/nameko_grpc/h2_patch.py new file mode 100644 index 0000000..dfdfd14 --- /dev/null +++ b/nameko_grpc/h2_patch.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +import logging + +from h2.connection import ( + ConnectionInputs, + ConnectionState, + H2Connection, + H2ConnectionStateMachine, +) + + +logger = logging.getLogger(__name__) + + +def patch_h2_transitions() -> None: + """ + H2 transitions immediately to closed state upon receiving a GOAWAY frame. + This is out of spec and results in errors when frames are received between + the GOAWAY and us (the client) terminating the connection gracefully. + We still terminate the connection following a GOAWAY. + https://github.com/python-hyper/h2/issues/1181 + + Instead of transitioning to CLOSED, remain in the current STATE and await + a graceful termination. + + We also need to patch (noop) clear_outbound_data_buffer which would clear any + outbound data. + + Fixes: + RPC terminated with: + code = StatusCode.UNAVAILABLE + message = "Invalid input ConnectionInputs.RECV_PING in state ConnectionState.CLOSED" + status = "code: 14 + """ + logger.info("H2 transitions patched for RECV_GOAWAY frame fix") + + patched_transitions = { + # State: idle + (ConnectionState.IDLE, ConnectionInputs.RECV_GOAWAY): ( + None, + ConnectionState.IDLE, + ), + # State: open, client side. + (ConnectionState.CLIENT_OPEN, ConnectionInputs.RECV_GOAWAY): ( + None, + ConnectionState.CLIENT_OPEN, + ), + # State: open, server side. + (ConnectionState.SERVER_OPEN, ConnectionInputs.RECV_GOAWAY): ( + None, + ConnectionState.SERVER_OPEN, + ), + (ConnectionState.IDLE, ConnectionInputs.SEND_GOAWAY): ( + None, + ConnectionState.IDLE, + ), + (ConnectionState.CLIENT_OPEN, ConnectionInputs.SEND_GOAWAY): ( + None, + ConnectionState.CLIENT_OPEN, + ), + (ConnectionState.SERVER_OPEN, ConnectionInputs.SEND_GOAWAY): ( + None, + ConnectionState.SERVER_OPEN, + ), + } + + H2ConnectionStateMachine._transitions.update(patched_transitions) + + # no op this method which is called by h2 after recieving a GO_AWAY frame + def clear_outbound_data_buffer(*args, **kwargs): + pass + + H2Connection.clear_outbound_data_buffer = clear_outbound_data_buffer diff --git a/nameko_grpc/streams.py b/nameko_grpc/streams.py index edf81b8..34906f6 100644 --- a/nameko_grpc/streams.py +++ b/nameko_grpc/streams.py @@ -57,6 +57,10 @@ def __init__(self, stream_id): def exhausted(self): """A stream is exhausted if it is closed and there are no more messages to be consumed or bytes to be read. + + When a stream is closed we append the STREAM_END item or a GrpcError, so an + exhausted stream will possibly still have 1 item left in the queue, so we + must check for that. """ return self.closed and self.queue.empty() and self.buffer.empty() @@ -70,7 +74,7 @@ def close(self, error=None): in race conditions between timeout threads, connection teardown, and the natural termination of streams. - SendStreams have an additional race condition beteen the end of the iterator + SendStreams have an additional race condition between the end of the iterator and the StreamEnded event received from the remote side. An error is only raised if the first invocation happened due to an error. @@ -166,7 +170,7 @@ def headers_to_send(self, defer_until_data=True): if self.headers_sent or len(self.headers) == 0: return False - if defer_until_data and self.queue.empty(): + if defer_until_data and self.queue.empty() and self.buffer.empty(): return False self.headers_sent = True diff --git a/test/test_errors.py b/test/test_errors.py index a36e22b..d9d7cf8 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -7,7 +7,6 @@ import pytest from grpc import StatusCode -from h2.events import ConnectionTerminated from nameko import config from nameko_grpc.constants import Cardinality @@ -360,19 +359,3 @@ def test_invalid_request(self, client, protobufs): client.unary_unary(protobufs.ExampleRequest(value="hello")) assert error.value.code == StatusCode.INTERNAL assert error.value.message == "Exception deserializing request!" - - -class TestErrorStreamClosed: - @pytest.fixture(params=["client=nameko"]) - def client_type(self, request): - return request.param[7:] - - def test_response_stream_closed(self, client, protobufs): - with mock.patch( - "h2.connection.H2Connection.receive_data", - return_value=[ConnectionTerminated()], - ): - with pytest.raises(GrpcError) as error: - client.unary_unary(protobufs.ExampleRequest(value="hello")) - assert error.value.code == StatusCode.UNAVAILABLE - assert error.value.message == "Stream was closed mid-request" diff --git a/tox.ini b/tox.ini index ff1fd3f..c7323f3 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,7 @@ envlist = static, {py3.6,py3.7,py3.8,py3.9}-test skipsdist = True [testenv] -whitelist_externals = make +allowlist_externals = make commands = static: pip install pre-commit