Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to tests.replication #14987

Merged
merged 6 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14987.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ disallow_untyped_defs = True
[mypy-tests.push.*]
disallow_untyped_defs = True

[mypy-tests.replication.*]
disallow_untyped_defs = True

[mypy-tests.rest.*]
disallow_untyped_defs = True

Expand Down
70 changes: 40 additions & 30 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from typing import Any, Dict, List, Optional, Set, Tuple

from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
from twisted.internet.protocol import Protocol, connectionDone
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource

from synapse.app.generic_worker import GenericWorkerServer
Expand All @@ -30,6 +32,7 @@
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock

from tests import unittest
from tests.server import FakeTransport
Expand All @@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
Expand Down Expand Up @@ -92,8 +95,8 @@ def prepare(self, reactor, clock, hs):
repl_handler,
)

self._client_transport = None
self._server_transport = None
self._client_transport: Optional[FakeTransport] = None
self._server_transport: Optional[FakeTransport] = None

def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
Expand All @@ -107,10 +110,10 @@ def _get_worker_hs_config(self) -> dict:
config["worker_replication_http_port"] = "8765"
return config

def _build_replication_data_handler(self):
def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
return TestReplicationDataHandler(self.worker_hs)

def reconnect(self):
def reconnect(self) -> None:
if self._client_transport:
self.client.close()

Expand All @@ -123,7 +126,7 @@ def reconnect(self):
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)

def disconnect(self):
def disconnect(self) -> None:
if self._client_transport:
self._client_transport = None
self.client.close()
Expand All @@ -132,7 +135,7 @@ def disconnect(self):
self._server_transport = None
self.server.close()

def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
Expand Down Expand Up @@ -168,7 +171,7 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory

def request_factory(*args, **kwargs):
def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
Expand Down Expand Up @@ -202,7 +205,7 @@ def request_factory(*args, **kwargs):

def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
):
) -> None:
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
Expand Down Expand Up @@ -244,7 +247,7 @@ def default_config(self) -> Dict[str, Any]:
base["redis"] = {"enabled": True}
return base

def setUp(self):
def setUp(self) -> None:
super().setUp()

# build a replication server
Expand Down Expand Up @@ -287,7 +290,7 @@ def setUp(self):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)

def create_test_resource(self):
def create_test_resource(self) -> ReplicationRestResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
Expand All @@ -301,7 +304,7 @@ def create_test_resource(self):
return resource

def make_worker_hs(
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
Expand Down Expand Up @@ -385,14 +388,14 @@ def _get_worker_hs_config(self) -> dict:
config["worker_replication_http_port"] = "8765"
return config

def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()

def _handle_http_replication_attempt(self, hs, repl_port):
def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
"""Handles a connection attempt to the given HS replication HTTP
listener on the given port.
"""
Expand Down Expand Up @@ -429,7 +432,7 @@ def _handle_http_replication_attempt(self, hs, repl_port):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.

def connect_any_redis_attempts(self):
def connect_any_redis_attempts(self) -> None:
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
Expand All @@ -440,8 +443,11 @@ def connect_any_redis_attempts(self):
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)

client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 6379)
client_protocol = client_factory.buildProtocol(client_address)

server_address = IPv4Address("TCP", host, port)
server_protocol = self._redis_server.buildProtocol(server_address)
Comment on lines -443 to +450
Copy link
Contributor

@DMRobertson DMRobertson Feb 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went for a different approach in #14988 (comment) of passing in a dummy address. I should probably just go for something simpler like this.


client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
Expand All @@ -463,7 +469,9 @@ def __init__(self, hs: HomeServer):
# list of received (stream_name, token, row) tuples
self.received_rdata_rows: List[Tuple[str, int, Any]] = []

async def on_rdata(self, stream_name, instance_name, token, rows):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
) -> None:
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
Expand All @@ -472,28 +480,30 @@ async def on_rdata(self, stream_name, instance_name, token, rows):
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""

def __init__(self):
def __init__(self) -> None:
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)

def add_subscriber(self, conn, channel: bytes):
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
self._subscribers_by_channel[channel].add(conn)

def remove_subscriber(self, conn):
def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)

def publish(self, conn, channel: bytes, msg) -> int:
def publish(
self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])

return len(self._subscribers_by_channel)

def buildProtocol(self, addr):
def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
return FakeRedisPubSubProtocol(self)


Expand All @@ -506,7 +516,7 @@ def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()

def dataReceived(self, data):
def dataReceived(self, data: bytes) -> None:
self._reader.feed(data)

# We might get multiple messages in one packet.
Expand All @@ -523,7 +533,7 @@ def dataReceived(self, data):

self.handle_command(msg[0], *msg[1:])

def handle_command(self, command, *args):
def handle_command(self, command: bytes, *args: bytes) -> None:
"""Received a Redis command from the client."""

# We currently only support pub/sub.
Expand All @@ -548,9 +558,9 @@ def handle_command(self, command, *args):
self.send("PONG")

else:
raise Exception(f"Unknown command: {command}")
raise Exception(f"Unknown command: {command!r}")

def send(self, msg):
def send(self, msg: object) -> None:
"""Send a message back to the client."""
assert self.transport is not None

Expand All @@ -559,7 +569,7 @@ def send(self, msg):
self.transport.write(raw)
self.transport.flush()

def encode(self, obj):
def encode(self, obj: object) -> str:
"""Encode an object to its Redis format.

Supports: strings/bytes, integers and list/tuples.
Expand All @@ -581,5 +591,5 @@ def encode(self, obj):

raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)

def connectionLost(self, reason):
def connectionLost(self, reason: Failure = connectionDone) -> None:
self._server.remove_subscriber(self)
2 changes: 1 addition & 1 deletion tests/replication/http/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def _handle_request( # type: ignore[override]
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""

def create_test_resource(self):
def create_test_resource(self) -> JsonResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs)

Expand Down
25 changes: 16 additions & 9 deletions tests/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterable, Optional
from unittest.mock import Mock

from tests.replication._base import BaseStreamTestCase
from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.util import Clock

class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock):
from tests.replication._base import BaseStreamTestCase

hs = self.setup_test_homeserver(federation_client=Mock())

return hs
class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)

self.reconnect()

self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
persistence = hs.get_storage_controllers().persistence
assert persistence is not None
self.persistance = persistence

def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump(0.1)

def check(self, method, args, expected_result=None):
def check(
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
) -> None:
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
Expand Down
Loading