Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to profile and base handlers. (#8609)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Oct 21, 2020
1 parent 9e0f228 commit de5cafe
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 41 deletions.
1 change: 1 addition & 0 deletions changelog.d/8609.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to profile and base handler.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ files =
synapse/events/builder.py,
synapse/events/spamcheck.py,
synapse/federation,
synapse/handlers/appservice.py,
synapse/handlers/_base.py,
synapse/handlers/account_data.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py,
Expand All @@ -32,6 +33,7 @@ files =
synapse/handlers/pagination.py,
synapse/handlers/password_policy.py,
synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
synapse/handlers/room.py,
synapse/handlers/room_member.py,
Expand Down
20 changes: 10 additions & 10 deletions synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Optional

import synapse.state
import synapse.storage
Expand All @@ -22,6 +23,9 @@
from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)


Expand All @@ -30,11 +34,7 @@ class BaseHandler:
Common base class for the event handlers.
"""

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
Expand All @@ -56,7 +56,7 @@ def __init__(self, hs):
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
)
) # type: Optional[Ratelimiter]
else:
self.admin_redaction_ratelimiter = None

Expand Down Expand Up @@ -127,15 +127,15 @@ async def maybe_kick_guest_users(self, event, context=None):
if guest_access != "can_join":
if context:
current_state_ids = await context.get_current_state_ids()
current_state = await self.store.get_events(
current_state_dict = await self.store.get_events(
list(current_state_ids.values())
)
current_state = list(current_state_dict.values())
else:
current_state = await self.state_handler.get_current_state(
current_state_map = await self.state_handler.get_current_state(
event.room_id
)

current_state = list(current_state.values())
current_state = list(current_state_map.values())

logger.info("maybe_kick_guest_users %r", current_state)
await self.kick_guest_users(current_state)
Expand Down
8 changes: 6 additions & 2 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ async def room_initial_sync(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
# The member_event_id will always be available if membership is set
# to leave.
assert member_event_id

result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
Expand All @@ -315,7 +319,7 @@ async def _room_initial_sync_parted(
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
membership: str,
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
Expand Down Expand Up @@ -367,7 +371,7 @@ async def _room_initial_sync_joined(
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
membership: str,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
Expand Down
74 changes: 49 additions & 25 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import (
AuthError,
Expand All @@ -25,10 +25,19 @@
SynapseError,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID, create_requester, get_domain_from_id
from synapse.types import (
JsonDict,
Requester,
UserID,
create_requester,
get_domain_from_id,
)

from ._base import BaseHandler

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

MAX_DISPLAYNAME_LEN = 256
Expand All @@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.federation = hs.get_federation_client()
Expand All @@ -60,7 +69,7 @@ def __init__(self, hs):
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)

async def get_profile(self, user_id):
async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)

if self.hs.is_mine(target_user):
Expand Down Expand Up @@ -91,7 +100,7 @@ async def get_profile(self, user_id):
except HttpResponseException as e:
raise e.to_synapse_error()

async def get_profile_from_cache(self, user_id):
async def get_profile_from_cache(self, user_id: str) -> JsonDict:
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
Expand All @@ -115,7 +124,7 @@ async def get_profile_from_cache(self, user_id):
profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}

async def get_displayname(self, target_user):
async def get_displayname(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
displayname = await self.store.get_profile_displayname(
Expand Down Expand Up @@ -143,15 +152,19 @@ async def get_displayname(self, target_user):
return result["displayname"]

async def set_displayname(
self, target_user, requester, new_displayname, by_admin=False
):
self,
target_user: UserID,
requester: Requester,
new_displayname: str,
by_admin: bool = False,
) -> None:
"""Set the displayname of a user
Args:
target_user (UserID): the user whose displayname is to be changed.
requester (Requester): The user attempting to make this change.
new_displayname (str): The displayname to give this user.
by_admin (bool): Whether this change was made by an administrator.
target_user: the user whose displayname is to be changed.
requester: The user attempting to make this change.
new_displayname: The displayname to give this user.
by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
Expand All @@ -176,16 +189,19 @@ async def set_displayname(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)

displayname_to_set = new_displayname # type: Optional[str]
if new_displayname == "":
new_displayname = None
displayname_to_set = None

# If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
requester = create_requester(target_user)

await self.store.set_profile_displayname(target_user.localpart, new_displayname)
await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
)

if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
Expand All @@ -195,7 +211,7 @@ async def set_displayname(

await self._update_join_states(requester, target_user)

async def get_avatar_url(self, target_user):
async def get_avatar_url(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
avatar_url = await self.store.get_profile_avatar_url(
Expand All @@ -222,15 +238,19 @@ async def get_avatar_url(self, target_user):
return result["avatar_url"]

async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
self,
target_user: UserID,
requester: Requester,
new_avatar_url: str,
by_admin: bool = False,
):
"""Set a new avatar URL for a user.
Args:
target_user (UserID): the user whose avatar URL is to be changed.
requester (Requester): The user attempting to make this change.
new_avatar_url (str): The avatar URL to give this user.
by_admin (bool): Whether this change was made by an administrator.
target_user: the user whose avatar URL is to be changed.
requester: The user attempting to make this change.
new_avatar_url: The avatar URL to give this user.
by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
Expand Down Expand Up @@ -267,7 +287,7 @@ async def set_avatar_url(

await self._update_join_states(requester, target_user)

async def on_profile_query(self, args):
async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
Expand All @@ -292,7 +312,9 @@ async def on_profile_query(self, args):

return response

async def _update_join_states(self, requester, target_user):
async def _update_join_states(
self, requester: Requester, target_user: UserID
) -> None:
if not self.hs.is_mine(target_user):
return

Expand Down Expand Up @@ -323,15 +345,17 @@ async def _update_join_states(self, requester, target_user):
"Failed to update join event for room %s - %s", room_id, str(e)
)

async def check_profile_query_allowed(self, target_user, requester=None):
async def check_profile_query_allowed(
self, target_user: UserID, requester: Optional[UserID] = None
) -> None:
"""Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
share a room.
Args:
target_user (UserID): The owner of the queried profile.
requester (None|UserID): The user querying for the profile.
target_user: The owner of the queried profile.
requester: The user querying for the profile.
Raises:
SynapseError(403): The two users share no room, or ne user couldn't
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -72,7 +72,7 @@ async def create_profile(self, user_localpart: str) -> None:
)

async def set_profile_displayname(
self, user_localpart: str, new_displayname: str
self, user_localpart: str, new_displayname: Optional[str]
) -> None:
await self.db_pool.simple_update_one(
table="profiles",
Expand Down Expand Up @@ -144,7 +144,7 @@ async def is_subscribed_remote_profile_for_user(self, user_id):

async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> Dict[str, str]:
) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`
"""

Expand Down

0 comments on commit de5cafe

Please sign in to comment.