Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Track ongoing event fetches correctly (again) #11376

Merged
merged 15 commits into from
Nov 26, 2021
1 change: 1 addition & 0 deletions changelog.d/11376.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection, for real this time. Also fix a race condition introduced in the previous insufficient fix in 1.47.0.
154 changes: 112 additions & 42 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
logger = logging.getLogger(__name__)


# These values are used in the `enqueus_event` and `_do_fetch` methods to
# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
Expand Down Expand Up @@ -602,7 +602,7 @@ async def _get_events_from_cache_or_db(
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, _EventCacheEntry]
] = ObservableDeferred(defer.Deferred())
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
richvdh marked this conversation as resolved.
Show resolved Hide resolved
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred

Expand Down Expand Up @@ -736,35 +736,118 @@ async def get_stripped_room_state_from_event_context(
for e in state_to_include.values()
]

def _do_fetch(self, conn: Connection) -> None:
def _maybe_start_fetch_thread(self) -> None:
"""Starts an event fetch thread if we are not yet at the maximum number."""
with self._event_fetch_lock:
if (
self._event_fetch_list
and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
):
self._event_fetch_ongoing += 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# `_event_fetch_ongoing` is decremented in `_fetch_thread`.
should_start = True
else:
should_start = False

if should_start:
run_as_background_process("fetch_events", self._fetch_thread)

async def _fetch_thread(self) -> None:
"""Services requests for events from `_event_fetch_list`."""
exc = None
try:
await self.db_pool.runWithConnection(self._fetch_loop)
except BaseException as e:
exc = e
raise
finally:
should_restart = False
event_fetches_to_fail = []
with self._event_fetch_lock:
self._event_fetch_ongoing -= 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)

# There may still be work remaining in `_event_fetch_list` if we
# failed, or it was added in between us deciding to exit and
# decrementing `_event_fetch_ongoing`.
if self._event_fetch_list:
if exc is None:
# We decided to exit, but then some more work was added
# before `_event_fetch_ongoing` was decremented.
# If a new event fetch thread was not started, we should
# restart ourselves since the remaining event fetch threads
# may take a while to get around to the new work.
#
# Unfortunately it is not possible to tell whether a new
# event fetch thread was started, so we restart
# unconditionally. If we are unlucky, we will end up with
# an idle fetch thread, but it will time out after
# `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
# in any case.
#
# Note that multiple fetch threads may run down this path at
# the same time.
should_restart = True
elif isinstance(exc, Exception):
if self._event_fetch_ongoing == 0:
# We were the last remaining fetcher and failed.
# Fail any outstanding fetches since no one else will
# handle them.
event_fetches_to_fail = self._event_fetch_list
self._event_fetch_list = []
else:
# We weren't the last remaining fetcher, so another
# fetcher will pick up the work. This will either happen
# after their existing work, however long that takes,
# or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
# they are idle.
pass
else:
# The exception is a `SystemExit`, `KeyboardInterrupt` or
# `GeneratorExit`. Don't try to do anything clever here.
pass

if should_restart:
# We exited cleanly but noticed more work.
self._maybe_start_fetch_thread()

if event_fetches_to_fail:
# We were the last remaining fetcher and failed.
# Fail any outstanding fetches since no one else will handle them.
assert exc is not None
with PreserveLoggingContext():
for _, deferred in event_fetches_to_fail:
deferred.errback(exc)

def _fetch_loop(self, conn: Connection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
try:
i = 0
while True:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []

if not event_list:
single_threaded = self.database_engine.single_threaded
if (
not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
or single_threaded
or i > EVENT_QUEUE_ITERATIONS
):
break
else:
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
i = 0
while True:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []

if not event_list:
richvdh marked this conversation as resolved.
Show resolved Hide resolved
# There are no requests waiting. If we haven't yet reached the
# maximum iteration limit, wait for some more requests to turn up.
# Otherwise, bail out.
single_threaded = self.database_engine.single_threaded
if (
not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
richvdh marked this conversation as resolved.
Show resolved Hide resolved
or single_threaded
or i > EVENT_QUEUE_ITERATIONS
):
return

self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0

self._fetch_event_list(conn, event_list)
finally:
self._event_fetch_ongoing -= 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
self._fetch_event_list(conn, event_list)

def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
Expand Down Expand Up @@ -806,9 +889,7 @@ def fire():
# We only want to resolve deferreds from the main thread
def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(exc)
d.errback(exc)

with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
Expand Down Expand Up @@ -983,20 +1064,9 @@ async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))

self._event_fetch_lock.notify()

if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
should_start = True
else:
should_start = False

if should_start:
run_as_background_process(
"fetch_events", self.db_pool.runWithConnection, self._do_fetch
)
self._maybe_start_fetch_thread()

logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
Expand Down
139 changes: 138 additions & 1 deletion tests/storage/databases/main/test_events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from contextlib import contextmanager
from typing import Generator

from twisted.enterprise.adbapi import ConnectionPool
from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.room_versions import EventFormatVersions, RoomVersions
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.server import HomeServer
from synapse.storage.databases.main.events_worker import (
EVENT_QUEUE_THREADS,
EventsWorkerStore,
)
from synapse.storage.types import Connection
from synapse.util import Clock
from synapse.util.async_helpers import yieldable_gather_results

from tests import unittest
Expand Down Expand Up @@ -144,3 +157,127 @@ def test_dedupe(self):

# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)


class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"""Test event fetching during a database outage."""

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
self.store: EventsWorkerStore = hs.get_datastore()

self.room_id = f"!room:{hs.hostname}"
self.event_ids = [f"event{i}" for i in range(20)]

self._populate_events()

def _populate_events(self) -> None:
"""Ensure that there are test events in the database.

When testing with the in-memory SQLite database, all the events are lost during
the simulated outage.

To ensure consistency between `room_id`s and `event_id`s before and after the
outage, rows are built and inserted manually.

Upserts are used to handle the non-SQLite case where events are not lost.
"""
self.get_success(
self.store.db_pool.simple_upsert(
"rooms",
{"room_id": self.room_id},
{"room_version": RoomVersions.V4.identifier},
)
)

self.event_ids = [f"event{i}" for i in range(20)]
for idx, event_id in enumerate(self.event_ids):
self.get_success(
self.store.db_pool.simple_upsert(
"events",
{"event_id": event_id},
{
"event_id": event_id,
"room_id": self.room_id,
"topological_ordering": idx,
"stream_ordering": idx,
"type": "test",
"processed": True,
"outlier": False,
},
)
)
self.get_success(
self.store.db_pool.simple_upsert(
"event_json",
{"event_id": event_id},
{
"room_id": self.room_id,
"json": json.dumps({"type": "test", "room_id": self.room_id}),
"internal_metadata": "{}",
"format_version": EventFormatVersions.V3,
},
)
)

@contextmanager
def _outage(self) -> Generator[None, None, None]:
"""Simulate a database outage.

Returns:
A context manager. While the context is active, any attempts to connect to
the database will fail.
"""
connection_pool = self.store.db_pool._db_pool

# Close all connections and shut down the database `ThreadPool`.
connection_pool.close()

# Restart the database `ThreadPool`.
connection_pool.start()

original_connection_factory = connection_pool.connectionFactory

def connection_factory(_pool: ConnectionPool) -> Connection:
raise Exception("Could not connect to the database.")

connection_pool.connectionFactory = connection_factory # type: ignore[assignment]
try:
yield
finally:
connection_pool.connectionFactory = original_connection_factory

# If the in-memory SQLite database is being used, all the events are gone.
# Restore the test data.
self._populate_events()

def test_failure(self) -> None:
"""Test that event fetches do not get stuck during a database outage."""
with self._outage():
failure = self.get_failure(
self.store.get_event(self.event_ids[0]), Exception
)
self.assertEqual(str(failure.value), "Could not connect to the database.")

def test_recovery(self) -> None:
"""Test that event fetchers recover after a database outage."""
with self._outage():
# Kick off a bunch of event fetches but do not pump the reactor
event_deferreds = []
for event_id in self.event_ids:
event_deferreds.append(ensureDeferred(self.store.get_event(event_id)))

# We should have maxed out on event fetcher threads
self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS)

# All the event fetchers will fail
self.pump()
self.assertEqual(self.store._event_fetch_ongoing, 0)

for event_deferred in event_deferreds:
failure = self.get_failure(event_deferred, Exception)
self.assertEqual(
str(failure.value), "Could not connect to the database."
)

# This next event fetch should succeed
self.get_success(self.store.get_event(self.event_ids[0]))