Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Make it possible to use dmypy #9692

Merged
merged 2 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9692.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make it possible to use `dmypy`.
3 changes: 2 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
follow_imports = silent
follow_imports = normal
check_untyped_defs = True
show_error_codes = True
show_traceback = True
mypy_path = stubs
warn_unreachable = True
local_partial_types = True

# To find all folders that pass mypy you run:
#
Expand Down
5 changes: 5 additions & 0 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,9 @@ def has_access_token(request: Request):
Returns:
bool: False if no access_token was given, True otherwise.
"""
# This will always be set by the time Twisted calls us.
assert request.args is not None

query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
Expand All @@ -574,6 +577,8 @@ def get_access_token_from_request(request: Request):
MissingClientTokenError: If there isn't a single access_token in the
request
"""
# This will always be set by the time Twisted calls us.
assert request.args is not None

auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token")
Expand Down
6 changes: 4 additions & 2 deletions synapse/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"

# Map from canonicalised cache name to cache.
_CACHES = {}
_CACHES = {} # type: Dict[str, Callable[[float], None]]

# a lock on the contents of _CACHES
_CACHES_LOCK = threading.Lock()
Expand Down Expand Up @@ -59,7 +59,9 @@ def _canonicalise_cache_name(cache_name: str) -> str:
return cache_name.lower()


def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
def add_resizable_cache(
cache_name: str, cache_resize_callback: Callable[[float], None]
):
"""Register a cache that's size can dynamically change

Args:
Expand Down
3 changes: 3 additions & 0 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
Args:
request: the incoming request from the browser.
"""
# This will always be set by the time Twisted calls us.
assert request.args is not None

# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
Expand Down
2 changes: 1 addition & 1 deletion synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def report_span(self, span):
# Block everything by default
# A regex which matches the server_names to expose traces for.
# None means 'block everything'.
_homeserver_whitelist = None
_homeserver_whitelist = None # type: Optional[re.Pattern[str]]

# Util methods

Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/tcp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@

# A list of all connected protocols. This allows us to send metrics about the
# connections.
connected_connections = []
connected_connections = [] # type: List[BaseReplicationStreamProtocol]


logger = logging.getLogger(__name__)
Expand Down
3 changes: 3 additions & 0 deletions synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def __init__(self, hs: "HomeServer"):
async def on_POST(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
assert request.args is not None

requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)

Expand Down
3 changes: 3 additions & 0 deletions synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,9 @@ def __init__(self, hs: "HomeServer"):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
assert request.args is not None

await assert_requester_is_admin(self.auth, request)

if not self.is_mine(UserID.from_string(user_id)):
Expand Down
3 changes: 3 additions & 0 deletions synapse/rest/client/v2_alpha/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def __init__(self, hs: "HomeServer"):
self._event_serializer = hs.get_event_client_serializer()

async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
assert request.args is not None

if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'.
Expand Down
2 changes: 2 additions & 0 deletions synapse/rest/media/v1/preview_url_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)

async def _async_render_GET(self, request: SynapseRequest) -> None:
# This will always be set by the time Twisted calls us.
assert request.args is not None

# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
Expand Down
3 changes: 3 additions & 0 deletions synapse/rest/synapse/client/pick_username.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ async def _async_render_GET(self, request: Request) -> None:
respond_with_html(request, 200, html)

async def _async_render_POST(self, request: SynapseRequest):
# This will always be set by the time Twisted calls us.
assert request.args is not None

try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:
Expand Down
4 changes: 2 additions & 2 deletions synapse/util/caches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

logger = logging.getLogger(__name__)

caches_by_name = {}
collectors_by_name = {} # type: Dict
caches_by_name = {} # type: Dict[str, Sized]
collectors_by_name = {} # type: Dict[str, CacheMetric]

cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
Expand Down
1 change: 1 addition & 0 deletions tests/replication/tcp/streams/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_typing(self):
self.assert_request_is_get_repl_stream_updates(request, "typing")

# The from token should be the token from the last RDATA we got.
assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token)

self.test_handler.on_rdata.assert_called_once()
Expand Down
4 changes: 2 additions & 2 deletions tests/replication/test_multi_media_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import os
from binascii import unhexlify
from typing import Tuple
from typing import Optional, Tuple

from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
Expand All @@ -32,7 +32,7 @@

logger = logging.getLogger(__name__)

test_server_connection_factory = None
test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory]


class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
Expand Down
28 changes: 20 additions & 8 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import deque
from io import SEEK_END, BytesIO
from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union

import attr
from typing_extensions import Deque
Expand All @@ -13,8 +13,11 @@
from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IHostnameResolver,
IProtocol,
IPullProducer,
IPushProducer,
IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple,
ITransport,
)
Expand Down Expand Up @@ -45,11 +48,11 @@ class FakeChannel:
wire).
"""

site = attr.ib(type=Site)
site = attr.ib(type=Union[Site, "FakeSite"])
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
_ip = attr.ib(type=str, default="127.0.0.1")
_producer = None
_producer = None # type: Optional[Union[IPullProducer, IPushProducer]]

@property
def json_body(self):
Expand Down Expand Up @@ -159,7 +162,11 @@ def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:

Any cookines found are added to the given dict
"""
for h in self.headers.getRawHeaders("Set-Cookie"):
headers = self.headers.getRawHeaders("Set-Cookie")
if not headers:
return

for h in headers:
parts = h.split(";")
k, v = parts[0].split("=", maxsplit=1)
cookies[k] = v
Expand Down Expand Up @@ -311,8 +318,8 @@ def __init__(self):

self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {}
self._thread_callbacks = deque() # type: Deque[Callable[[], None]]()
lookups = self.lookups = {} # type: Dict[str, str]
self._thread_callbacks = deque() # type: Deque[Callable[[], None]]

@implementer(IResolverSimple)
class FakeResolver:
Expand All @@ -324,6 +331,9 @@ def getHostByName(self, name, timeout=None):
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super().__init__()

def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()

def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
Expand Down Expand Up @@ -621,7 +631,9 @@ def flush(self, maxbytes=None):
self.disconnected = True


def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
def connect_client(
reactor: ThreadedMemoryReactorClock, client_id: int
) -> Tuple[IProtocol, AccumulatingProtocol]:
"""
Connect a client to a fake TCP transport.

Expand Down