diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index 303e190b6cc2..d868db410876 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -17,8 +17,12 @@ from twisted.internet.defer import CancelledError, ensureDeferred -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) +from tests.test_utils import make_awaitable from tests.unittest import TestCase @@ -115,3 +119,58 @@ def test_cancellation(self): self.tracker.notify_un_partial_stated("event1") self.successResultOf(d2) + + +class PartialCurrentStateTrackerTestCase(TestCase): + def setUp(self) -> None: + self.mock_store = mock.Mock(spec_set=["is_room_got_partial_state"]) + + self.tracker = PartialCurrentStateTracker(self.mock_store) + + def test_does_not_block_for_full_state_rooms(self): + self.mock_store.is_room_got_partial_state.return_value = make_awaitable(False) + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_blocks_for_partial_room_state(self): + self.mock_store.is_room_got_partial_state.return_value = make_awaitable(True) + + d = ensureDeferred(self.tracker.await_full_state("room_id")) + + # there should be no result yet + self.assertNoResult(d) + + # notifying that the room has been de-partial-stated should unblock + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d) + + def test_un_partial_state_race(self): + # We should correctly handle race between awaiting the state and us + # un-partialling the state + async def is_room_got_partial_state(events): + self.tracker.notify_un_partial_stated("room_id") + return True + + self.mock_store.is_room_got_partial_state.side_effect = ( + is_room_got_partial_state + ) + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_cancellation(self): + self.mock_store.is_room_got_partial_state.return_value = make_awaitable(True) + + d1 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d1) + + d2 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d2) + + d1.cancel() + self.assertFailure(d1, CancelledError) + + # d2 should still be waiting! + self.assertNoResult(d2) + + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d2)