Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add more type hints to the main state store. (#12267)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Mar 31, 2022
1 parent 5e88143 commit 11df4ec
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
1 change: 1 addition & 0 deletions changelog.d/12267.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints for storage.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ exclude = (?x)
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/state.py
|synapse/storage/schema/

|tests/api/test_auth.py
Expand Down
18 changes: 11 additions & 7 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple

from frozendict import frozendict

from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
Expand All @@ -29,7 +30,7 @@
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, StateMap
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList

Expand Down Expand Up @@ -132,7 +133,7 @@ def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str:

return room_version

async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Expand All @@ -158,9 +159,10 @@ async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
predecessor = create_event.content.get("predecessor", None)

# Ensure the key is a dictionary
if not isinstance(predecessor, collections.abc.Mapping):
if not isinstance(predecessor, (dict, frozendict)):
return None

# The keys must be strings since the data is JSON.
return predecessor

async def get_create_event_for_room(self, room_id: str) -> EventBase:
Expand Down Expand Up @@ -306,7 +308,9 @@ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
list_name="event_ids",
num_args=1,
)
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
async def _get_state_group_for_events(
self, event_ids: Iterable[str]
) -> Dict[str, int]:
"""Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
Expand Down Expand Up @@ -521,7 +525,7 @@ def _background_remove_left_rooms_txn(
)

for user_id in potentially_left_users - joined_users:
await self.mark_remote_user_device_list_as_unsubscribed(user_id)
await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined]

return batch_size

Expand Down
6 changes: 4 additions & 2 deletions synapse/util/caches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import typing
from enum import Enum, auto
from sys import intern
from typing import Any, Callable, Dict, List, Optional, Sized
from typing import Any, Callable, Dict, List, Optional, Sized, TypeVar

import attr
from prometheus_client.core import Gauge
Expand Down Expand Up @@ -195,8 +195,10 @@ def register_cache(
)
}

T = TypeVar("T", Optional[str], str)

def intern_string(string: Optional[str]) -> Optional[str]:

def intern_string(string: T) -> T:
"""Takes a (potentially) unicode string and interns it if it's ascii"""
if string is None:
return None
Expand Down

0 comments on commit 11df4ec

Please sign in to comment.