Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to the admin API servlets #10105

Merged
merged 3 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10105.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to the admin API servlets.
45 changes: 24 additions & 21 deletions synapse/rest/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

import logging
import platform
from typing import TYPE_CHECKING, Optional, Tuple

import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
Expand Down Expand Up @@ -66,22 +68,25 @@
UserTokenRestServlet,
WhoisRestServlet,
)
from synapse.types import RoomStreamToken
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.versionstring import get_version_string

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class VersionServlet(RestServlet):
PATTERNS = admin_patterns("/server_version$")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.res = {
"server_version": get_version_string(synapse),
"python_version": platform.python_version(),
}

def on_GET(self, request):
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, self.res


Expand All @@ -90,17 +95,14 @@ class PurgeHistoryRestServlet(RestServlet):
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)

def __init__(self, hs):
"""

Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs: "HomeServer"):
self.pagination_handler = hs.get_pagination_handler()
self.store = hs.get_datastore()
self.auth = hs.get_auth()

async def on_POST(self, request, room_id, event_id):
async def on_POST(
self, request: SynapseRequest, room_id: str, event_id: Optional[str]
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

body = parse_json_object_from_request(request, allow_empty_body=True)
Expand All @@ -119,6 +121,8 @@ async def on_POST(self, request, room_id, event_id):
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")

# RoomStreamToken expects [int] not Optional[int]
assert event.internal_metadata.stream_ordering is not None
room_token = RoomStreamToken(
event.depth, event.internal_metadata.stream_ordering
)
Expand Down Expand Up @@ -173,16 +177,13 @@ async def on_POST(self, request, room_id, event_id):
class PurgeHistoryStatusRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")

def __init__(self, hs):
"""

Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs: "HomeServer"):
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()

async def on_GET(self, request, purge_id):
async def on_GET(
self, request: SynapseRequest, purge_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

purge_status = self.pagination_handler.get_purge_status(purge_id)
Expand All @@ -203,12 +204,12 @@ async def on_GET(self, request, purge_id):
class AdminRestResource(JsonResource):
"""The REST resource which gets mounted at /_synapse/admin"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False)
register_servlets(hs, self)


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
"""
Register all the admin servlets.
"""
Expand Down Expand Up @@ -242,7 +243,9 @@ def register_servlets(hs, http_server):
RateLimitRestServlet(hs).register(http_server)


def register_servlets_for_client_rest_resource(hs, http_server):
def register_servlets_for_client_rest_resource(
hs: "HomeServer", http_server: HttpServer
) -> None:
"""Register only the servlets which need to be exposed on /_matrix/client/xxx"""
WhoisRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/admin/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

import re
from typing import Iterable, Pattern

from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.http.site import SynapseRequest
from synapse.types import UserID


def admin_patterns(path_regex: str, version: str = "v1"):
def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]:
"""Returns the list of patterns for an admin endpoint

Args:
Expand Down
12 changes: 10 additions & 2 deletions synapse/rest/admin/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

Expand All @@ -25,12 +31,14 @@ class DeleteGroupAdminRestServlet(RestServlet):

PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.group_server = hs.get_groups_server_handler()
self.is_mine_id = hs.is_mine_id
self.auth = hs.get_auth()

async def on_POST(self, request, group_id):
async def on_POST(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)

Expand Down
12 changes: 6 additions & 6 deletions synapse/rest/admin/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import TYPE_CHECKING, Tuple

from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
Expand All @@ -37,12 +38,11 @@ class QuarantineMediaInRoom(RestServlet):
this server.
"""

PATTERNS = (
admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+
PATTERNS = [
*admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine"),
# This path kept around for legacy reasons
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
*admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"),
]

def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
Expand Down Expand Up @@ -283,7 +283,7 @@ async def on_POST(
return 200, {"deleted_media": deleted_media, "total": total}


def register_servlets_for_media_repo(hs: "HomeServer", http_server):
def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None:
"""
Media repo specific APIs.
"""
Expand Down
15 changes: 5 additions & 10 deletions synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,12 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

class WhoisRestServlet(RestServlet):
path_regex = "/whois/(?P<user_id>[^/]*)$"
PATTERNS = (
admin_patterns(path_regex)
+
PATTERNS = [
*admin_patterns(path_regex),
# URL for spec reason
# https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid
client_patterns("/admin" + path_regex, v1=True)
)
*client_patterns("/admin" + path_regex, v1=True),
]

def __init__(self, hs: "HomeServer"):
self.hs = hs
Expand Down Expand Up @@ -553,11 +552,7 @@ async def on_POST(
class AccountValidityRenewServlet(RestServlet):
PATTERNS = admin_patterns("/account_validity/validity$")

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
Expand Down