Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #6504 from matrix-org/erikj/account_validity_async…
Browse files Browse the repository at this point in the history
…_await

Port handlers.account_validity to async/await.
  • Loading branch information
erikjohnston committed Dec 11, 2019
2 parents 7c429f9 + 3f97b4c commit 31905a5
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 57 deletions.
1 change: 1 addition & 0 deletions changelog.d/6504.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Port handlers.account_data and handlers.account_validity to async/await.
14 changes: 5 additions & 9 deletions synapse/handlers/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer


class AccountDataEventSource(object):
def __init__(self, hs):
Expand All @@ -23,15 +21,14 @@ def __init__(self, hs):
def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()

@defer.inlineCallbacks
def get_new_events(self, user, from_key, **kwargs):
async def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key

current_stream_id = yield self.store.get_max_account_data_stream_id()
current_stream_id = self.store.get_max_account_data_stream_id()

results = []
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
tags = await self.store.get_updated_tags(user_id, last_stream_id)

for room_id, room_tags in tags.items():
results.append(
Expand All @@ -41,7 +38,7 @@ def get_new_events(self, user, from_key, **kwargs):
(
account_data,
room_account_data,
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)

for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
Expand All @@ -54,6 +51,5 @@ def get_new_events(self, user, from_key, **kwargs):

return results, current_stream_id

@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
async def get_pagination_rows(self, user, config, key):
return [], config.to_id
86 changes: 40 additions & 46 deletions synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText

from twisted.internet import defer
from typing import List

from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
Expand Down Expand Up @@ -78,42 +77,39 @@ def send_emails():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"send_renewals", self.send_renewal_emails
"send_renewals", self._send_renewal_emails
)

self.clock.looping_call(send_emails, 30 * 60 * 1000)

@defer.inlineCallbacks
def send_renewal_emails(self):
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
expiring_users = yield self.store.get_users_expiring_soon()
expiring_users = await self.store.get_users_expiring_soon()

if expiring_users:
for user in expiring_users:
yield self._send_renewal_email(
await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)

@defer.inlineCallbacks
def send_renewal_email_to_user(self, user_id):
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
yield self._send_renewal_email(user_id, expiration_ts)
async def send_renewal_email_to_user(self, user_id: str):
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
await self._send_renewal_email(user_id, expiration_ts)

@defer.inlineCallbacks
def _send_renewal_email(self, user_id, expiration_ts):
async def _send_renewal_email(self, user_id: str, expiration_ts: int):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.
Args:
user_id (str): ID of the user to send email(s) to.
expiration_ts (int): Timestamp in milliseconds for the expiration date of
user_id: ID of the user to send email(s) to.
expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
addresses = yield self._get_email_addresses_for_user(user_id)
addresses = await self._get_email_addresses_for_user(user_id)

# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
Expand All @@ -125,15 +121,15 @@ def _send_renewal_email(self, user_id, expiration_ts):
return

try:
user_display_name = yield self.store.get_profile_displayname(
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
user_display_name = user_id
except StoreError:
user_display_name = user_id

renewal_token = yield self._get_renewal_token(user_id)
renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
Expand Down Expand Up @@ -165,7 +161,7 @@ def _send_renewal_email(self, user_id, expiration_ts):

logger.info("Sending renewal email to %s", address)

yield make_deferred_yieldable(
await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from,
Expand All @@ -180,19 +176,18 @@ def _send_renewal_email(self, user_id, expiration_ts):
)
)

yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)

@defer.inlineCallbacks
def _get_email_addresses_for_user(self, user_id):
async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
"""Retrieve the list of email addresses attached to a user's account.
Args:
user_id (str): ID of the user to lookup email addresses for.
user_id: ID of the user to lookup email addresses for.
Returns:
defer.Deferred[list[str]]: Email addresses for this account.
Email addresses for this account.
"""
threepids = yield self.store.user_get_threepids(user_id)
threepids = await self.store.user_get_threepids(user_id)

addresses = []
for threepid in threepids:
Expand All @@ -201,16 +196,15 @@ def _get_email_addresses_for_user(self, user_id):

return addresses

@defer.inlineCallbacks
def _get_renewal_token(self, user_id):
async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
Args:
user_id (str): ID of the user to generate a string for.
user_id: ID of the user to generate a string for.
Returns:
defer.Deferred[str]: The generated string.
The generated string.
Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
Expand All @@ -219,52 +213,52 @@ def _get_renewal_token(self, user_id):
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")

@defer.inlineCallbacks
def renew_account(self, renewal_token):
async def renew_account(self, renewal_token: str) -> bool:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
Args:
renewal_token (str): Token sent with the renewal request.
renewal_token: Token sent with the renewal request.
Returns:
bool: Whether the provided token is valid.
Whether the provided token is valid.
"""
try:
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
defer.returnValue(False)
return False

logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id)
await self.renew_account_for_user(user_id)

defer.returnValue(True)
return True

@defer.inlineCallbacks
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
async def renew_account_for_user(
self, user_id: str, expiration_ts: int = None, email_sent: bool = False
) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
renewal_token (str): Token sent with the renewal request.
expiration_ts (int): New expiration date. Defaults to now + validity period.
email_sent (bool): Whether an email has been sent for this validity period.
renewal_token: Token sent with the renewal request.
expiration_ts: New expiration date. Defaults to now + validity period.
email_sen: Whether an email has been sent for this validity period.
Defaults to False.
Returns:
defer.Deferred[int]: New expiration date for this account, as a timestamp
in milliseconds since epoch.
New expiration date for this account, as a timestamp in
milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period

yield self.store.set_account_validity_for_user(
await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)

Expand Down
3 changes: 1 addition & 2 deletions tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,8 @@ def make_homeserver(self, reactor, clock):
# Email config.
self.email_attempts = []

def sendmail(*args, **kwargs):
async def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs))
return

config["email"] = {
"enable_notifs": True,
Expand Down

0 comments on commit 31905a5

Please sign in to comment.