Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Make MultiWriterIDGenerator work for streams that use negative stream IDs #8203

Merged
merged 4 commits into from
Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8203.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `MultiWriterIDGenerator` work for streams that use negative values.
39 changes: 28 additions & 11 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
"""

def __init__(
Expand All @@ -196,13 +198,19 @@ def __init__(
instance_column: str,
id_column: str,
sequence_name: str,
positive: bool = True,
):
self._db = db
self._instance_name = instance_name
self._positive = positive
self._return_factor = 1 if positive else -1

# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()

# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
self._current_positions = self._load_current_ids(
db_conn, table, instance_column, id_column
)
Expand Down Expand Up @@ -233,13 +241,16 @@ def __init__(
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
# If positive stream aggregate via MAX. For negative stream use MIN
# *and* negate the result to get a positive number.
sql = """
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
GROUP BY %(instance)s
""" % {
"instance": instance_column,
"id": id_column,
"table": table,
"agg": "MAX" if self._positive else "-MIN",
}

cur = db_conn.cursor()
Expand Down Expand Up @@ -269,15 +280,16 @@ async def get_next(self):
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
assert self.get_current_token_for_writer(self._instance_name) < next_id

with self._lock:
assert self._current_positions.get(self._instance_name, 0) < next_id

self._unfinished_ids.add(next_id)

@contextlib.contextmanager
def manager():
try:
yield next_id
# Multiply by the return factor so that the ID has correct sign.
yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)

Expand All @@ -296,15 +308,15 @@ async def get_next_mult(self, n: int):
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
assert max(self.get_positions().values(), default=0) < min(next_ids)

with self._lock:
assert max(self._current_positions.values(), default=0) < min(next_ids)

self._unfinished_ids.update(next_ids)

@contextlib.contextmanager
def manager():
try:
yield next_ids
yield [self._return_factor * i for i in next_ids]
finally:
for i in next_ids:
self._mark_id_as_finished(i)
Expand All @@ -327,7 +339,7 @@ def get_next_txn(self, txn: LoggingTransaction):
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)

return next_id
return self._return_factor * next_id

def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
Expand Down Expand Up @@ -359,20 +371,25 @@ def get_current_token_for_writer(self, instance_name: str) -> int:
"""

with self._lock:
return self._current_positions.get(instance_name, 0)
return self._return_factor * self._current_positions.get(instance_name, 0)

def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
"""

with self._lock:
return dict(self._current_positions)
return {
name: self._return_factor * i
for name, i in self._current_positions.items()
}

def advance(self, instance_name: str, new_id: int):
"""Advance the postion of the named writer to the given ID, if greater
than existing entry.
"""

new_id *= self._return_factor

with self._lock:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
Expand All @@ -390,7 +407,7 @@ def get_persisted_upto_position(self) -> int:
"""

with self._lock:
return self._persisted_upto_position
return self._return_factor * self._persisted_upto_position

def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.
Expand Down
105 changes: 105 additions & 0 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,108 @@ def test_get_persisted_upto_position_get_next(self):
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).


class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
"""

if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"

def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)

def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
positive=False,
)

return self.get_success(self.db_pool.runWithConnection(_create))

def _insert_row(self, instance_name: str, stream_id: int):
"""Insert one row as the given instance with given stream_id.
"""

def _insert(txn):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)

self.get_success(self.db_pool.runInteraction("_insert_row", _insert))

def test_single_instance(self):
"""Test that reads and writes from a single process are handled
correctly.
"""
id_gen = self._create_id_generator()

with self.get_success(id_gen.get_next()) as stream_id:
self._insert_row("master", stream_id)

self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)

with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
for stream_id in stream_ids:
self._insert_row("master", stream_id)

self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(id_gen.get_persisted_upto_position(), -4)

# Test loading from DB by creating a second ID gen
second_id_gen = self._create_id_generator()

self.assertEqual(second_id_gen.get_positions(), {"master": -4})
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)

def test_multiple_instance(self):
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
id_gen_1 = self._create_id_generator("first")
id_gen_2 = self._create_id_generator("second")

with self.get_success(id_gen_1.get_next()) as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)

self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)

with self.get_success(id_gen_2.get_next()) as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)

self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)