Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use inline type hints in handlers/ and rest/ #10382

Merged
merged 5 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10382.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.
8 changes: 4 additions & 4 deletions synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ class BaseHandler:
"""

def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.store: synapse.storage.DataStore = hs.get_datastore()
ShadowJonathan marked this conversation as resolved.
Show resolved Hide resolved
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
self.state_handler: synapse.state.StateHandler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs
Expand All @@ -55,12 +55,12 @@ def __init__(self, hs: "HomeServer"):
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
) # type: Optional[Ratelimiter]
)
else:
self.admin_redaction_ratelimiter = None

Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->
to_key = RoomStreamToken(None, stream_ordering)

# Events that we've processed in this room
written_events = set() # type: Set[str]
written_events: Set[str] = set()

# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
Expand All @@ -152,7 +152,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
# events "children".
unseen_to_child_events = {} # type: Dict[str, Set[str]]
unseen_to_child_events: Dict[str, Set[str]] = {}

# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def _notify_interested_services(self, max_token: RoomStreamToken):
self.current_max, limit
)

events_by_room = {} # type: Dict[str, List[EventBase]]
events_by_room: Dict[str, List[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)

Expand Down Expand Up @@ -275,7 +275,7 @@ async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
events = [] # type: List[JsonDict]
events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
Expand Down Expand Up @@ -375,7 +375,7 @@ async def get_3pe_protocols(
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]]
protocols: Dict[str, List[JsonDict]] = {}

# Collect up all the individual protocol responses out of the ASes
for s in services:
Expand Down
16 changes: 8 additions & 8 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
Expand Down Expand Up @@ -296,7 +296,7 @@ def __init__(self, hs: "HomeServer"):

# A mapping of user ID to extra attributes to include in the login
# response.
self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}

async def validate_user_via_ui_auth(
self,
Expand Down Expand Up @@ -500,7 +500,7 @@ async def check_ui_auth(
all the stages in any of the permitted flows.
"""

sid = None # type: Optional[str]
sid: Optional[str] = None
authdict = clientdict.pop("auth", {})
if "session" in authdict:
sid = authdict["session"]
Expand Down Expand Up @@ -588,9 +588,9 @@ async def check_ui_auth(
)

# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
errordict: Dict[str, Any] = {}
if "type" in authdict:
login_type = authdict["type"] # type: str
login_type: str = authdict["type"]
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
Expand Down Expand Up @@ -766,7 +766,7 @@ def _auth_dict_for_flows(
LoginType.TERMS: self._get_params_terms,
}

params = {} # type: Dict[str, Any]
params: Dict[str, Any] = {}

for f in public_flows:
for stage in f:
Expand Down Expand Up @@ -1530,9 +1530,9 @@ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> s
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

user_id_to_verify = await self.get_session_data(
user_id_to_verify: str = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str
)

idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:

# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {} # type: Dict[str, List[Optional[str]]]
attributes: Dict[str, List[Optional[str]]] = {}
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
Expand Down
14 changes: 7 additions & 7 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ async def notify_device_update(
user_id
)

hosts = set() # type: Set[str]
hosts: Set[str] = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
Expand Down Expand Up @@ -613,20 +613,20 @@ def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
self._remote_edu_linearizer = Linearizer(name="remote_device_list")

# user_id -> list of updates waiting to be handled.
self._pending_updates = (
{}
) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
self._pending_updates: Dict[
str, List[Tuple[str, str, Iterable[str], JsonDict]]
] = {}

# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
cache_name="device_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
) # type: ExpiringCache[str, Set[str]]
)

# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
Expand Down Expand Up @@ -755,7 +755,7 @@ async def _need_to_do_resync(
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
seen_updates: Set[str] = self._seen_updates.get(user_id, set())

extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ async def send_device_message(
log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ async def _delete_association(self, room_alias: RoomAlias) -> str:
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None
if self.hs.is_mine(room_alias):
result = await self.get_association_from_room_alias(
room_alias
) # type: Optional[RoomAliasMapping]
result: Optional[
RoomAliasMapping
] = await self.get_association_from_room_alias(room_alias)

if result:
room_id = result.room_id
Expand Down
40 changes: 19 additions & 21 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ async def query_devices(
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get(
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
)

# separate users by domain.
# make a map from domain to user_id to device_ids
Expand All @@ -136,7 +136,7 @@ async def query_devices(

# First get local devices.
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
failures: Dict[str, JsonDict] = {}
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
Expand All @@ -151,11 +151,9 @@ async def query_devices(

# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = (
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
query_list: List[Tuple[str, Optional[str]]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
Expand Down Expand Up @@ -362,9 +360,9 @@ async def query_local_devices(
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
local_query = [] # type: List[Tuple[str, Optional[str]]]
local_query: List[Tuple[str, Optional[str]]] = []

result_dict = {} # type: Dict[str, Dict[str, dict]]
result_dict: Dict[str, Dict[str, dict]] = {}
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
Expand Down Expand Up @@ -402,9 +400,9 @@ async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
"""Handle a device key query from a federated server"""
device_keys_query = query_body.get(
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
) # type: Dict[str, Optional[List[str]]]
)
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}

Expand All @@ -421,8 +419,8 @@ async def on_federation_query_client_keys(
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
) -> JsonDict:
local_query = [] # type: List[Tuple[str, str, str]]
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}

for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
Expand All @@ -439,8 +437,8 @@ async def claim_one_time_keys(
results = await self.store.claim_e2e_one_time_keys(local_query)

# A map of user ID -> device ID -> key ID -> key.
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
failures = {} # type: Dict[str, JsonDict]
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
failures: Dict[str, JsonDict] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
Expand Down Expand Up @@ -768,8 +766,8 @@ async def _process_self_signatures(
Raises:
SynapseError: if the input is malformed
"""
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
signature_list: List["SignatureListItem"] = []
failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures

Expand Down Expand Up @@ -930,8 +928,8 @@ async def _process_other_signatures(
Raises:
SynapseError: if the input is malformed
"""
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
signature_list: List["SignatureListItem"] = []
failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures

Expand Down Expand Up @@ -1300,7 +1298,7 @@ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")

# user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}

async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
Expand Down Expand Up @@ -1349,7 +1347,7 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
# This can happen since we batch updates
return

device_ids = [] # type: List[str]
device_ids: List[str] = []

logger.info("pending updates: %r", pending_updates)

Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def get_stream(

# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = [] # type: List[JsonDict]
to_add: List[JsonDict] = []
for event in events:
if not isinstance(event, EventBase):
continue
Expand All @@ -103,9 +103,9 @@ async def get_stream(
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = await self.store.get_users_in_room(
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
) # type: Iterable[str]
)
else:
users = [event.state_key]

Expand Down
Loading