Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Additional type hints for REST servlets (part 2) #10674

Merged
merged 15 commits into from
Aug 26, 2021
1 change: 1 addition & 0 deletions changelog.d/10674.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to REST servlets.
5 changes: 5 additions & 0 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ async def send_full_presence_to_users(self, user_ids: Collection[str]):
# otherwise would not do).
await self.set_state(UserID.from_string(user_id), state, force_notify=True)

async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
raise NotImplementedError(
"Attempting to check presence on a non-presence worker."
)


class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""
Expand Down
11 changes: 7 additions & 4 deletions synapse/rest/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import logging
from typing import TYPE_CHECKING

from twisted.web.server import Request

from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import respond_with_html
from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest

from ._base import client_patterns

Expand Down Expand Up @@ -49,7 +52,7 @@ def __init__(self, hs: "HomeServer"):
self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template

async def on_GET(self, request, stagetype):
async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
Expand Down Expand Up @@ -88,7 +91,7 @@ async def on_GET(self, request, stagetype):
respond_with_html(request, 200, html)
return None

async def on_POST(self, request, stagetype):
async def on_POST(self, request: Request, stagetype: str) -> None:

session = parse_string(request, "session")
if not session:
Expand Down Expand Up @@ -172,5 +175,5 @@ async def on_POST(self, request, stagetype):
return None


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AuthRestServlet(hs).register(http_server)
48 changes: 26 additions & 22 deletions synapse/rest/client/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,36 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api import errors
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict

from ._base import client_patterns, interactive_auth_handler

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class DevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/devices$")

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()

async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
Expand All @@ -57,15 +59,15 @@ class DeleteDevicesRestServlet(RestServlet):

PATTERNS = client_patterns("/delete_devices")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()

@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

try:
Expand Down Expand Up @@ -100,26 +102,26 @@ async def on_POST(self, request):
class DeviceRestServlet(RestServlet):
PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()

async def on_GET(self, request, device_id):
async def on_GET(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
device = await self.device_handler.get_device(
requester.user.to_string(), device_id
)
return 200, device

@interactive_auth_handler
async def on_DELETE(self, request, device_id):
async def on_DELETE(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

try:
Expand All @@ -146,7 +148,9 @@ async def on_DELETE(self, request, device_id):
await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {}

async def on_PUT(self, request, device_id):
async def on_PUT(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)

body = parse_json_object_from_request(request)
Expand Down Expand Up @@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet):

PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()

async def on_GET(self, request: SynapseRequest):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
dehydrated_device = await self.device_handler.get_dehydrated_device(
requester.user.to_string()
Expand All @@ -211,7 +215,7 @@ async def on_GET(self, request: SynapseRequest):
else:
raise errors.NotFoundError("No dehydrated device available")

async def on_PUT(self, request: SynapseRequest):
async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_json_object_from_request(request)
requester = await self.auth.get_user_by_req(request)

Expand Down Expand Up @@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet):
"/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()

async def on_POST(self, request: SynapseRequest):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

submission = parse_json_object_from_request(request)
Expand All @@ -292,7 +296,7 @@ async def on_POST(self, request: SynapseRequest):
return (200, result)


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
Expand Down
38 changes: 23 additions & 15 deletions synapse/rest/client/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@

"""This module contains REST servlets to do with event streaming, /events."""
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple, Union

from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

Expand All @@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet):

DEFAULT_LONGPOLL_TIME_MS = 30000

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()

async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
room_id = None
args: Dict[bytes, List[bytes]] = request.args # type: ignore
if is_guest:
if b"room_id" not in request.args:
if b"room_id" not in args:
raise SynapseError(400, "Guest users must specify room_id param")
if b"room_id" in request.args:
room_id = request.args[b"room_id"][0].decode("ascii")
room_id = parse_string(request, "room_id")

pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in request.args:
if b"timeout" in args:
try:
timeout = int(request.args[b"timeout"][0])
timeout = int(args[b"timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")

as_client_event = b"raw" not in request.args
as_client_event = b"raw" not in args

chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(),
Expand All @@ -70,25 +76,27 @@ async def on_GET(self, request):
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer()

async def on_GET(self, request, event_id):
async def on_GET(
self, request: SynapseRequest, event_id: str
) -> Tuple[int, Union[str, JsonDict]]:
requester = await self.auth.get_user_by_req(request)
event = await self.event_handler.get_event(requester.user, None, event_id)

time_now = self.clock.time_msec()
if event:
event = await self._event_serializer.serialize_event(event, time_now)
return 200, event
result = await self._event_serializer.serialize_event(event, time_now)
return 200, result
else:
return 404, "Event not found."


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EventStreamRestServlet(hs).register(http_server)
EventRestServlet(hs).register(http_server)
26 changes: 18 additions & 8 deletions synapse/rest/client/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,34 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID

from ._base import client_patterns, set_timeline_upper_limit

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()

async def on_GET(self, request, user_id, filter_id):
async def on_GET(
self, request: SynapseRequest, user_id: str, filter_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)

Expand All @@ -43,13 +51,13 @@ async def on_GET(self, request, user_id, filter_id):
raise AuthError(403, "Can only get filters for local users")

try:
filter_id = int(filter_id)
filter_id_int = int(filter_id)
except Exception:
raise SynapseError(400, "Invalid filter_id")

try:
filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id
user_localpart=target_user.localpart, filter_id=filter_id_int
)
except StoreError as e:
if e.code != 404:
Expand All @@ -62,13 +70,15 @@ async def on_GET(self, request, user_id, filter_id):
class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()

async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:

target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
Expand All @@ -89,6 +99,6 @@ async def on_POST(self, request, user_id):
return 200, {"filter_id": str(filter_id)}


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)
Loading