Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Optionally include account validity in MSC3720 account status responses #12266

Merged
merged 3 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12266.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optionally include account validity expiration information to experimental [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) account status responses.
4 changes: 4 additions & 0 deletions synapse/config/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,10 @@ def read_config(self, config, **kwargs):
):
raise ConfigError("'custom_template_directory' must be a string")

self.use_account_validity_in_account_status: bool = (
config.get("use_account_validity_in_account_status") or False
)

babolivier marked this conversation as resolved.
Show resolved Hide resolved
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)

Expand Down
11 changes: 11 additions & 0 deletions synapse/handlers/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client()
self._use_account_validity_in_account_status = (
hs.config.server.use_account_validity_in_account_status
)
self._account_validity_handler = hs.get_account_validity_handler()

async def get_account_statuses(
self,
Expand Down Expand Up @@ -106,6 +110,13 @@ async def _get_local_account_status(self, user_id: UserID) -> JsonDict:
"deactivated": userinfo.is_deactivated,
}

if self._use_account_validity_in_account_status:
status[
"org.matrix.expired"
] = await self._account_validity_handler.is_user_expired(
user_id.to_string()
)

return status

async def _get_remote_account_statuses(
Expand Down
58 changes: 57 additions & 1 deletion tests/rest/client/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util import Clock

from tests import unittest
Expand Down Expand Up @@ -1222,6 +1222,62 @@ async def post_json(
expected_failures=[users[2]],
)

@unittest.override_config(
{
"use_account_validity_in_account_status": True,
}
)
def test_no_account_validity(self) -> None:
"""Tests that if we decide to include account validity in the response but no
account validity 'is_user_expired' callback is provided, we default to marking all
users as not expired.
"""
user = self.register_user("someuser", "password")

self._test_status(
users=[user],
expected_statuses={
user: {
"exists": True,
"deactivated": False,
"org.matrix.expired": False,
},
},
expected_failures=[],
)

@unittest.override_config(
{
"use_account_validity_in_account_status": True,
}
)
def test_account_validity_expired(self) -> None:
"""Test that if we decide to include account validity in the response and the user
is expired, we return the correct info.
"""
user = self.register_user("someuser", "password")

async def is_expired(user_id: str) -> bool:
# We can't blindly say everyone is expired, otherwise the request to get the
# account status will fail.
return UserID.from_string(user_id).localpart == "someuser"

self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
is_expired
)

self._test_status(
users=[user],
expected_statuses={
user: {
"exists": True,
"deactivated": False,
"org.matrix.expired": True,
},
},
expected_failures=[],
)

def _test_status(
self,
users: Optional[List[str]],
Expand Down