diff --git a/changelog.d/11938.misc b/changelog.d/11938.misc new file mode 100644 index 000000000000..1d3a0030f77f --- /dev/null +++ b/changelog.d/11938.misc @@ -0,0 +1 @@ +Add missing type hints to replication code. diff --git a/mypy.ini b/mypy.ini index 2884078d0ad3..cd28ac0dd2fe 100644 --- a/mypy.ini +++ b/mypy.ini @@ -169,6 +169,9 @@ disallow_untyped_defs = True [mypy-synapse.push.*] disallow_untyped_defs = True +[mypy-synapse.replication.*] +disallow_untyped_defs = True + [mypy-synapse.rest.*] disallow_untyped_defs = True diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index fa132d10b414..8f3f953ed474 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -40,7 +40,7 @@ def __init__( for table, column in extra_tables: self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, instance_name: Optional[str], new_id: int): + def advance(self, instance_name: Optional[str], new_id: int) -> None: self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self) -> int: diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index bc888ce1a871..b5b84c09ae41 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -37,7 +37,9 @@ def __init__( cache_name="client_ip_last_seen", max_size=50000 ) - async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + async def insert_client_ip( + self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str + ) -> None: now = int(self._clock.time_msec()) key = (user_id, access_token, ip) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index a2aff75b7075..0ffd34f1dad0 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -60,7 +60,9 @@ def __init__( def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) @@ -70,7 +72,9 @@ def process_replication_rows(self, stream_name, instance_name, token, rows): self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) - def _invalidate_caches_for_devices(self, token, rows): + def _invalidate_caches_for_devices( + self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] + ) -> None: for row in rows: # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 9d90e26375f0..d6f37d7479ba 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -44,10 +44,12 @@ def __init__( self._group_updates_id_gen.get_current_token(), ) - def get_group_stream_token(self): + def get_group_stream_token(self) -> int: return self._group_updates_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == GroupServerStream.NAME: self._group_updates_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 7541e21de9dd..52ee3f7e58ea 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Iterable from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -20,10 +21,12 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_max_push_rules_stream_id(self): + def get_max_push_rules_stream_id(self) -> int: return self._push_rules_stream_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index cea90c0f1bf4..de642bba71b0 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from synapse.replication.tcp.streams import PushersStream from synapse.storage.database import DatabasePool, LoggingDatabaseConnection @@ -41,8 +41,8 @@ def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() def process_replication_rows( - self, stream_name: str, instance_name: str, token, rows + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(instance_name, token) # type: ignore + self._pushers_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e29ae1e375af..d59ce7ccf967 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,10 +14,12 @@ """A replication client for use by synapse workers. """ import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.protocol import ReconnectingClientFactory +from twisted.python.failure import Failure from synapse.api.constants import EventTypes from synapse.federation import send_queue @@ -79,10 +81,10 @@ def __init__( hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) - def startedConnecting(self, connector): + def startedConnecting(self, connector: IConnector) -> None: logger.info("Connecting to replication: %r", connector.getDestination()) - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol: logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( self.hs, @@ -92,11 +94,11 @@ def buildProtocol(self, addr): self.command_handler, ) - def clientConnectionLost(self, connector, reason): + def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: logger.error("Lost replication conn: %r", reason) ReconnectingClientFactory.clientConnectionLost(self, connector, reason) - def clientConnectionFailed(self, connector, reason): + def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: logger.error("Failed to connect to replication: %r", reason) ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) @@ -131,7 +133,7 @@ def __init__(self, hs: "HomeServer"): async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list - ): + ) -> None: """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -252,14 +254,16 @@ async def on_rdata( # loop. (This maintains the order so no need to resort) waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] - async def on_position(self, stream_name: str, instance_name: str, token: int): + async def on_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: await self.on_rdata(stream_name, instance_name, token, []) # We poke the generic "replication" notifier to wake anything up that # may be streaming. self.notifier.notify_replication() - def on_remote_server_up(self, server: str): + def on_remote_server_up(self, server: str) -> None: """Called when get a new REMOTE_SERVER_UP command.""" # Let's wake up the transaction queue for the server in case we have @@ -269,7 +273,7 @@ def on_remote_server_up(self, server: str): async def wait_for_stream_position( self, instance_name: str, stream_name: str, position: int - ): + ) -> None: """Wait until this instance has received updates up to and including the given stream position. """ @@ -304,7 +308,7 @@ async def wait_for_stream_position( "Finished waiting for repl stream %r to reach %s", stream_name, position ) - def stop_pusher(self, user_id, app_id, pushkey): + def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: if not self._notify_pushers: return @@ -316,13 +320,13 @@ def stop_pusher(self, user_id, app_id, pushkey): logger.info("Stopping pusher %r / %r", user_id, key) pusher.on_stop() - async def start_pusher(self, user_id, app_id, pushkey): + async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: if not self._notify_pushers: return key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) class FederationSenderHandler: @@ -353,10 +357,12 @@ def __init__(self, hs: "HomeServer"): self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - def wake_destination(self, server: str): + def wake_destination(self, server: str) -> None: self.federation_sender.wake_destination(server) - async def process_replication_rows(self, stream_name, token, rows): + async def process_replication_rows( + self, stream_name: str, token: int, rows: list + ) -> None: # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": @@ -384,11 +390,12 @@ async def process_replication_rows(self, stream_name, token, rows): for host in hosts: self.federation_sender.send_device_messages(host) - async def _on_new_receipts(self, rows): + async def _on_new_receipts( + self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow] + ) -> None: """ Args: - rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]): - new receipts to be processed + rows: new receipts to be processed """ for receipt in rows: # we only want to send on receipts for our own users @@ -408,7 +415,7 @@ async def _on_new_receipts(self, rows): ) await self.federation_sender.send_read_receipt(receipt_info) - async def update_token(self, token): + async def update_token(self, token: int) -> None: """Update the record of where we have processed to in the federation stream. Called after we have processed a an update received over replication. Sends @@ -428,7 +435,7 @@ async def update_token(self, token): run_as_background_process("_save_and_send_ack", self._save_and_send_ack) - async def _save_and_send_ack(self): + async def _save_and_send_ack(self) -> None: """Save the current federation position in the database and send an ACK to master with where we're up to. """ diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 1311b013dae7..3654f6c03c7e 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -18,12 +18,15 @@ """ import abc import logging -from typing import Tuple, Type +from typing import Optional, Tuple, Type, TypeVar +from synapse.replication.tcp.streams._base import StreamRow from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) +T = TypeVar("T", bound="Command") + class Command(metaclass=abc.ABCMeta): """The base command class. @@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def from_line(cls, line): + def from_line(cls: Type[T], line: str) -> T: """Deserialises a line from the wire into this command. `line` does not include the command. """ @@ -49,21 +52,24 @@ def to_line(self) -> str: prefix. """ - def get_logcontext_id(self): + def get_logcontext_id(self) -> str: """Get a suitable string for the logcontext when processing this command""" # by default, we just use the command name. return self.NAME +SC = TypeVar("SC", bound="_SimpleCommand") + + class _SimpleCommand(Command): """An implementation of Command whose argument is just a 'data' string.""" - def __init__(self, data): + def __init__(self, data: str): self.data = data @classmethod - def from_line(cls, line): + def from_line(cls: Type[SC], line: str) -> SC: return cls(line) def to_line(self) -> str: @@ -109,14 +115,16 @@ class RdataCommand(Command): NAME = "RDATA" - def __init__(self, stream_name, instance_name, token, row): + def __init__( + self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow + ): self.stream_name = stream_name self.instance_name = instance_name self.token = token self.row = row @classmethod - def from_line(cls, line): + def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand": stream_name, instance_name, token, row_json = line.split(" ", 3) return cls( stream_name, @@ -125,7 +133,7 @@ def from_line(cls, line): json_decoder.decode(row_json), ) - def to_line(self): + def to_line(self) -> str: return " ".join( ( self.stream_name, @@ -135,7 +143,7 @@ def to_line(self): ) ) - def get_logcontext_id(self): + def get_logcontext_id(self) -> str: return "RDATA-" + self.stream_name @@ -164,18 +172,20 @@ class PositionCommand(Command): NAME = "POSITION" - def __init__(self, stream_name, instance_name, prev_token, new_token): + def __init__( + self, stream_name: str, instance_name: str, prev_token: int, new_token: int + ): self.stream_name = stream_name self.instance_name = instance_name self.prev_token = prev_token self.new_token = new_token @classmethod - def from_line(cls, line): + def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand": stream_name, instance_name, prev_token, new_token = line.split(" ", 3) return cls(stream_name, instance_name, int(prev_token), int(new_token)) - def to_line(self): + def to_line(self) -> str: return " ".join( ( self.stream_name, @@ -218,14 +228,14 @@ class ReplicateCommand(Command): NAME = "REPLICATE" - def __init__(self): + def __init__(self) -> None: pass @classmethod - def from_line(cls, line): + def from_line(cls: Type[T], line: str) -> T: return cls() - def to_line(self): + def to_line(self) -> str: return "" @@ -247,14 +257,16 @@ class UserSyncCommand(Command): NAME = "USER_SYNC" - def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + def __init__( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod - def from_line(cls, line): + def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand": instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): @@ -262,7 +274,7 @@ def from_line(cls, line): return cls(instance_id, user_id, state == "start", int(last_sync_ms)) - def to_line(self): + def to_line(self) -> str: return " ".join( ( self.instance_id, @@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command): NAME = "CLEAR_USER_SYNC" - def __init__(self, instance_id): + def __init__(self, instance_id: str): self.instance_id = instance_id @classmethod - def from_line(cls, line): + def from_line( + cls: Type["ClearUserSyncsCommand"], line: str + ) -> "ClearUserSyncsCommand": return cls(line) - def to_line(self): + def to_line(self) -> str: return self.instance_id @@ -316,7 +330,9 @@ def __init__(self, instance_name: str, token: int): self.token = token @classmethod - def from_line(cls, line: str) -> "FederationAckCommand": + def from_line( + cls: Type["FederationAckCommand"], line: str + ) -> "FederationAckCommand": instance_name, token = line.split(" ") return cls(instance_name, int(token)) @@ -334,7 +350,15 @@ class UserIpCommand(Command): NAME = "USER_IP" - def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): + def __init__( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): self.user_id = user_id self.access_token = access_token self.ip = ip @@ -343,14 +367,14 @@ def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): self.last_seen = last_seen @classmethod - def from_line(cls, line): + def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand": user_id, jsn = line.split(" ", 1) access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) - def to_line(self): + def to_line(self) -> str: return ( self.user_id + " " diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 21293038ef84..1f00155c7236 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -261,7 +261,7 @@ def _add_command_to_stream_queue( "process-replication-data", self._unsafe_process_queue, stream_name ) - async def _unsafe_process_queue(self, stream_name: str): + async def _unsafe_process_queue(self, stream_name: str) -> None: """Processes the command queue for the given stream, until it is empty Does not check if there is already a thread processing the queue, hence "unsafe" @@ -294,7 +294,7 @@ async def _process_command( # This shouldn't be possible raise Exception("Unrecognised command %s in stream queue", cmd.NAME) - def start_replication(self, hs: "HomeServer"): + def start_replication(self, hs: "HomeServer") -> None: """Helper method to start a replication connection to the remote server using TCP. """ @@ -345,10 +345,10 @@ def get_streams_to_replicate(self) -> List[Stream]: """Get a list of streams that this instances replicates.""" return self._streams_to_replicate - def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None: self.send_positions_to_connection(conn) - def send_positions_to_connection(self, conn: IReplicationConnection): + def send_positions_to_connection(self, conn: IReplicationConnection) -> None: """Send current position of all streams this process is source of to the connection. """ @@ -392,7 +392,7 @@ def on_CLEAR_USER_SYNC( def on_FEDERATION_ACK( self, conn: IReplicationConnection, cmd: FederationAckCommand - ): + ) -> None: federation_ack_counter.inc() if self._federation_sender: @@ -408,7 +408,7 @@ def on_USER_IP( else: return None - async def _handle_user_ip(self, cmd: UserIpCommand): + async def _handle_user_ip(self, cmd: UserIpCommand) -> None: await self._store.insert_client_ip( cmd.user_id, cmd.access_token, @@ -421,7 +421,7 @@ async def _handle_user_ip(self, cmd: UserIpCommand): assert self._server_notices_sender is not None await self._server_notices_sender.on_user_ip(cmd.user_id) - def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand): + def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None: if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -497,7 +497,7 @@ async def _process_rdata( async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list - ): + ) -> None: """Called to handle a batch of replication data with a given stream token. Args: @@ -512,7 +512,7 @@ async def on_rdata( stream_name, instance_name, token, rows ) - def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand): + def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand) -> None: if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return @@ -581,7 +581,7 @@ async def _process_position( def on_REMOTE_SERVER_UP( self, conn: IReplicationConnection, cmd: RemoteServerUpCommand - ): + ) -> None: """Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -604,7 +604,7 @@ def on_REMOTE_SERVER_UP( # between two instances, but that is not currently supported). self.send_command(cmd, ignore_conn=conn) - def new_connection(self, connection: IReplicationConnection): + def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -631,7 +631,7 @@ def new_connection(self, connection: IReplicationConnection): UserSyncCommand(self._instance_id, user_id, True, now) ) - def lost_connection(self, connection: IReplicationConnection): + def lost_connection(self, connection: IReplicationConnection) -> None: """Called when a connection is closed/lost.""" # we no longer need _streams_by_connection for this connection. streams = self._streams_by_connection.pop(connection, None) @@ -653,7 +653,7 @@ def connected(self) -> bool: def send_command( self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None - ): + ) -> None: """Send a command to all connected connections. Args: @@ -680,7 +680,7 @@ def send_command( else: logger.warning("Dropping command as not connected: %r", cmd.NAME) - def send_federation_ack(self, token: int): + def send_federation_ack(self, token: int) -> None: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ @@ -688,7 +688,7 @@ def send_federation_ack(self, token: int): def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int - ): + ) -> None: """Poke the master that a user has started/stopped syncing.""" self.send_command( UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) @@ -702,15 +702,15 @@ def send_user_ip( user_agent: str, device_id: str, last_seen: int, - ): + ) -> None: """Tell the master that the user made a request.""" cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) self.send_command(cmd) - def send_remote_server_up(self, server: str): + def send_remote_server_up(self, server: str) -> None: self.send_command(RemoteServerUpCommand(server)) - def stream_update(self, stream_name: str, token: str, data: Any): + def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None: """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 7bae36db169b..7763ffb2d0c7 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -49,7 +49,7 @@ import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, Collection, List, Optional +from typing import TYPE_CHECKING, Any, Collection, List, Optional from prometheus_client import Counter from zope.interface import Interface, implementer @@ -123,7 +123,7 @@ class ConnectionStates: class IReplicationConnection(Interface): """An interface for replication connections.""" - def send_command(cmd: Command): + def send_command(cmd: Command) -> None: """Send the command down the connection""" @@ -190,7 +190,7 @@ def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"): "replication-conn", self.conn_id ) - def connectionMade(self): + def connectionMade(self) -> None: logger.info("[%s] Connection established", self.id()) self.state = ConnectionStates.ESTABLISHED @@ -207,11 +207,11 @@ def connectionMade(self): # Always send the initial PING so that the other side knows that they # can time us out. - self.send_command(PingCommand(self.clock.time_msec())) + self.send_command(PingCommand(str(self.clock.time_msec()))) self.command_handler.new_connection(self) - def send_ping(self): + def send_ping(self) -> None: """Periodically sends a ping and checks if we should close the connection due to the other side timing out. """ @@ -226,7 +226,7 @@ def send_ping(self): self.transport.abortConnection() else: if now - self.last_sent_command >= PING_TIME: - self.send_command(PingCommand(now)) + self.send_command(PingCommand(str(now))) if ( self.received_ping @@ -239,12 +239,12 @@ def send_ping(self): ) self.send_error("ping timeout") - def lineReceived(self, line: bytes): + def lineReceived(self, line: bytes) -> None: """Called when we've received a line""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_line(line) - def _parse_and_dispatch_line(self, line: bytes): + def _parse_and_dispatch_line(self, line: bytes) -> None: if line.strip() == "": # Ignore blank lines return @@ -309,24 +309,24 @@ def handle_command(self, cmd: Command) -> None: if not handled: logger.warning("Unhandled command: %r", cmd) - def close(self): + def close(self) -> None: logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() assert self.transport is not None self.transport.loseConnection() self.on_connection_closed() - def send_error(self, error_string, *args): + def send_error(self, error_string: str, *args: Any) -> None: """Send an error to remote and close the connection.""" self.send_command(ErrorCommand(error_string % args)) self.close() - def send_command(self, cmd, do_buffer=True): + def send_command(self, cmd: Command, do_buffer: bool = True) -> None: """Send a command if connection has been established. Args: - cmd (Command) - do_buffer (bool): Whether to buffer the message or always attempt + cmd + do_buffer: Whether to buffer the message or always attempt to send the command. This is mostly used to send an error message if we're about to close the connection due our buffers becoming full. @@ -357,7 +357,7 @@ def send_command(self, cmd, do_buffer=True): self.last_sent_command = self.clock.time_msec() - def _queue_command(self, cmd): + def _queue_command(self, cmd: Command) -> None: """Queue the command until the connection is ready to write to again.""" logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) @@ -370,20 +370,20 @@ def _queue_command(self, cmd): self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False) self.close() - def _send_pending_commands(self): + def _send_pending_commands(self) -> None: """Send any queued commandes""" pending = self.pending_commands self.pending_commands = [] for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + def on_PING(self, cmd: PingCommand) -> None: self.received_ping = True - def on_ERROR(self, cmd): + def on_ERROR(self, cmd: ErrorCommand) -> None: logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) - def pauseProducing(self): + def pauseProducing(self) -> None: """This is called when both the kernel send buffer and the twisted tcp connection send buffers have become full. @@ -394,26 +394,26 @@ def pauseProducing(self): logger.info("[%s] Pause producing", self.id()) self.state = ConnectionStates.PAUSED - def resumeProducing(self): + def resumeProducing(self) -> None: """The remote has caught up after we started buffering!""" logger.info("[%s] Resume producing", self.id()) self.state = ConnectionStates.ESTABLISHED self._send_pending_commands() - def stopProducing(self): + def stopProducing(self) -> None: """We're never going to send any more data (normally because either we or the remote has closed the connection) """ logger.info("[%s] Stop producing", self.id()) self.on_connection_closed() - def connectionLost(self, reason): + def connectionLost(self, reason: Failure) -> None: # type: ignore[override] logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): assert reason.type is not None connection_close_counter.labels(reason.type.__name__).inc() else: - connection_close_counter.labels(reason.__class__.__name__).inc() + connection_close_counter.labels(reason.__class__.__name__).inc() # type: ignore[unreachable] try: # Remove us from list of connections to be monitored @@ -427,7 +427,7 @@ def connectionLost(self, reason): self.on_connection_closed() - def on_connection_closed(self): + def on_connection_closed(self) -> None: logger.info("[%s] Connection was closed", self.id()) self.state = ConnectionStates.CLOSED @@ -445,7 +445,7 @@ def on_connection_closed(self): # the sentinel context is now active, which may not be correct. # PreserveLoggingContext() will restore the correct logging context. - def __str__(self): + def __str__(self) -> str: addr = None if self.transport: addr = str(self.transport.getPeer()) @@ -455,10 +455,10 @@ def __str__(self): addr, ) - def id(self): + def id(self) -> str: return "%s-%s" % (self.name, self.conn_id) - def lineLengthExceeded(self, line): + def lineLengthExceeded(self, line: str) -> None: """Called when we receive a line that is above the maximum line length""" self.send_error("Line length exceeded") @@ -474,11 +474,11 @@ def __init__( self.server_name = server_name - def connectionMade(self): + def connectionMade(self) -> None: self.send_command(ServerCommand(self.server_name)) super().connectionMade() - def on_NAME(self, cmd): + def on_NAME(self, cmd: NameCommand) -> None: logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data @@ -500,19 +500,19 @@ def __init__( self.client_name = client_name self.server_name = server_name - def connectionMade(self): + def connectionMade(self) -> None: self.send_command(NameCommand(self.client_name)) super().connectionMade() # Once we've connected subscribe to the necessary streams self.replicate() - def on_SERVER(self, cmd): + def on_SERVER(self, cmd: ServerCommand) -> None: if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def replicate(self): + def replicate(self) -> None: """Send the subscription request to the server""" logger.info("[%s] Subscribing to replication streams", self.id()) @@ -529,7 +529,7 @@ def replicate(self): ) -def transport_buffer_size(protocol): +def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int: if protocol.transport: size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen return size @@ -544,7 +544,9 @@ def transport_buffer_size(protocol): ) -def transport_kernel_read_buffer_size(protocol, read=True): +def transport_kernel_read_buffer_size( + protocol: BaseReplicationStreamProtocol, read: bool = True +) -> int: SIOCINQ = 0x541B SIOCOUTQ = 0x5411 diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 8d28bd3f3fcc..13c0c46d758e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -14,7 +14,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast import attr import txredisapi @@ -62,7 +62,7 @@ class ConstantProperty(Generic[T, V]): def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: return self.constant - def __set__(self, obj: Optional[T], value: V): + def __set__(self, obj: Optional[T], value: V) -> None: pass @@ -95,7 +95,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): synapse_stream_name: str synapse_outbound_redis_connection: txredisapi.RedisProtocol - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # a logcontext which we use for processing incoming commands. We declare it as a @@ -108,12 +108,12 @@ def __init__(self, *args, **kwargs): "replication_command_handler" ) - def connectionMade(self): + def connectionMade(self) -> None: logger.info("Connected to redis") super().connectionMade() run_as_background_process("subscribe-replication", self._send_subscribe) - async def _send_subscribe(self): + async def _send_subscribe(self) -> None: # it's important to make sure that we only send the REPLICATE command once we # have successfully subscribed to the stream - otherwise we might miss the # POSITION response sent back by the other end. @@ -131,12 +131,12 @@ async def _send_subscribe(self): # otherside won't know we've connected and so won't issue a REPLICATE. self.synapse_handler.send_positions_to_connection(self) - def messageReceived(self, pattern: str, channel: str, message: str): + def messageReceived(self, pattern: str, channel: str, message: str) -> None: """Received a message from redis.""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_message(message) - def _parse_and_dispatch_message(self, message: str): + def _parse_and_dispatch_message(self, message: str) -> None: if message.strip() == "": # Ignore blank lines return @@ -181,7 +181,7 @@ def handle_command(self, cmd: Command) -> None: "replication-" + cmd.get_logcontext_id(), lambda: res ) - def connectionLost(self, reason): + def connectionLost(self, reason: Failure) -> None: # type: ignore[override] logger.info("Lost connection to redis") super().connectionLost(reason) self.synapse_handler.lost_connection(self) @@ -193,17 +193,17 @@ def connectionLost(self, reason): # the sentinel context is now active, which may not be correct. # PreserveLoggingContext() will restore the correct logging context. - def send_command(self, cmd: Command): + def send_command(self, cmd: Command) -> None: """Send a command if connection has been established. Args: - cmd (Command) + cmd: The command to send """ run_as_background_process( "send-cmd", self._async_send_command, cmd, bg_start_span=False ) - async def _async_send_command(self, cmd: Command): + async def _async_send_command(self, cmd: Command) -> None: """Encode a replication command and send it over our outbound connection""" string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: @@ -259,7 +259,7 @@ def __init__( hs.get_clock().looping_call(self._send_ping, 30 * 1000) @wrap_as_background_process("redis_ping") - async def _send_ping(self): + async def _send_ping(self) -> None: for connection in self.pool: try: await make_deferred_yieldable(connection.ping()) @@ -269,13 +269,13 @@ async def _send_ping(self): # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but # it's rubbish. We add our own here. - def startedConnecting(self, connector: IConnector): + def startedConnecting(self, connector: IConnector) -> None: logger.info( "Connecting to redis server %s", format_address(connector.getDestination()) ) super().startedConnecting(connector) - def clientConnectionFailed(self, connector: IConnector, reason: Failure): + def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: logger.info( "Connection to redis server %s failed: %s", format_address(connector.getDestination()), @@ -283,7 +283,7 @@ def clientConnectionFailed(self, connector: IConnector, reason: Failure): ) super().clientConnectionFailed(connector, reason) - def clientConnectionLost(self, connector: IConnector, reason: Failure): + def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: logger.info( "Connection to redis server %s lost: %s", format_address(connector.getDestination()), @@ -330,7 +330,7 @@ def __init__( self.synapse_outbound_redis_connection = outbound_redis_connection - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> RedisSubscriber: p = super().buildProtocol(addr) p = cast(RedisSubscriber, p) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index a9d85f4f6cd0..ecd6190f5b9b 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -16,16 +16,18 @@ import logging import random -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional, Tuple from prometheus_client import Counter +from twisted.internet.interfaces import IAddress from twisted.internet.protocol import ServerFactory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.streams import EventsStream +from synapse.replication.tcp.streams._base import StreamRow, Token from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -56,7 +58,7 @@ def __init__(self, hs: "HomeServer"): # listener config again or always starting a `ReplicationStreamer`.) hs.get_replication_streamer() - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol: return ServerReplicationStreamProtocol( self.server_name, self.clock, self.command_handler ) @@ -105,7 +107,7 @@ def __init__(self, hs: "HomeServer"): if any(EventsStream.NAME == s.NAME for s in self.streams): self.clock.looping_call(self.on_notifier_poke, 1000) - def on_notifier_poke(self): + def on_notifier_poke(self) -> None: """Checks if there is actually any new data and sends it to the connections if there are. @@ -137,7 +139,7 @@ def on_notifier_poke(self): run_as_background_process("replication_notifier", self._run_notifier_loop) - async def _run_notifier_loop(self): + async def _run_notifier_loop(self) -> None: self.is_looping = True try: @@ -238,7 +240,9 @@ async def _run_notifier_loop(self): self.is_looping = False -def _batch_updates(updates): +def _batch_updates( + updates: List[Tuple[Token, StreamRow]] +) -> List[Tuple[Optional[Token], StreamRow]]: """Takes a list of updates of form [(token, row)] and sets the token to None for all rows where the next row has the same token. This is used to implement batching. @@ -254,7 +258,7 @@ def _batch_updates(updates): if not updates: return [] - new_updates = [] + new_updates: List[Tuple[Optional[Token], StreamRow]] = [] for i, update in enumerate(updates[:-1]): if update[0] == updates[i + 1][0]: new_updates.append((None, update[1])) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 5a2d90c5309f..914b9eae8492 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -90,7 +90,7 @@ class Stream: ROW_TYPE: Any = None @classmethod - def parse_row(cls, row: StreamRow): + def parse_row(cls, row: StreamRow) -> Any: """Parse a row received over replication By default, assumes that the row data is an array object and passes its contents @@ -139,7 +139,7 @@ def __init__( # The token from which we last asked for updates self.last_token = self.current_token(self.local_instance_name) - def discard_updates_and_advance(self): + def discard_updates_and_advance(self) -> None: """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ @@ -200,7 +200,7 @@ def current_token_without_instance( return lambda instance_name: current_token() -def make_http_update_function(hs, stream_name: str) -> UpdateFunction: +def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction: """Makes a suitable function for use as an `update_function` that queries the master process for updates. """ diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 4f4f1ad45378..50c4a5ba03a1 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import heapq -from collections.abc import Iterable -from typing import TYPE_CHECKING, Optional, Tuple, Type +from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast import attr -from ._base import Stream, StreamUpdateResult, Token +from synapse.replication.tcp.streams._base import ( + Stream, + StreamRow, + StreamUpdateResult, + Token, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -58,6 +62,9 @@ class EventsStreamRow: data: "BaseEventsStreamRow" +T = TypeVar("T", bound="BaseEventsStreamRow") + + class BaseEventsStreamRow: """Base class for rows to be sent in the events stream. @@ -68,7 +75,7 @@ class BaseEventsStreamRow: TypeId: str @classmethod - def from_data(cls, data): + def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T: """Parse the data from the replication stream into a row. By default we just call the constructor with the data list as arguments @@ -221,7 +228,7 @@ async def _update_function( return updates, upper_limit, limited @classmethod - def parse_row(cls, row): - (typ, data) = row - data = TypeToRow[typ].from_data(data) - return EventsStreamRow(typ, data) + def parse_row(cls, row: StreamRow) -> "EventsStreamRow": + (typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row) + event_stream_row_data = TypeToRow[typ].from_data(data) + return EventsStreamRow(typ, event_stream_row_data) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index ea1032b4fcfb..b26546aecdb7 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -16,8 +16,7 @@ import re import secrets import string -from collections.abc import Iterable -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from netaddr import valid_ipv6 @@ -197,7 +196,7 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. - Otherwise, return the stringification of a a list with the first maxitems items, + Otherwise, return the stringification of a list with the first maxitems items, followed by "...". Args: diff --git a/tests/replication/_base.py b/tests/replication/_base.py index cb02eddf0797..9fc50f885227 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -14,6 +14,7 @@ import logging from typing import Any, Dict, List, Optional, Tuple +from twisted.internet.address import IPv4Address from twisted.internet.protocol import Protocol from twisted.web.resource import Resource @@ -53,7 +54,7 @@ def prepare(self, reactor, clock, hs): server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol( - None + IPv4Address("TCP", "127.0.0.1", 0) ) # Make a new HomeServer object for the worker @@ -345,7 +346,9 @@ def make_worker_hs( self.clock, repl_handler, ) - server = self.server_factory.buildProtocol(None) + server = self.server_factory.buildProtocol( + IPv4Address("TCP", "127.0.0.1", 0) + ) client_transport = FakeTransport(server, self.reactor) client.makeConnection(client_transport) diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py index 262c35cef3c6..545f11acd1bf 100644 --- a/tests/replication/tcp/test_remote_server_up.py +++ b/tests/replication/tcp/test_remote_server_up.py @@ -14,6 +14,7 @@ from typing import Tuple +from twisted.internet.address import IPv4Address from twisted.internet.interfaces import IProtocol from twisted.test.proto_helpers import StringTransport @@ -29,7 +30,7 @@ def prepare(self, reactor, clock, hs): def _make_client(self) -> Tuple[IProtocol, StringTransport]: """Create a new direct TCP replication connection""" - proto = self.factory.buildProtocol(("127.0.0.1", 0)) + proto = self.factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 0)) transport = StringTransport() proto.makeConnection(transport)