Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert some util functions to async (#8035)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Aug 6, 2020
1 parent d4a7829 commit fe6cfc8
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 61 deletions.
1 change: 1 addition & 0 deletions changelog.d/8035.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
39 changes: 21 additions & 18 deletions synapse/util/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
from functools import wraps

from prometheus_client import Counter

from twisted.internet import defer

from synapse.logging.context import LoggingContext, current_context
from synapse.metrics import InFlightGauge

Expand Down Expand Up @@ -62,25 +59,31 @@


def measure_func(name=None):
def wrapper(func):
block_name = func.__name__ if name is None else name
"""
Used to decorate an async function with a `Measure` context manager.
Usage:
if inspect.iscoroutinefunction(func):
@measure_func()
async def foo(...):
...
@wraps(func)
async def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r
Which is analogous to:
else:
async def foo(...):
with Measure(...):
...
"""

def wrapper(func):
block_name = func.__name__ if name is None else name

@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
@wraps(func)
async def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r

return measured_func

Expand Down
16 changes: 6 additions & 10 deletions synapse/util/retryutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import logging
import random

from twisted.internet import defer

import synapse.logging.context
from synapse.api.errors import CodeMessageException

Expand Down Expand Up @@ -54,8 +52,7 @@ def __init__(self, retry_last_ts, retry_interval, destination):
self.destination = destination


@defer.inlineCallbacks
def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
Expand All @@ -73,17 +70,17 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
Example usage:
try:
limiter = yield get_retry_limiter(destination, clock, store)
limiter = await get_retry_limiter(destination, clock, store)
with limiter:
response = yield do_request()
response = await do_request()
except NotRetryingDestination:
# We aren't ready to retry that destination.
raise
"""
failure_ts = None
retry_last_ts, retry_interval = (0, 0)

retry_timings = yield store.get_destination_retry_timings(destination)
retry_timings = await store.get_destination_retry_timings(destination)

if retry_timings:
failure_ts = retry_timings["failure_ts"]
Expand Down Expand Up @@ -222,10 +219,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if self.failure_ts is None:
self.failure_ts = retry_last_ts

@defer.inlineCallbacks
def store_retry_timings():
async def store_retry_timings():
try:
yield self.store.set_destination_retry_timings(
await self.store.set_destination_retry_timings(
self.destination,
self.failure_ts,
retry_last_ts,
Expand Down
44 changes: 11 additions & 33 deletions tests/util/test_retryutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self):
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastore()
d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
limiter = self.successResultOf(d)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))

# advance the clock a bit before making the request
self.pump(1)

with limiter:
pass

d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)

def test_limiter(self):
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastore()

d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
limiter = self.successResultOf(d)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))

self.pump(1)
try:
Expand All @@ -58,29 +52,22 @@ def test_limiter(self):
except AssertionError:
pass

# wait for the update to land
self.pump()

d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], failure_ts)
self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)

# now if we try again we should get a failure
d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
self.failureResultOf(d, NotRetryingDestination)
self.get_failure(
get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
)

#
# advance the clock and try again
#

self.pump(MIN_RETRY_INTERVAL)
d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
limiter = self.successResultOf(d)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))

self.pump(1)
try:
Expand All @@ -91,12 +78,7 @@ def test_limiter(self):
except AssertionError:
pass

# wait for the update to land
self.pump()

d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], retry_ts)
self.assertGreaterEqual(
Expand All @@ -110,9 +92,7 @@ def test_limiter(self):
# one more go, with success
#
self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
d = get_retry_limiter("test_dest", self.clock, store)
self.pump()
limiter = self.successResultOf(d)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))

self.pump(1)
with limiter:
Expand All @@ -121,7 +101,5 @@ def test_limiter(self):
# wait for the update to land
self.pump()

d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)

0 comments on commit fe6cfc8

Please sign in to comment.