Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix race in MultiWriterIdGenerator (#11045)
Browse files Browse the repository at this point in the history
The race allowed the current position to advance too far when stream IDs
are still being persisted.

This happened when it received a new stream ID from a remote write
between a new stream ID being allocated and it being added to the set of
unpersisted stream IDs.

Fixes #9424.
  • Loading branch information
erikjohnston authored Oct 12, 2021
1 parent 5c35074 commit 333d6f4
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 15 deletions.
1 change: 1 addition & 0 deletions changelog.d/11045.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug when using multiple event persister workers where events were not correctly sent down `/sync` due to a race.
82 changes: 67 additions & 15 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)

import attr
from sortedcontainers import SortedSet
from sortedcontainers import SortedList, SortedSet

from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import (
Expand Down Expand Up @@ -265,6 +265,15 @@ def __init__(
# should be less than the minimum of this set (if not empty).
self._unfinished_ids: SortedSet[int] = SortedSet()

# We also need to track when we've requested some new stream IDs but
# they haven't yet been added to the `_unfinished_ids` set. Every time
# we request a new stream ID we add the current max stream ID to the
# list, and remove it once we've added the newly allocated IDs to the
# `_unfinished_ids` set. This means that we *may* be allocated stream
# IDs above those in the list, and so we can't advance the local current
# position beyond the minimum stream ID in this list.
self._in_flight_fetches: SortedList[int] = SortedList()

# Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs.
self._finished_ids: Set[int] = set()
Expand All @@ -290,6 +299,9 @@ def __init__(
)
self._known_persisted_positions: List[int] = []

# The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1

self._sequence_gen = PostgresSequenceGenerator(sequence_name)

# We check that the table and sequence haven't diverged.
Expand All @@ -305,6 +317,10 @@ def __init__(
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)

self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1
)

def _load_current_ids(
self,
db_conn: LoggingDatabaseConnection,
Expand Down Expand Up @@ -411,10 +427,32 @@ def _load_current_ids(
cur.close()

def _load_next_id_txn(self, txn: Cursor) -> int:
return self._sequence_gen.get_next_id_txn(txn)
stream_ids = self._load_next_mult_id_txn(txn, 1)
return stream_ids[0]

def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
# We need to track that we've requested some more stream IDs, and what
# the current max allocated stream ID is. This is to prevent a race
# where we've been allocated stream IDs but they have not yet been added
# to the `_unfinished_ids` set, allowing the current position to advance
# past them.
with self._lock:
current_max = self._max_seen_allocated_stream_id
self._in_flight_fetches.add(current_max)

try:
stream_ids = self._sequence_gen.get_next_mult_txn(txn, n)

with self._lock:
self._unfinished_ids.update(stream_ids)
self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, self._unfinished_ids[-1]
)
finally:
with self._lock:
self._in_flight_fetches.remove(current_max)

return stream_ids

def get_next(self) -> AsyncContextManager[int]:
"""
Expand Down Expand Up @@ -463,9 +501,6 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:

next_id = self._load_next_id_txn(txn)

with self._lock:
self._unfinished_ids.add(next_id)

txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)

Expand Down Expand Up @@ -497,15 +532,27 @@ def _mark_id_as_finished(self, next_id: int) -> None:

new_cur: Optional[int] = None

if self._unfinished_ids:
if self._unfinished_ids or self._in_flight_fetches:
# If there are unfinished IDs then the new position will be the
# largest finished ID less than the minimum unfinished ID.
# largest finished ID strictly less than the minimum unfinished
# ID.

# The minimum unfinished ID needs to take account of both
# `_unfinished_ids` and `_in_flight_fetches`.
if self._unfinished_ids and self._in_flight_fetches:
# `_in_flight_fetches` stores the maximum safe stream ID, so
# we add one to make it equivalent to the minimum unsafe ID.
min_unfinished = min(
self._unfinished_ids[0], self._in_flight_fetches[0] + 1
)
elif self._in_flight_fetches:
min_unfinished = self._in_flight_fetches[0] + 1
else:
min_unfinished = self._unfinished_ids[0]

finished = set()

min_unfinshed = self._unfinished_ids[0]
for s in self._finished_ids:
if s < min_unfinshed:
if s < min_unfinished:
if new_cur is None or new_cur < s:
new_cur = s
else:
Expand Down Expand Up @@ -575,6 +622,10 @@ def advance(self, instance_name: str, new_id: int) -> None:
new_id, self._current_positions.get(instance_name, 0)
)

self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, new_id
)

self._add_persisted_position(new_id)

def get_persisted_upto_position(self) -> int:
Expand Down Expand Up @@ -605,7 +656,11 @@ def _add_persisted_position(self, new_id: int) -> None:
# to report a recent position when asked, rather than a potentially old
# one (if this instance hasn't written anything for a while).
our_current_position = self._current_positions.get(self._instance_name)
if our_current_position and not self._unfinished_ids:
if (
our_current_position
and not self._unfinished_ids
and not self._in_flight_fetches
):
self._current_positions[self._instance_name] = max(
our_current_position, new_id
)
Expand Down Expand Up @@ -697,9 +752,6 @@ async def __aenter__(self) -> Union[int, List[int]]:
db_autocommit=True,
)

with self.id_gen._lock:
self.id_gen._unfinished_ids.update(self.stream_ids)

if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor
else:
Expand Down

0 comments on commit 333d6f4

Please sign in to comment.