Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix race in replication #7226

Merged
merged 7 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7226.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Move catchup of replication streams logic to worker.
73 changes: 45 additions & 28 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,34 @@ async def on_RDATA(self, cmd: RdataCommand):
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
raise

if cmd.token is None or stream_name not in self._streams_connected:
# I.e. either this is part of a batch of updates for this stream (in
# which case batch until we get an update for the stream with a non
# None token) or we're currently connecting so we queue up rows.
self._pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)
# We linearize here for two reasons:
# 1. so we don't try and concurrently handle multiple rows for the
# same stream, and
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
with await self._position_linearizer.queue(cmd.stream_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried that we could get quite far behind (ie, have a long list of things waiting for the position linearizer) if the catchup is a bit slow and we get a few POSITION lines intermixed with lots of RDATA lines, all of which will end up getting processed in series.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm, though I'm not sure the solution to that is allowing RDATA to be processed in parallel. Perhaps we just want to add metrics for the queue size?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah maybe it's not worth worrying about for now. especially if we can mitigate it as per #7226 (comment).

if stream_name not in self._streams_connected:
# If the stream isn't marked as connected then we haven't seen a
# `POSITION` command yet, and so we may have missed some rows.
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.warning(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
)
return

if cmd.token is None:
# I.e. this is part of a batch of updates for this stream (in
# which case batch until we get an update for the stream with a non
# None token).
self._pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)

async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
Expand All @@ -124,12 +142,13 @@ async def on_POSITION(self, cmd: PositionCommand):
# We protect catching up with a linearizer in case the replication
# connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name):
# We're about to go and catch up with the stream, so mark as connecting
# to stop RDATA being handled at the same time by removing stream from
# list of connected streams. We also clear any batched up RDATA from
# before we got the POSITION.
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
self._streams_connected.discard(cmd.stream_name)
Comment on lines +145 to 147
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this actually doing anything useful? once we've caught up for the first time, _streams_connected[cmd.stream_name] is always set outside the linearizer. Alternatively expressed: the linearizer ensures that no RDATA will be processed while we catch up, so there is no need for us to clear _streams_connected. I'd just get rid of these three lines.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this a bit, the only difference is if an exception is raised. I don't think we want to handle RDATA for that stream if we fail to handle the position, but dropping everything doesn't sound like the right thing either

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmm this sounds like a thing we need to improve, but ok let's punt it for now.

self._pending_batches.clear()

# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(cmd.stream_name, [])

# Find where we previously streamed up to.
current_token = self._replication_data_handler.get_streams_to_replicate().get(
Expand All @@ -142,12 +161,17 @@ async def on_POSITION(self, cmd: PositionCommand):
)
return

# Fetch all updates between then and now.
limited = True
while limited:
updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token
)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
(
updates,
current_token,
missing_updates,
) = await stream.get_updates_since(current_token, cmd.token)

if updates:
await self.on_rdata(
cmd.stream_name,
Expand All @@ -158,13 +182,6 @@ async def on_POSITION(self, cmd: PositionCommand):
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)

# Handle any RDATA that came in while we were catching up.
rows = self._pending_batches.pop(cmd.stream_name, [])
if rows:
await self._replication_data_handler.on_rdata(
cmd.stream_name, rows[-1].token, rows
)

self._streams_connected.add(cmd.stream_name)

async def on_SYNC(self, cmd: SyncCommand):
Expand Down
3 changes: 2 additions & 1 deletion synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,13 @@ def make_http_update_function(
async def update_function(
from_token: int, upto_token: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
return await client(
result = await client(
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
limit=limit,
)
return result["updates"], result["upto_token"], result["limited"]

return update_function

Expand Down
40 changes: 20 additions & 20 deletions synapse/storage/data_stores/main/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,26 @@ def bulk_get_push_rules_enabled(self, user_ids):
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
return results

def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
if last_id == current_id:
return defer.succeed([])

def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
" op, priority_class, priority, conditions, actions"
" FROM push_rules_stream"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()

return self.db.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)


class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
Expand Down Expand Up @@ -685,26 +705,6 @@ def _insert_push_rules_update_txn(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)

def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
if last_id == current_id:
return defer.succeed([])

def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
" op, priority_class, priority, conditions, actions"
" FROM push_rules_stream"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()

return self.db.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)

def get_push_rules_stream_token(self):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
Expand Down