diff --git a/nameko_grpc/__init__.py b/nameko_grpc/__init__.py index e69de29..3f534cf 100644 --- a/nameko_grpc/__init__.py +++ b/nameko_grpc/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +from nameko_grpc.h2_patch import patch_h2_transitions + + +patch_h2_transitions() diff --git a/nameko_grpc/channel.py b/nameko_grpc/channel.py index 496276b..625d573 100644 --- a/nameko_grpc/channel.py +++ b/nameko_grpc/channel.py @@ -39,6 +39,8 @@ def __init__(self, targets, ssl, spawn_thread): self.spawn_thread = spawn_thread self.connections = queue.Queue() + self.is_accepting = False + self.listening_socket = None def connect(self, target): sock = socket.create_connection( diff --git a/nameko_grpc/connection.py b/nameko_grpc/connection.py index b3ea751..5bd3d08 100644 --- a/nameko_grpc/connection.py +++ b/nameko_grpc/connection.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import itertools +import logging import select import sys from collections import deque @@ -8,7 +9,7 @@ from threading import Event from grpc import StatusCode -from h2.config import H2Configuration +from h2.config import DummyLogger, H2Configuration from h2.connection import H2Connection from h2.errors import ErrorCodes from h2.events import ( @@ -37,9 +38,30 @@ log = getLogger(__name__) +class H2Logger(DummyLogger): + """ + Provide logger to H2 matching required interface + """ + + def __init__(self, logger: logging.Logger): + super().__init__() + self.logger = logger + + def debug(self, *vargs, **kwargs): + self.logger.debug(*vargs, **kwargs) + + def trace(self, *vargs, **kwargs): + # log level below debug + self.logger.log(5, *vargs, **kwargs) + + SELECT_TIMEOUT = 0.01 +class ConnectionTerminatingError(Exception): + pass + + class ConnectionManager: """ Base class for managing a single GRPC HTTP/2 connection. @@ -52,7 +74,8 @@ class ConnectionManager: def __init__(self, sock, client_side): self.sock = sock - config = H2Configuration(client_side=client_side) + h2_logger = H2Logger(log.getChild("h2")) + config = H2Configuration(client_side=client_side, logger=h2_logger) self.conn = H2Connection(config=config) self.receive_streams = {} @@ -60,10 +83,11 @@ def __init__(self, sock, client_side): self.run = True self.stopped = Event() + self.terminating = False @property def alive(self): - return not self.stopped.is_set() + return not self.stopped.is_set() and not self.terminating @contextmanager def cleanup_on_exit(self): @@ -83,7 +107,7 @@ def cleanup_on_exit(self): if send_stream.closed: continue # stream.close() is idemponent but this prevents the log log.info( - f"Terminating send stream {send_stream}" + f"Terminating send stream {send_stream.stream_id}" f"{f' with error {error}' if error else ''}." ) send_stream.close(error) @@ -91,15 +115,17 @@ def cleanup_on_exit(self): if receive_stream.closed: continue # stream.close() is idemponent but this prevents the log log.info( - f"Terminating receive stream {receive_stream}" + f"Terminating receive stream {receive_stream.stream_id}" f"{f' with error {error}' if error else ''}." ) receive_stream.close(error) self.sock.close() self.stopped.set() + log.debug(f"connection terminated {self}") def run_forever(self): """Event loop.""" + log.debug(f"connection initiated {self}") self.conn.initiate_connection() with self.cleanup_on_exit(): @@ -108,6 +134,9 @@ def run_forever(self): self.on_iteration() + if not self.run: + break + self.sock.sendall(self.conn.data_to_send()) ready = select.select([self.sock], [], [], SELECT_TIMEOUT) if not ready[0]: @@ -142,8 +171,10 @@ def run_forever(self): self.connection_terminated(event) def stop(self): - self.run = False - self.stopped.wait() + self.conn.close_connection() + self.terminating = True + log.debug("waiting for connection to terminate (Timeout 5s)") + self.stopped.wait(5) def on_iteration(self): """Called on every iteration of the event loop. @@ -155,6 +186,16 @@ def on_iteration(self): self.send_headers(stream_id) self.send_data(stream_id) + if self.terminating: + send_streams_closed = all( + stream.exhausted for stream in self.send_streams.values() + ) + receive_streams_closed = all( + stream.exhausted for stream in self.receive_streams.values() + ) + if send_streams_closed and receive_streams_closed: + self.run = False + def request_received(self, event): """Called when a request is received on a stream. @@ -199,6 +240,7 @@ def window_updated(self, event): Any data waiting to be sent on the stream may fit in the window now. """ log.debug("window updated, stream %s", event.stream_id) + self.send_headers(event.stream_id) self.send_data(event.stream_id) def stream_ended(self, event): @@ -238,9 +280,22 @@ def trailers_received(self, event): receive_stream.trailers.set(*event.headers, from_wire=True) - def connection_terminated(self, event): - log.debug("connection terminated") - self.run = False + def connection_terminated(self, event: ConnectionTerminated): + """H2 signals a connection terminated event after receiving a GOAWAY frame + + If no error has occurred, flag termination and initiate a graceful termination + allowing existing streams to finish sending/receiving. + + If an error has occurred then close down immediately. + """ + log.debug(f"received GOAWAY with error code {event.error_code}") + if event.error_code not in (ErrorCodes.NO_ERROR, ErrorCodes.ENHANCE_YOUR_CALM): + log.debug("connection terminating immediately") + self.terminating = True + self.run = False + else: + log.debug("connection terminating") + self.terminating = True def send_headers(self, stream_id, immediate=False): """Attempt to send any headers on a stream. @@ -272,6 +327,13 @@ def send_data(self, stream_id): # has been completely sent return + # When a stream is closed, a STREAM_END item or ERROR is placed in the queue. + # If we never read from the stream again, these are not consumed, and the + # stream is never exhausted which prevents a graceful termination. + # Because we return early if headers haven't been sent, we need to manually + # flush the queue (an operation that would otherwise occur during `stream.read`) + send_stream.flush_queue_to_buffer() + if not send_stream.headers_sent: # don't attempt to send any data until the headers have been sent return @@ -334,7 +396,20 @@ def send_request(self, request_headers): over the response. Invocations are queued and sent on the next iteration of the event loop. + + raises ConnectionTerminatingError if connection is terminating. Check + connection .is_alive() before initiating send_request + + Note: + We are handling termination and raising TerminatingError here as the + underlying library H2 doesn't do this. If H2 ever begins handling graceful + shutdowns, this logic will need altering. + https://github.com/python-hyper/h2/issues/1181 """ + if self.terminating: + raise ConnectionTerminatingError( + "Connection is terminating. No new streams can be initiated" + ) stream_id = next(self.counter) request_stream = SendStream(stream_id) diff --git a/nameko_grpc/h2_patch.py b/nameko_grpc/h2_patch.py new file mode 100644 index 0000000..dfdfd14 --- /dev/null +++ b/nameko_grpc/h2_patch.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +import logging + +from h2.connection import ( + ConnectionInputs, + ConnectionState, + H2Connection, + H2ConnectionStateMachine, +) + + +logger = logging.getLogger(__name__) + + +def patch_h2_transitions() -> None: + """ + H2 transitions immediately to closed state upon receiving a GOAWAY frame. + This is out of spec and results in errors when frames are received between + the GOAWAY and us (the client) terminating the connection gracefully. + We still terminate the connection following a GOAWAY. + https://github.com/python-hyper/h2/issues/1181 + + Instead of transitioning to CLOSED, remain in the current STATE and await + a graceful termination. + + We also need to patch (noop) clear_outbound_data_buffer which would clear any + outbound data. + + Fixes: + RPC terminated with: + code = StatusCode.UNAVAILABLE + message = "Invalid input ConnectionInputs.RECV_PING in state ConnectionState.CLOSED" + status = "code: 14 + """ + logger.info("H2 transitions patched for RECV_GOAWAY frame fix") + + patched_transitions = { + # State: idle + (ConnectionState.IDLE, ConnectionInputs.RECV_GOAWAY): ( + None, + ConnectionState.IDLE, + ), + # State: open, client side. + (ConnectionState.CLIENT_OPEN, ConnectionInputs.RECV_GOAWAY): ( + None, + ConnectionState.CLIENT_OPEN, + ), + # State: open, server side. + (ConnectionState.SERVER_OPEN, ConnectionInputs.RECV_GOAWAY): ( + None, + ConnectionState.SERVER_OPEN, + ), + (ConnectionState.IDLE, ConnectionInputs.SEND_GOAWAY): ( + None, + ConnectionState.IDLE, + ), + (ConnectionState.CLIENT_OPEN, ConnectionInputs.SEND_GOAWAY): ( + None, + ConnectionState.CLIENT_OPEN, + ), + (ConnectionState.SERVER_OPEN, ConnectionInputs.SEND_GOAWAY): ( + None, + ConnectionState.SERVER_OPEN, + ), + } + + H2ConnectionStateMachine._transitions.update(patched_transitions) + + # no op this method which is called by h2 after recieving a GO_AWAY frame + def clear_outbound_data_buffer(*args, **kwargs): + pass + + H2Connection.clear_outbound_data_buffer = clear_outbound_data_buffer diff --git a/nameko_grpc/streams.py b/nameko_grpc/streams.py index c8b9765..1802d48 100644 --- a/nameko_grpc/streams.py +++ b/nameko_grpc/streams.py @@ -72,7 +72,7 @@ def close(self, error=None): in race conditions between timeout threads, connection teardown, and the natural termination of streams. - SendStreams have an additional race condition beteen the end of the iterator + SendStreams have an additional race condition between the end of the iterator and the StreamEnded event received from the remote side. An error is only raised if the first invocation happened due to an error. @@ -168,7 +168,7 @@ def headers_to_send(self, defer_until_data=True): if self.headers_sent or len(self.headers) == 0: return False - if defer_until_data and self.queue.empty(): + if defer_until_data and self.queue.empty() and self.buffer.empty(): return False self.headers_sent = True