Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add tests for database transaction callbacks #12198

Merged
merged 4 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12198.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add tests for database transaction callbacks.
113 changes: 112 additions & 1 deletion tests/storage/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from synapse.storage.database import make_tuple_comparison_clause
from typing import Callable, NoReturn, Tuple
from unittest.mock import Mock, call

from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.util import Clock

from tests import unittest

Expand All @@ -22,3 +33,103 @@ def test_native_tuple_comparison(self):
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])


class CallbacksTestCase(unittest.HomeserverTestCase):
"""Tests for transaction callbacks."""

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool

def _run_interaction(
self, func: Callable[[LoggingTransaction], object]
) -> Tuple[Mock, Mock]:
"""Run the given function in a database transaction, with callbacks registered.

Args:
func: The function to be run in a transaction. The transaction will be
retried if `func` raises an `OperationalError`.

Returns:
Two mocks, which were registered as an `after_callback` and an
`exception_callback` respectively, on every transaction attempt.
"""
after_callback = Mock()
exception_callback = Mock()

def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
func(txn)

try:
self.get_success_or_raise(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
except Exception:
pass

return after_callback, exception_callback

def test_after_callback(self) -> None:
"""Test that the after callback is called when a transaction succeeds."""
after_callback, exception_callback = self._run_interaction(lambda txn: None)

after_callback.assert_called_once_with(123, 456, extra=789)
exception_callback.assert_not_called()

def test_exception_callback(self) -> None:
"""Test that the exception callback is called when a transaction fails."""
after_callback, exception_callback = self._run_interaction(lambda txn: 1 / 0)

after_callback.assert_not_called()
exception_callback.assert_called_once_with(987, 654, extra=321)

def test_failed_retry(self) -> None:
"""Test that the exception callback is called for every failed attempt."""

def _test_txn(txn: LoggingTransaction) -> NoReturn:
"""Simulate a retryable failure on every attempt."""
raise self.db_pool.engine.module.OperationalError()

after_callback, exception_callback = self._run_interaction(_test_txn)

after_callback.assert_not_called()
exception_callback.assert_has_calls(
[
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
]
)
self.assertEqual(exception_callback.call_count, 6) # no additional calls
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL that there can be additional calls beyond what you give...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this surprising too. I was hoping there would be an assert-these-calls-and-only-these-calls method.


def test_successful_retry(self) -> None:
"""Test callbacks for a failed transaction followed by a successful attempt."""
first_attempt = True

def _test_txn(txn: LoggingTransaction) -> None:
"""Simulate a retryable failure on the first attempt only."""
nonlocal first_attempt
if first_attempt:
first_attempt = False
raise self.db_pool.engine.module.OperationalError()
else:
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could probably be simplified with Mock(side_effect=[self.db_pool.engine.module.OperationalError, None])

Similarly for test_failed_retry it could just be Mock(side_effect=self.db_pool.engine.module.OperationalError). Fine either way though IMO!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know that was possible. That's a lot nicer, thank you!


after_callback, exception_callback = self._run_interaction(_test_txn)

# Calling both `after_callback`s when the first attempt failed is rather
# dubious (#12184). But let's document the behaviour in a test.
after_callback.assert_has_calls(
[
call(123, 456, extra=789),
call(123, 456, extra=789),
]
)
self.assertEqual(after_callback.call_count, 2) # no additional calls
exception_callback.assert_not_called()