Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Factor out MultiWriter token from RoomStreamToken #16427

Merged
merged 6 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16427.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Factor out `MultiWriter` token from `RoomStreamToken`.
4 changes: 2 additions & 2 deletions synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->
else:
stream_ordering = room.stream_ordering

from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
from_key = RoomStreamToken(topological=0, stream=0)
to_key = RoomStreamToken(stream=stream_ordering)

# Events that we've processed in this room
written_events: Set[str] = set()
Expand Down
3 changes: 1 addition & 2 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ async def handle_room(event: RoomsForUser) -> None:
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
None,
event.stream_ordering,
stream=event.stream_ordering,
)
deferred_room_state = run_in_background(
self._state_storage_controller.get_state_for_events,
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,7 @@ async def get_new_events(

if from_key.topological:
logger.warning("Stream has topological part!!!! %r", from_key)
from_key = RoomStreamToken(None, from_key.stream)
from_key = RoomStreamToken(stream=from_key.stream)

app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,7 +2333,7 @@ async def _get_room_changes_for_initial_sync(
continue

leave_token = now_token.copy_and_replace(
StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
StreamKeyType.ROOM, RoomStreamToken(stream=event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def on_POST(
# RoomStreamToken expects [int] not Optional[int]
assert event.internal_metadata.stream_ordering is not None
room_token = RoomStreamToken(
event.depth, event.internal_metadata.stream_ordering
topological=event.depth, stream=event.internal_metadata.stream_ordering
)
token = await room_token.to_string(self.store)

Expand Down
22 changes: 13 additions & 9 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def generate_next_token(
# when we are going backwards so we subtract one from the
# stream part.
last_stream_ordering -= 1
return RoomStreamToken(last_topo_ordering, last_stream_ordering)
return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering)


def _make_generic_sql_bound(
Expand Down Expand Up @@ -558,7 +558,7 @@ def get_room_max_token(self) -> RoomStreamToken:
if p > min_pos
}

return RoomStreamToken(None, min_pos, immutabledict(positions))
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))

async def get_room_events_stream_for_rooms(
self,
Expand Down Expand Up @@ -708,7 +708,7 @@ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
ret.reverse()

if rows:
key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
Expand Down Expand Up @@ -969,7 +969,7 @@ async def get_current_room_stream_token_for_room_id(
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
return RoomStreamToken(topo, stream_ordering)
return RoomStreamToken(topological=topo, stream=stream_ordering)

@overload
def get_stream_id_for_event_txn(
Expand Down Expand Up @@ -1033,7 +1033,9 @@ async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToke
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
return RoomStreamToken(
topological=row["topological_ordering"], stream=row["stream_ordering"]
)

async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
Expand Down Expand Up @@ -1114,8 +1116,8 @@ def _set_before_and_after(
else:
topo = None
internal = event.internal_metadata
internal.before = RoomStreamToken(topo, stream - 1)
internal.after = RoomStreamToken(topo, stream)
internal.before = RoomStreamToken(topological=topo, stream=stream - 1)
internal.after = RoomStreamToken(topological=topo, stream=stream)
internal.order = (int(topo) if topo else 0, int(stream))

async def get_events_around(
Expand Down Expand Up @@ -1191,11 +1193,13 @@ def _get_events_around_txn(
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
results["topological_ordering"] - 1, results["stream_ordering"]
topological=results["topological_ordering"] - 1,
stream=results["stream_ordering"],
)

after_token = RoomStreamToken(
results["topological_ordering"], results["stream_ordering"]
topological=results["topological_ordering"],
stream=results["stream_ordering"],
)

rows, start_token = self._paginate_room_events_txn(
Expand Down
132 changes: 91 additions & 41 deletions synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from synapse.util.stringutils import parse_and_validate_server_name

if TYPE_CHECKING:
from typing_extensions import Self
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires typing_extensions >= 4.0. Not sure how we feel about bumping that? Or just leaving it behind TYPE_CHECKING?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy as long as the packagers are. https://pkgs.org/search/?q=typing-extensions and https://repology.org/project/python:typing-extensions/versions. Debian buster and Ubuntu focal + jammy have 3.x; I can't remember if we've dropped support for those.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ISWYM. It'd be nice to be able to just import it, but I don't mind the conditional import and referencing it in "quotes" if that helps the packagers too. (ISTR we have to quote "defer.Deferred[blah]" anyway...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think putting it behind TYPE_CHECKING is fine.


from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
Expand Down Expand Up @@ -436,7 +438,78 @@ def f2(m: Match[bytes]) -> bytes:


@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken:
class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
"""An abstract stream token class for streams that supports multiple
writers.

This works by keeping track of the stream position of each writer,
represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream
position.
"""

stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)

instance_map: "immutabledict[str, int]" = attr.ib(
factory=immutabledict,
validator=attr.validators.deep_mapping(
key_validator=attr.validators.instance_of(str),
value_validator=attr.validators.instance_of(int),
mapping_validator=attr.validators.instance_of(immutabledict),
),
kw_only=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OMG can we use kw_only everywhere?!?! 😍

)

@classmethod
@abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self":
"""Parse the string representation of the token."""
...

@abc.abstractmethod
async def to_string(self, store: "DataStore") -> str:
"""Serialize the token into its string representation."""
...

def copy_and_advance(self, other: "Self") -> "Self":
"""Return a new token such that if an event is after both this token and
the other token, then its after the returned token too.
"""

max_stream = max(self.stream, other.stream)

instance_map = {
instance: max(
self.instance_map.get(instance, self.stream),
other.instance_map.get(instance, other.stream),
)
for instance in set(self.instance_map).union(other.instance_map)
}

return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map)
)

def get_max_stream_pos(self) -> int:
"""Get the maximum stream position referenced in this token.

The corresponding "min" position is, by definition just `self.stream`.

This is used to handle tokens that have non-empty `instance_map`, and so
reference stream positions after the `self.stream` position.
"""
return max(self.instance_map.values(), default=self.stream)

def get_stream_pos_for_instance(self, instance_name: str) -> int:
"""Get the stream position that the given writer was at at this token."""

# If we don't have an entry for the instance we can assume that it was
# at `self.stream`.
return self.instance_map.get(instance_name, self.stream)


@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken(AbstractMultiWriterStreamToken):
"""Tokens are positions between events. The token "s1" comes after event 1.

s0 s1
Expand Down Expand Up @@ -513,16 +586,8 @@ class RoomStreamToken:

topological: Optional[int] = attr.ib(
validator=attr.validators.optional(attr.validators.instance_of(int)),
)
stream: int = attr.ib(validator=attr.validators.instance_of(int))

instance_map: "immutabledict[str, int]" = attr.ib(
factory=immutabledict,
validator=attr.validators.deep_mapping(
key_validator=attr.validators.instance_of(str),
value_validator=attr.validators.instance_of(int),
mapping_validator=attr.validators.instance_of(immutabledict),
),
kw_only=True,
default=None,
)

def __attrs_post_init__(self) -> None:
Expand Down Expand Up @@ -582,17 +647,7 @@ def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
if self.topological or other.topological:
raise Exception("Can't advance topological tokens")

max_stream = max(self.stream, other.stream)

instance_map = {
instance: max(
self.instance_map.get(instance, self.stream),
other.instance_map.get(instance, other.stream),
)
for instance in set(self.instance_map).union(other.instance_map)
}

return RoomStreamToken(None, max_stream, immutabledict(instance_map))
return super().copy_and_advance(other)

def as_historical_tuple(self) -> Tuple[int, int]:
"""Returns a tuple of `(topological, stream)` for historical tokens.
Expand All @@ -618,16 +673,6 @@ def get_stream_pos_for_instance(self, instance_name: str) -> int:
# at `self.stream`.
return self.instance_map.get(instance_name, self.stream)

def get_max_stream_pos(self) -> int:
"""Get the maximum stream position referenced in this token.

The corresponding "min" position is, by definition just `self.stream`.

This is used to handle tokens that have non-empty `instance_map`, and so
reference stream positions after the `self.stream` position.
"""
return max(self.instance_map.values(), default=self.stream)

async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
Expand Down Expand Up @@ -809,23 +854,28 @@ def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
return attr.evolve(self, **{key: new_value})


StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)


@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedEventPosition:
"""Position of a newly persisted event with instance that persisted it.

This can be used to test whether the event is persisted before or after a
RoomStreamToken.
"""
class PersistedPosition:
"""Position of a newly persisted row with instance that persisted it."""

instance_name: str
stream: int

def persisted_after(self, token: RoomStreamToken) -> bool:
def persisted_after(self, token: AbstractMultiWriterStreamToken) -> bool:
return token.get_stream_pos_for_instance(self.instance_name) < self.stream


@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedEventPosition(PersistedPosition):
"""Position of a newly persisted event with instance that persisted it.

This can be used to test whether the event is persisted before or after a
RoomStreamToken.
"""

def to_room_stream_token(self) -> RoomStreamToken:
"""Converts the position to a room stream token such that events
persisted in the same room after this position will be after the
Expand All @@ -836,7 +886,7 @@ def to_room_stream_token(self) -> RoomStreamToken:
"""
# Doing the naive thing satisfies the desired properties described in
# the docstring.
return RoomStreamToken(None, self.stream)
return RoomStreamToken(stream=self.stream)


@attr.s(slots=True, frozen=True, auto_attribs=True)
Expand Down
8 changes: 4 additions & 4 deletions tests/handlers/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_notify_interested_services(self) -> None:
[event],
]
)
self.handler.notify_interested_services(RoomStreamToken(None, 1))
self.handler.notify_interested_services(RoomStreamToken(stream=1))

self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, events=[event]
Expand All @@ -107,7 +107,7 @@ def test_query_user_exists_unknown_user(self) -> None:
]
)
self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.handler.notify_interested_services(RoomStreamToken(stream=0))

self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)

Expand All @@ -126,7 +126,7 @@ def test_query_user_exists_known_user(self) -> None:
]
)

self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.handler.notify_interested_services(RoomStreamToken(stream=0))

self.assertFalse(
self.mock_as_api.query_user.called,
Expand Down Expand Up @@ -441,7 +441,7 @@ def _notify_interested_services(self) -> None:
self.get_success(
self.hs.get_application_service_handler()._notify_interested_services(
RoomStreamToken(
None, self.hs.get_application_service_handler().current_max
stream=self.hs.get_application_service_handler().current_max
)
)
)
Expand Down
Loading