Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

ObservableDeferred: run observers in order #11229

Merged
merged 3 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11229.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`ObservableDeferred`: run registered observers in order.
34 changes: 18 additions & 16 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Generic,
Hashable,
Iterable,
List,
Optional,
Set,
TypeVar,
Expand Down Expand Up @@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]):
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
object.__setattr__(self, "_observers", list())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this dance? Do we override __setattr__ below? Oh yes, we do. I guess this is meant to be a fairly transparent wrapper around a Deferred.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for some reason someone decided that making ObservableDeferred a transparent wrapper for a Deferred was a good idea. I hate it.


def callback(r):
object.__setattr__(self, "_result", (True, r))
while self._observers:
observer = self._observers.pop()

# once we have set _result, no more entries will be added to _observers,
# so it's safe to replace it with the empty tuple.
observers = self._observers
object.__setattr__(self, "_observers", tuple())

for observer in observers:
try:
observer.callback(r)
except Exception as e:
Expand All @@ -95,12 +100,16 @@ def callback(r):

def errback(f):
object.__setattr__(self, "_result", (False, f))
while self._observers:

# once we have set _result, no more entries will be added to _observers,
# so it's safe to replace it with the empty tuple.
observers = self._observers
object.__setattr__(self, "_observers", tuple())

for observer in observers:
# This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f

observer = self._observers.pop()
try:
observer.errback(f)
except Exception as e:
Expand All @@ -127,20 +136,13 @@ def observe(self) -> "defer.Deferred[_T]":
"""
if not self._result:
d: "defer.Deferred[_T]" = defer.Deferred()

def remove(r):
self._observers.discard(d)
return r

d.addBoth(remove)
Comment on lines -130 to -135
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code removed the observer from self._observers when it was run. I think it was redundant before, because the observer was pop()ed from self._observers anyway, but it's doubly-redundant now, since the whole of self._observers is thrown away.

(it was added in #190 - no real clues there as to why.)


self._observers.add(d)
self._observers.append(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)

def observers(self) -> "List[defer.Deferred[_T]]":
def observers(self) -> "Collection[defer.Deferred[_T]]":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit surprised that we're leaking internal state here!

Could only see one use though. Probably fine as it is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup. Thought about changing it to a method that just returns the number of observers. Ran out of enthusiasm.

return self._observers

def has_called(self) -> bool:
Expand Down
4 changes: 1 addition & 3 deletions tests/util/caches/test_deferred_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def check1(r):
self.assertTrue(set_d.called)
return r

# TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
# maybe we should fix that?
# get_d.addCallback(check1)
Comment on lines -50 to -52
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment isn't terribly clear, but I think this is what we're fixing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_d.addCallback(check1)

# now fire off all the deferreds
origin_d.callback(99)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,78 @@
PreserveLoggingContext,
current_context,
)
from synapse.util.async_helpers import timeout_deferred
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred

from tests.unittest import TestCase


class ObservableDeferredTest(TestCase):
def test_succeed(self):
origin_d = Deferred()
observable = ObservableDeferred(origin_d)

observer1 = observable.observe()
observer2 = observable.observe()

self.assertFalse(observer1.called)
self.assertFalse(observer2.called)

# check the first observer is called first
def check_called_first(res):
self.assertFalse(observer2.called)
return res

observer1.addBoth(check_called_first)

# store the results
results = [None, None]

def check_val(res, idx):
results[idx] = res
return res

observer1.addCallback(check_val, 0)
observer2.addCallback(check_val, 1)

origin_d.callback(123)
self.assertEqual(results[0], 123, "observer 1 callback result")
self.assertEqual(results[1], 123, "observer 2 callback result")

def test_failure(self):
origin_d = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)

observer1 = observable.observe()
observer2 = observable.observe()

self.assertFalse(observer1.called)
self.assertFalse(observer2.called)

# check the first observer is called first
def check_called_first(res):
self.assertFalse(observer2.called)
return res

observer1.addBoth(check_called_first)

# store the results
results = [None, None]

def check_val(res, idx):
results[idx] = res
return None

observer1.addErrback(check_val, 0)
observer2.addErrback(check_val, 1)

try:
raise Exception("gah!")
except Exception as e:
origin_d.errback(e)
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")


class TimeoutDeferredTest(TestCase):
def setUp(self):
self.clock = Clock()
Expand Down