Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use inline type hints in tests/ #10350

Merged
merged 2 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10350.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.
6 changes: 3 additions & 3 deletions tests/events/test_presence_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_receiving_all_presence(self):
)
self.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
self.assertEqual(presence_update.user_id, self.other_user_one_id)
self.assertEqual(presence_update.state, "online")
self.assertEqual(presence_update.status_msg, "boop")
Expand Down Expand Up @@ -274,7 +274,7 @@ def test_send_local_online_presence_to_with_module(self):
presence_updates, _ = sync_presence(self, self.other_user_id)
self.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
self.assertEqual(presence_update.user_id, self.other_user_id)
self.assertEqual(presence_update.state, "online")
self.assertEqual(presence_update.status_msg, "I'm online!")
Expand Down Expand Up @@ -320,7 +320,7 @@ def test_send_local_online_presence_to_with_module(self):
)
for call in calls:
call_args = call[0]
federation_transaction = call_args[0] # type: Transaction
federation_transaction: Transaction = call_args[0]

# Get the sent EDUs in this transaction
edus = federation_transaction.get_dict()["edus"]
Expand Down
16 changes: 8 additions & 8 deletions tests/module_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def test_sending_events_into_room(self):
"content": content,
"sender": user_id,
}
event = self.get_success(
event: EventBase = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
) # type: EventBase
)
self.assertEqual(event.sender, user_id)
self.assertEqual(event.type, "m.room.message")
self.assertEqual(event.room_id, room_id)
Expand Down Expand Up @@ -136,9 +136,9 @@ def test_sending_events_into_room(self):
"sender": user_id,
"state_key": "",
}
event = self.get_success(
event: EventBase = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
) # type: EventBase
)
self.assertEqual(event.sender, user_id)
self.assertEqual(event.type, "m.room.power_levels")
self.assertEqual(event.room_id, room_id)
Expand Down Expand Up @@ -281,7 +281,7 @@ def test_send_local_online_presence_to_federation(self):
)
for call in calls:
call_args = call[0]
federation_transaction = call_args[0] # type: Transaction
federation_transaction: Transaction = call_args[0]

# Get the sent EDUs in this transaction
edus = federation_transaction.get_dict()["edus"]
Expand Down Expand Up @@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")

Expand Down Expand Up @@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")

Expand All @@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)

presence_update = presence_updates[0] # type: UserPresenceState
presence_update: UserPresenceState = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")

Expand Down
12 changes: 6 additions & 6 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server = server_factory.buildProtocol(
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
None
) # type: ServerReplicationStreamProtocol
)

# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
Expand Down Expand Up @@ -195,7 +195,7 @@ def assert_request_is_get_repl_stream_updates(
fetching updates for given stream.
"""

path = request.path # type: bytes # type: ignore
path: bytes = request.path # type: ignore
self.assertRegex(
path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
Expand All @@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""

servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
servlets: List[Callable[[HomeServer, JsonResource], None]] = []

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -448,7 +448,7 @@ def __init__(self, hs: HomeServer):
super().__init__(hs)

# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
self.received_rdata_rows: List[Tuple[str, int, Any]] = []

async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
Expand Down Expand Up @@ -484,7 +484,7 @@ def buildProtocol(self, addr):
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""

transport = None # type: Optional[FakeTransport]
transport: Optional[FakeTransport] = None

def __init__(self, server: FakeRedisPubSubServer):
self._server = server
Expand Down
14 changes: 7 additions & 7 deletions tests/replication/tcp/streams/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def test_update_function_huge_state_change(self):
)

# this is the point in the DAG where we make a fork
fork_point = self.get_success(
fork_point: List[str] = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
)

events = [
self._inject_state_event(sender=OTHER_USER)
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_update_function_huge_state_change(self):
self.assertEqual(row.data.event_id, pl_event.event_id)

# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
state_rows: List[EventsStreamCurrentStateRow] = []
for stream_name, _, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
Expand Down Expand Up @@ -290,11 +290,11 @@ def test_update_function_state_row_limit(self):
)

# this is the point in the DAG where we make a fork
fork_point = self.get_success(
fork_point: List[str] = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
)

events = [] # type: List[EventBase]
events: List[EventBase] = []
for user in user_ids:
events.extend(
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
Expand Down Expand Up @@ -355,7 +355,7 @@ def test_update_function_state_row_limit(self):
self.assertEqual(row.data.event_id, pl_events[i].event_id)

# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
state_rows: List[EventsStreamCurrentStateRow] = []
for _ in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
Expand Down
4 changes: 2 additions & 2 deletions tests/replication/tcp/streams/test_receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_receipt(self):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_receipt(self):
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))

row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
self.assertEqual("!room2:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
Expand Down
4 changes: 2 additions & 2 deletions tests/replication/tcp/streams/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_typing(self):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_reset(self):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)

Expand Down
2 changes: 1 addition & 1 deletion tests/replication/test_multi_media_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

logger = logging.getLogger(__name__)

test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory]
test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None


class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
Expand Down
4 changes: 2 additions & 2 deletions tests/rest/client/test_third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ def test_send_event(self):
"content": content,
"sender": self.user_id,
}
event = self.get_success(
event: EventBase = self.get_success(
current_rules_module().module_api.create_and_send_event_into_room(
event_dict
)
) # type: EventBase
)

self.assertEquals(event.sender, self.user_id)
self.assertEquals(event.room_id, self.room_id)
Expand Down
14 changes: 6 additions & 8 deletions tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def test_get_msc2858_login_flows(self):
self.assertEqual(channel.code, 200, channel.result)

# stick the flows results in a dict by type
flow_results = {} # type: Dict[str, Any]
flow_results: Dict[str, Any] = {}
for f in channel.json_body["flows"]:
flow_type = f["type"]
self.assertNotIn(
Expand Down Expand Up @@ -501,7 +501,7 @@ def test_multi_sso_redirect(self):
p.close()

# there should be a link for each href
returned_idps = [] # type: List[str]
returned_idps: List[str] = []
for link in p.links:
path, query = link.split("?", 1)
self.assertEqual(path, "pick_idp")
Expand Down Expand Up @@ -582,7 +582,7 @@ def test_login_via_oidc(self):
# ... and should have set a cookie including the redirect url
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
assert cookie_headers
cookies = {} # type: Dict[str, str]
cookies: Dict[str, str] = {}
for h in cookie_headers:
key, value = h.split(";")[0].split("=", maxsplit=1)
cookies[key] = value
Expand Down Expand Up @@ -874,9 +874,7 @@ def make_homeserver(self, reactor, clock):

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result = jwt.encode(
payload, secret, self.jwt_algorithm
) # type: Union[str, bytes]
result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
if isinstance(result, bytes):
return result.decode("ascii")
return result
Expand Down Expand Up @@ -1084,7 +1082,7 @@ def make_homeserver(self, reactor, clock):

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
if isinstance(result, bytes):
return result.decode("ascii")
return result
Expand Down Expand Up @@ -1272,7 +1270,7 @@ def test_username_picker(self):
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")

# ... with a username_mapping_session cookie
cookies = {} # type: Dict[str,str]
cookies: Dict[str, str] = {}
channel.extract_cookies(cookies)
self.assertIn("username_mapping_session", cookies)
session_id = cookies["username_mapping_session"]
Expand Down
8 changes: 5 additions & 3 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class FakeChannel:
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
_ip = attr.ib(type=str, default="127.0.0.1")
_producer = None # type: Optional[Union[IPullProducer, IPushProducer]]
_producer: Optional[Union[IPullProducer, IPushProducer]] = None

@property
def json_body(self):
Expand Down Expand Up @@ -316,8 +316,10 @@ def __init__(self):

self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {} # type: Dict[str, str]
self._thread_callbacks = deque() # type: Deque[Callable[[], None]]
self.lookups: Dict[str, str] = {}
self._thread_callbacks: Deque[Callable[[], None]] = deque()

lookups = self.lookups

@implementer(IResolverSimple)
class FakeResolver:
Expand Down
4 changes: 1 addition & 3 deletions tests/storage/test_background_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.updates = (
self.hs.get_datastore().db_pool.updates
) # type: BackgroundUpdater
self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
Expand Down
6 changes: 3 additions & 3 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):

def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.db_pool: DatabasePool = self.store.db_pool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

Expand Down Expand Up @@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):

def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.db_pool: DatabasePool = self.store.db_pool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

Expand Down Expand Up @@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):

def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.db_pool: DatabasePool = self.store.db_pool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_branch_no_conflict(self):

self.store.register_events(graph.walk())

context_store = {} # type: dict[str, EventContext]
context_store: dict[str, EventContext] = {}

for event in graph.walk():
context = yield defer.ensureDeferred(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils/html_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def __init__(self):
super().__init__()

# a list of links found in the doc
self.links = [] # type: List[str]
self.links: List[str] = []

# the values of any hidden <input>s: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]
self.hiddens: Dict[str, Optional[str]] = {}

# the values of any radio buttons: map from name to list of values
self.radios = {} # type: Dict[str, List[Optional[str]]]
self.radios: Dict[str, List[Optional[str]]] = {}

def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def get_success_or_raise(self, d, by=0.0):
if not isinstance(deferred, Deferred):
return d

results = [] # type: list
results: list = []
deferred.addBoth(results.append)

self.pump(by=by)
Expand Down
Loading