Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Port handlers.account_validity to async/await. #6504

Merged
merged 3 commits into from
Dec 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6504.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Port handlers.account_data and handlers.account_validity to async/await.
14 changes: 5 additions & 9 deletions synapse/handlers/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer


class AccountDataEventSource(object):
def __init__(self, hs):
Expand All @@ -23,15 +21,14 @@ def __init__(self, hs):
def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()

@defer.inlineCallbacks
def get_new_events(self, user, from_key, **kwargs):
async def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key

current_stream_id = yield self.store.get_max_account_data_stream_id()
current_stream_id = self.store.get_max_account_data_stream_id()

results = []
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
tags = await self.store.get_updated_tags(user_id, last_stream_id)

for room_id, room_tags in tags.items():
results.append(
Expand All @@ -41,7 +38,7 @@ def get_new_events(self, user, from_key, **kwargs):
(
account_data,
room_account_data,
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)

for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
Expand All @@ -54,6 +51,5 @@ def get_new_events(self, user, from_key, **kwargs):

return results, current_stream_id

@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
async def get_pagination_rows(self, user, config, key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this even used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You raise a good point, seems like only /initialSync uses it for presence and receipts. Given it's kind of part of the *EventSource API I'll leave it to another PR to remove it from everywhere.

return [], config.to_id
86 changes: 40 additions & 46 deletions synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText

from twisted.internet import defer
from typing import List

from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
Expand Down Expand Up @@ -78,42 +77,39 @@ def send_emails():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"send_renewals", self.send_renewal_emails
"send_renewals", self._send_renewal_emails
)

self.clock.looping_call(send_emails, 30 * 60 * 1000)

@defer.inlineCallbacks
def send_renewal_emails(self):
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
expiring_users = yield self.store.get_users_expiring_soon()
expiring_users = await self.store.get_users_expiring_soon()

if expiring_users:
for user in expiring_users:
yield self._send_renewal_email(
await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)

@defer.inlineCallbacks
def send_renewal_email_to_user(self, user_id):
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
yield self._send_renewal_email(user_id, expiration_ts)
async def send_renewal_email_to_user(self, user_id: str):
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
await self._send_renewal_email(user_id, expiration_ts)

@defer.inlineCallbacks
def _send_renewal_email(self, user_id, expiration_ts):
async def _send_renewal_email(self, user_id: str, expiration_ts: int):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.

Args:
user_id (str): ID of the user to send email(s) to.
expiration_ts (int): Timestamp in milliseconds for the expiration date of
user_id: ID of the user to send email(s) to.
expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
addresses = yield self._get_email_addresses_for_user(user_id)
addresses = await self._get_email_addresses_for_user(user_id)

# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
Expand All @@ -125,15 +121,15 @@ def _send_renewal_email(self, user_id, expiration_ts):
return

try:
user_display_name = yield self.store.get_profile_displayname(
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
user_display_name = user_id
except StoreError:
user_display_name = user_id

renewal_token = yield self._get_renewal_token(user_id)
renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
Expand Down Expand Up @@ -165,7 +161,7 @@ def _send_renewal_email(self, user_id, expiration_ts):

logger.info("Sending renewal email to %s", address)

yield make_deferred_yieldable(
await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from,
Expand All @@ -180,19 +176,18 @@ def _send_renewal_email(self, user_id, expiration_ts):
)
)

yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)

@defer.inlineCallbacks
def _get_email_addresses_for_user(self, user_id):
async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
"""Retrieve the list of email addresses attached to a user's account.

Args:
user_id (str): ID of the user to lookup email addresses for.
user_id: ID of the user to lookup email addresses for.

Returns:
defer.Deferred[list[str]]: Email addresses for this account.
Email addresses for this account.
"""
threepids = yield self.store.user_get_threepids(user_id)
threepids = await self.store.user_get_threepids(user_id)

addresses = []
for threepid in threepids:
Expand All @@ -201,16 +196,15 @@ def _get_email_addresses_for_user(self, user_id):

return addresses

@defer.inlineCallbacks
def _get_renewal_token(self, user_id):
async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.

Args:
user_id (str): ID of the user to generate a string for.
user_id: ID of the user to generate a string for.

Returns:
defer.Deferred[str]: The generated string.
The generated string.

Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
Expand All @@ -219,52 +213,52 @@ def _get_renewal_token(self, user_id):
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")

@defer.inlineCallbacks
def renew_account(self, renewal_token):
async def renew_account(self, renewal_token: str) -> bool:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.

Args:
renewal_token (str): Token sent with the renewal request.
renewal_token: Token sent with the renewal request.
Returns:
bool: Whether the provided token is valid.
Whether the provided token is valid.
"""
try:
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
defer.returnValue(False)
return False

logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id)
await self.renew_account_for_user(user_id)

defer.returnValue(True)
return True

@defer.inlineCallbacks
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
async def renew_account_for_user(
self, user_id: str, expiration_ts: int = None, email_sent: bool = False
) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.

Args:
renewal_token (str): Token sent with the renewal request.
expiration_ts (int): New expiration date. Defaults to now + validity period.
email_sent (bool): Whether an email has been sent for this validity period.
renewal_token: Token sent with the renewal request.
expiration_ts: New expiration date. Defaults to now + validity period.
email_sen: Whether an email has been sent for this validity period.
Defaults to False.

Returns:
defer.Deferred[int]: New expiration date for this account, as a timestamp
in milliseconds since epoch.
New expiration date for this account, as a timestamp in
milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period

yield self.store.set_account_validity_for_user(
await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)

Expand Down
3 changes: 1 addition & 2 deletions tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,8 @@ def make_homeserver(self, reactor, clock):
# Email config.
self.email_attempts = []

def sendmail(*args, **kwargs):
async def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs))
return

config["email"] = {
"enable_notifs": True,
Expand Down