diff --git a/.ci/scripts/test_old_deps.sh b/.ci/scripts/test_old_deps.sh index 54ec3c8b0dce..b2859f752261 100755 --- a/.ci/scripts/test_old_deps.sh +++ b/.ci/scripts/test_old_deps.sh @@ -8,7 +8,9 @@ export DEBIAN_FRONTEND=noninteractive set -ex apt-get update -apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev +apt-get install -y \ + python3 python3-dev python3-pip python3-venv \ + libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev export LANG="C.UTF-8" diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000000..acb118c86e84 --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +# TODO: incorporate this into pyproject.toml if flake8 supports it in the future. +# See https://github.com/PyCQA/flake8/issues/234 +[flake8] +# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes +# for error codes. The ones we ignore are: +# W503: line break before binary operator +# W504: line break after binary operator +# E203: whitespace before ':' (which is contrary to pep8?) +# E731: do not assign a lambda expression, use a def +# E501: Line too long (black enforces this for us) +ignore=W503,W504,E203,E731,E501 diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index eb294f161939..eee3633d5043 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -7,7 +7,7 @@ on: # of things breaking (but only build one set of debs) pull_request: push: - branches: ["develop"] + branches: ["develop", "release-*"] # we do the full build on tags. tags: ["v*"] @@ -91,17 +91,7 @@ jobs: build-sdist: name: "Build pypi distribution files" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - - run: pip install wheel - - run: | - python setup.py sdist bdist_wheel - - uses: actions/upload-artifact@v2 - with: - name: python-dist - path: dist/* + uses: "matrix-org/backend-meta/.github/workflows/packaging.yml@v1" # if it's a tag, create a release and attach the artifacts to it attach-assets: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 75ac1304bfbc..e9e427732239 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,12 +10,19 @@ concurrency: cancel-in-progress: true jobs: + check-sampleconfig: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - run: pip install -e . + - run: scripts-dev/generate_sample_config --check + lint: runs-on: ubuntu-latest strategy: matrix: toxenv: - - "check-sampleconfig" - "check_codestyle" - "check_isort" - "mypy" @@ -43,29 +50,15 @@ jobs: ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 0 - uses: actions/setup-python@v2 - - run: pip install tox + - run: "pip install 'towncrier>=18.6.0rc1'" - run: scripts-dev/check-newsfragment env: PULL_REQUEST_NUMBER: ${{ github.event.number }} - lint-sdist: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: "3.x" - - run: pip install wheel - - run: python setup.py sdist bdist_wheel - - uses: actions/upload-artifact@v2 - with: - name: Python Distributions - path: dist/* - # Dummy step to gate other tests on without repeating the whole list linting-done: if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - needs: [lint, lint-crlf, lint-newsfile, lint-sdist] + needs: [lint, lint-crlf, lint-newsfile, check-sampleconfig] runs-on: ubuntu-latest steps: - run: "true" @@ -397,7 +390,6 @@ jobs: - lint - lint-crlf - lint-newsfile - - lint-sdist - trial - trial-olddeps - sytest diff --git a/CHANGES.md b/CHANGES.md index 8d91a7921138..ef671e73f178 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,7 +1,124 @@ +Synapse 1.54.0 (2022-03-08) +=========================== + +Please note that this will be the last release of Synapse that is compatible with Mjolnir 1.3.1 and earlier. +Administrators of servers which have the Mjolnir module installed are advised to upgrade Mjolnir to version 1.3.2 or later. + + +Bugfixes +-------- + +- Fix a bug introduced in Synapse 1.54.0rc1 preventing the new module callbacks introduced in this release from being registered by modules. ([\#12141](https://github.com/matrix-org/synapse/issues/12141)) +- Fix a bug introduced in Synapse 1.54.0rc1 where runtime dependency version checks would mistakenly check development dependencies if they were present and would not accept pre-release versions of dependencies. ([\#12129](https://github.com/matrix-org/synapse/issues/12129), [\#12177](https://github.com/matrix-org/synapse/issues/12177)) + + +Internal Changes +---------------- + +- Update release script to insert the previous version when writing "No significant changes" line in the changelog. ([\#12127](https://github.com/matrix-org/synapse/issues/12127)) +- Relax the version guard for "packaging" added in [\#12088](https://github.com/matrix-org/synapse/issues/12088). ([\#12166](https://github.com/matrix-org/synapse/issues/12166)) + + +Synapse 1.54.0rc1 (2022-03-02) +============================== + + +Features +-------- + +- Add support for [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) +- Improve the generated URL previews for some web pages. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) +- Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000)) +- Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). ([\#12001](https://github.com/matrix-org/synapse/issues/12001), [\#12067](https://github.com/matrix-org/synapse/issues/12067)) +- Enable modules to set a custom display name when registering a user. ([\#12009](https://github.com/matrix-org/synapse/issues/12009)) +- Advertise Matrix 1.1 and 1.2 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020), ([\#12022](https://github.com/matrix-org/synapse/issues/12022)) +- Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. ([\#12021](https://github.com/matrix-org/synapse/issues/12021)) +- Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). ([\#12058](https://github.com/matrix-org/synapse/issues/12058)) +- Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. ([\#12062](https://github.com/matrix-org/synapse/issues/12062)) + + +Bugfixes +-------- + +- Fix a bug introduced in Synapse 1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) +- Fix long-standing bug where the `get_rooms_for_user` cache was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) +- Fix a 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) +- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an `argument of type 'int' is not iterable` error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) +- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens in version 1.38.0. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) +- Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077)) +- Fix occasional `Unhandled error in Deferred` error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) +- Fix a bug introduced in Synapse 1.51.0 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098)) +- Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. ([\#12100](https://github.com/matrix-org/synapse/issues/12100)) +- Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. ([\#12105](https://github.com/matrix-org/synapse/issues/12105)) +- Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. This reduces server load and load on the receiving device. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) + + +Updates to the Docker image +--------------------------- + +- The Docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997)) +- Use Python 3.9 in Docker images by default. ([\#12112](https://github.com/matrix-org/synapse/issues/12112)) + + +Improved Documentation +---------------------- + +- Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. ([\#11599](https://github.com/matrix-org/synapse/issues/11599)) +- Explain the meaning of spam checker callbacks' return values. ([\#12003](https://github.com/matrix-org/synapse/issues/12003)) +- Clarify information about external Identity Provider IDs. ([\#12004](https://github.com/matrix-org/synapse/issues/12004)) + + +Deprecations and Removals +------------------------- + +- Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. ([\#11865](https://github.com/matrix-org/synapse/issues/11865)) +- Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). ([\#12008](https://github.com/matrix-org/synapse/issues/12008)) +- Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. ([\#12018](https://github.com/matrix-org/synapse/issues/12018)) +- Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#12073](https://github.com/matrix-org/synapse/issues/12073)) + + +Internal Changes +---------------- + +- Make the `get_room_version` method use `get_room_version_id` to benefit from caching. ([\#11808](https://github.com/matrix-org/synapse/issues/11808)) +- Remove unnecessary condition on knock -> leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900)) +- Add tests for device list changes between local users. ([\#11972](https://github.com/matrix-org/synapse/issues/11972)) +- Optimise calculating `device_list` changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974)) +- Add missing type hints to storage classes. ([\#11984](https://github.com/matrix-org/synapse/issues/11984)) +- Refactor the search code for improved readability. ([\#11991](https://github.com/matrix-org/synapse/issues/11991)) +- Move common deduplication code down into `_auth_and_persist_outliers`. ([\#11994](https://github.com/matrix-org/synapse/issues/11994)) +- Limit concurrent joins from applications services. ([\#11996](https://github.com/matrix-org/synapse/issues/11996)) +- Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. ([\#12005](https://github.com/matrix-org/synapse/issues/12005), [\#12039](https://github.com/matrix-org/synapse/issues/12039)) +- Preparation for faster-room-join work: parse MSC3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011)) +- Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. ([\#12012](https://github.com/matrix-org/synapse/issues/12012)) +- Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. ([\#12013](https://github.com/matrix-org/synapse/issues/12013)) +- Configure `tox` to use `venv` rather than `virtualenv`. ([\#12015](https://github.com/matrix-org/synapse/issues/12015)) +- Fix bug in `StateFilter.return_expanded()` and add some tests. ([\#12016](https://github.com/matrix-org/synapse/issues/12016)) +- Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. ([\#12019](https://github.com/matrix-org/synapse/issues/12019)) +- Update the `olddeps` CI job to use an old version of `markupsafe`. ([\#12025](https://github.com/matrix-org/synapse/issues/12025)) +- Upgrade Mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030)) +- Remove legacy `HomeServer.get_datastore()`. ([\#12031](https://github.com/matrix-org/synapse/issues/12031), [\#12070](https://github.com/matrix-org/synapse/issues/12070)) +- Minor typing fixes. ([\#12034](https://github.com/matrix-org/synapse/issues/12034), [\#12069](https://github.com/matrix-org/synapse/issues/12069)) +- After joining a room, create a dedicated logcontext to process the queued events. ([\#12041](https://github.com/matrix-org/synapse/issues/12041)) +- Tidy up GitHub Actions config which builds distributions for PyPI. ([\#12051](https://github.com/matrix-org/synapse/issues/12051)) +- Move configuration out of `setup.cfg`. ([\#12052](https://github.com/matrix-org/synapse/issues/12052), [\#12059](https://github.com/matrix-org/synapse/issues/12059)) +- Fix error message when a worker process fails to talk to another worker process. ([\#12060](https://github.com/matrix-org/synapse/issues/12060)) +- Fix using the `complement.sh` script without specifying a directory or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063)) +- Add type hints to `tests/rest/client`. ([\#12066](https://github.com/matrix-org/synapse/issues/12066), [\#12072](https://github.com/matrix-org/synapse/issues/12072), [\#12084](https://github.com/matrix-org/synapse/issues/12084), [\#12094](https://github.com/matrix-org/synapse/issues/12094)) +- Add some logging to `/sync` to try and track down #11916. ([\#12068](https://github.com/matrix-org/synapse/issues/12068)) +- Inspect application dependencies using `importlib.metadata` or its backport. ([\#12088](https://github.com/matrix-org/synapse/issues/12088)) +- Use `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092)) +- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to `/versions`. ([\#12099](https://github.com/matrix-org/synapse/issues/12099)) +- Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. ([\#12106](https://github.com/matrix-org/synapse/issues/12106)) +- Improve exception handling for concurrent execution. ([\#12109](https://github.com/matrix-org/synapse/issues/12109)) +- Advertise support for Python 3.10 in packaging files. ([\#12111](https://github.com/matrix-org/synapse/issues/12111)) +- Move CI checks out of tox, to facilitate a move to using poetry. ([\#12119](https://github.com/matrix-org/synapse/issues/12119)) + + Synapse 1.53.0 (2022-02-22) =========================== -No significant changes. +No significant changes since 1.53.0rc1. Synapse 1.53.0rc1 (2022-02-15) @@ -11,7 +128,7 @@ Features -------- - Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). ([\#11215](https://github.com/matrix-org/synapse/issues/11215), [\#11966](https://github.com/matrix-org/synapse/issues/11966)) -- Remove account data (including client config, push rules and ignored users) upon user deactivation. ([\#11655](https://github.com/matrix-org/synapse/issues/11655)) +- Add a background database update to purge account data for deactivated users. ([\#11655](https://github.com/matrix-org/synapse/issues/11655)) - Experimental support for [MSC3666](https://github.com/matrix-org/matrix-doc/pull/3666): including bundled aggregations in server side search results. ([\#11837](https://github.com/matrix-org/synapse/issues/11837)) - Enable cache time-based expiry by default. The `expiry_time` config flag has been superseded by `expire_caches` and `cache_entry_ttl`. ([\#11849](https://github.com/matrix-org/synapse/issues/11849)) - Add a callback to allow modules to allow or forbid a 3PID (email address, phone number) from being associated to a local account. ([\#11854](https://github.com/matrix-org/synapse/issues/11854)) @@ -92,7 +209,7 @@ Note that [Twisted 22.1.0](https://github.com/twisted/twisted/releases/tag/twist has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx) within the Twisted library. We do not believe Synapse is affected by this vulnerability, though we advise server administrators who installed Synapse via pip to upgrade Twisted -with `pip install --upgrade Twisted` as a matter of good practice. The Docker image +with `pip install --upgrade Twisted treq` as a matter of good practice. The Docker image `matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the updated library. @@ -273,7 +390,7 @@ Bugfixes Synapse 1.50.0 (2022-01-18) =========================== -**This release contains a critical bug that may prevent clients from being able to connect. +**This release contains a critical bug that may prevent clients from being able to connect. As such, it is not recommended to upgrade to 1.50.0. Instead, please upgrade straight to to 1.50.1. Further details are available in [this issue](https://github.com/matrix-org/synapse/issues/11763).** diff --git a/MANIFEST.in b/MANIFEST.in index c24786c3b371..76d14eb642b2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -45,6 +45,7 @@ include book.toml include pyproject.toml recursive-include changelog.d * +include .flake8 prune .circleci prune .github prune .ci diff --git a/debian/changelog b/debian/changelog index 574930c085cd..02136a0d606f 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,15 @@ +matrix-synapse-py3 (1.54.0) stable; urgency=medium + + * New synapse release 1.54.0. + + -- Synapse Packaging team Tue, 08 Mar 2022 10:54:52 +0000 + +matrix-synapse-py3 (1.54.0~rc1) stable; urgency=medium + + * New synapse release 1.54.0~rc1. + + -- Synapse Packaging team Wed, 02 Mar 2022 10:43:22 +0000 + matrix-synapse-py3 (1.53.0) stable; urgency=medium * New synapse release 1.53.0. diff --git a/docker/Dockerfile b/docker/Dockerfile index 25c24296d908..3d9f42c9fe04 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,10 +11,10 @@ # There is an optional PYTHON_VERSION build argument which sets the # version of python to build against: for example: # -# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.9 . +# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.10 . # -ARG PYTHON_VERSION=3.8 +ARG PYTHON_VERSION=3.9 ### ### Stage 0: builder @@ -89,8 +89,6 @@ COPY --from=builder /install /usr/local COPY ./docker/start.py /start.py COPY ./docker/conf /conf -VOLUME ["/data"] - EXPOSE 8008/tcp 8009/tcp 8448/tcp ENTRYPOINT ["/start.py"] diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 1bbe23708055..4076fcab65f1 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -126,7 +126,8 @@ Body parameters: [Sample Configuration File](../usage/configuration/homeserver_sample_config.html) section `sso` and `oidc_providers`. - `auth_provider` - string. ID of the external identity provider. Value of `idp_id` - in homeserver configuration. + in the homeserver configuration. Note that no error is raised if the provided + value is not in the homeserver configuration. - `external_id` - string, user ID in the external identity provider. - `avatar_url` - string, optional, must be a [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris). diff --git a/docs/manhole.md b/docs/manhole.md index 715ed840f2dd..a82fad0f0fe0 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -94,6 +94,6 @@ As a simple example, retrieving an event from the database: ```pycon >>> from twisted.internet import defer ->>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')) +>>> defer.ensureDeferred(hs.get_datastores().main.get_event('$1416420717069yeQaw:matrix.org')) > ``` diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index 88b59bb09e61..ec810fd292e5 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the authentication is denied. ### `on_logged_out` @@ -162,10 +162,38 @@ return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the username provided by the user is used, if any (otherwise one is automatically generated). +### `get_displayname_for_registration` + +_First introduced in Synapse v1.54.0_ + +```python +async def get_displayname_for_registration( + uia_results: Dict[str, Any], + params: Dict[str, Any], +) -> Optional[str] +``` + +Called when registering a new user. The module can return a display name to set for the +user being registered by returning it as a string, or `None` if it doesn't wish to force a +display name for this user. + +This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +has been completed by the user. It is not called when registering a user via SSO. It is +passed two dictionaries, which include the information that the user has provided during +the registration process. These dictionaries are identical to the ones passed to +[`get_username_for_registration`](#get_username_for_registration), so refer to the +documentation of this callback for more information about them. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. If every callback returns `None`, +the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`). + ## `is_3pid_allowed` _First introduced in Synapse v1.53.0_ @@ -194,8 +222,7 @@ The example module below implements authentication checkers for two different lo - Is checked by the method: `self.check_my_login` - `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) - Expects a `password` field to be sent to `/login` - - Is checked by the method: `self.check_pass` - + - Is checked by the method: `self.check_pass` ```python from typing import Awaitable, Callable, Optional, Tuple diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 2eb9032f4136..2b672b78f9ae 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -16,10 +16,12 @@ _First introduced in Synapse v1.37.0_ async def check_event_for_spam(event: "synapse.events.EventBase") -> Union[bool, str] ``` -Called when receiving an event from a client or via federation. The module can return -either a `bool` to indicate whether the event must be rejected because of spam, or a `str` -to indicate the event must be rejected because of spam and to give a rejection reason to -forward to clients. +Called when receiving an event from a client or via federation. The callback must return +either: +- an error message string, to indicate the event must be rejected because of spam and + give a rejection reason to forward to clients; +- the boolean `True`, to indicate that the event is spammy, but not provide further details; or +- the booelan `False`, to indicate that the event is not considered spammy. If multiple modules implement this callback, they will be considered in order. If a callback returns `False`, Synapse falls through to the next one. The value of the first @@ -35,7 +37,10 @@ async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool ``` Called when a user is trying to join a room. The module must return a `bool` to indicate -whether the user can join the room. The user is represented by their Matrix user ID (e.g. +whether the user can join the room. Return `False` to prevent the user from joining the +room; otherwise return `True` to permit the joining. + +The user is represented by their Matrix user ID (e.g. `@alice:example.com`) and the room is represented by its Matrix ID (e.g. `!room:example.com`). The module is also given a boolean to indicate whether the user currently has a pending invite in the room. @@ -58,7 +63,8 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool Called when processing an invitation. The module must return a `bool` indicating whether the inviter can invite the invitee to the given room. Both inviter and invitee are -represented by their Matrix user ID (e.g. `@alice:example.com`). +represented by their Matrix user ID (e.g. `@alice:example.com`). Return `False` to prevent +the invitation; otherwise return `True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -80,7 +86,8 @@ async def user_may_send_3pid_invite( Called when processing an invitation using a third-party identifier (also called a 3PID, e.g. an email address or a phone number). The module must return a `bool` indicating -whether the inviter can invite the invitee to the given room. +whether the inviter can invite the invitee to the given room. Return `False` to prevent +the invitation; otherwise return `True` to permit it. The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the invitee is represented by its medium (e.g. "email") and its address @@ -117,6 +124,7 @@ async def user_may_create_room(user: str) -> bool Called when processing a room creation request. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed to create a room. +Return `False` to prevent room creation; otherwise return `True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -133,7 +141,8 @@ async def user_may_create_room_alias(user: str, room_alias: "synapse.types.RoomA Called when trying to associate an alias with an existing room. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed -to set the given alias. +to set the given alias. Return `False` to prevent the alias creation; otherwise return +`True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -150,7 +159,8 @@ async def user_may_publish_room(user: str, room_id: str) -> bool Called when trying to publish a room to the homeserver's public rooms directory. The module must return a `bool` indicating whether the given user (represented by their -Matrix user ID) is allowed to publish the given room. +Matrix user ID) is allowed to publish the given room. Return `False` to prevent the +room from being published; otherwise return `True` to permit its publication. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -166,8 +176,11 @@ async def check_username_for_spam(user_profile: Dict[str, str]) -> bool ``` Called when computing search results in the user directory. The module must return a -`bool` indicating whether the given user profile can appear in search results. The profile -is represented as a dictionary with the following keys: +`bool` indicating whether the given user should be excluded from user directory +searches. Return `True` to indicate that the user is spammy and exclude them from +search results; otherwise return `False`. + +The profile is represented as a dictionary with the following keys: * `user_id`: The Matrix ID for this user. * `display_name`: The user's display name. @@ -225,8 +238,9 @@ async def check_media_file_for_spam( ) -> bool ``` -Called when storing a local or remote file. The module must return a boolean indicating -whether the given file can be stored in the homeserver's media store. +Called when storing a local or remote file. The module must return a `bool` indicating +whether the given file should be excluded from the homeserver's media store. Return +`True` to prevent this file from being stored; otherwise return `False`. If multiple modules implement this callback, they will be considered in order. If a callback returns `False`, Synapse falls through to the next one. The value of the first diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index a3a17096a8f5..09ac838107b8 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -148,6 +148,62 @@ deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#c If multiple modules implement this callback, Synapse runs them all in order. +### `on_profile_update` + +_First introduced in Synapse v1.54.0_ + +```python +async def on_profile_update( + user_id: str, + new_profile: "synapse.module_api.ProfileInfo", + by_admin: bool, + deactivation: bool, +) -> None: +``` + +Called after updating a local user's profile. The update can be triggered either by the +user themselves or a server admin. The update can also be triggered by a user being +deactivated (in which case their display name is set to an empty string (`""`) and the +avatar URL is set to `None`). The module is passed the Matrix ID of the user whose profile +has been updated, their new profile, as well as a `by_admin` boolean that is `True` if the +update was triggered by a server admin (and `False` otherwise), and a `deactivated` +boolean that is `True` if the update is a result of the user being deactivated. + +Note that the `by_admin` boolean is also `True` if the profile change happens as a result +of the user logging in through Single Sign-On, or if a server admin updates their own +profile. + +Per-room profile changes do not trigger this callback to be called. Synapse administrators +wishing this callback to be called on every profile change are encouraged to disable +per-room profiles globally using the `allow_per_room_profiles` configuration setting in +Synapse's configuration file. +This callback is not called when registering a user, even when setting it through the +[`get_displayname_for_registration`](https://matrix-org.github.io/synapse/latest/modules/password_auth_provider_callbacks.html#get_displayname_for_registration) +module callback. + +If multiple modules implement this callback, Synapse runs them all in order. + +### `on_user_deactivation_status_changed` + +_First introduced in Synapse v1.54.0_ + +```python +async def on_user_deactivation_status_changed( + user_id: str, deactivated: bool, by_admin: bool +) -> None: +``` + +Called after deactivating a local user, or reactivating them through the admin API. The +deactivation can be triggered either by the user themselves or a server admin. The module +is passed the Matrix ID of the user whose status is changed, as well as a `deactivated` +boolean that is `True` if the user is being deactivated and `False` if they're being +reactivated, and a `by_admin` boolean that is `True` if the deactivation was triggered by +a server admin (and `False` otherwise). This latter `by_admin` boolean is always `True` +if the user is being reactivated, as this operation can only be performed through the +admin API. + +If multiple modules implement this callback, Synapse runs them all in order. + ## Example The example below is a module that implements the third-party rules callback diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d2bb3d42080a..6f3623c88ab9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -163,7 +163,7 @@ presence: # For example, for room version 1, default_room_version should be set # to "1". # -#default_room_version: "6" +#default_room_version: "9" # The GC threshold parameters to pass to `gc.set_threshold`, if defined # diff --git a/docs/structured_logging.md b/docs/structured_logging.md index 14db85f58728..805c86765340 100644 --- a/docs/structured_logging.md +++ b/docs/structured_logging.md @@ -81,14 +81,12 @@ remote endpoint at 10.1.2.3:9999. ## Upgrading from legacy structured logging configuration -Versions of Synapse prior to v1.23.0 included a custom structured logging -configuration which is deprecated. It used a `structured: true` flag and -configured `drains` instead of ``handlers`` and `formatters`. - -Synapse currently automatically converts the old configuration to the new -configuration, but this will be removed in a future version of Synapse. The -following reference can be used to update your configuration. Based on the drain -`type`, we can pick a new handler: +Versions of Synapse prior to v1.54.0 automatically converted the legacy +structured logging configuration, which was deprecated in v1.23.0, to the standard +library logging configuration. + +The following reference can be used to update your configuration. Based on the +drain `type`, we can pick a new handler: 1. For a type of `console`, `console_json`, or `console_json_terse`: a handler with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout` diff --git a/docs/upgrade.md b/docs/upgrade.md index 477d7d0e81c8..f9be3ac6bc15 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,15 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.54.0 + +## Legacy structured logging configuration removal + +This release removes support for the `structured: true` logging configuration +which was deprecated in Synapse v1.23.0. If your logging configuration contains +`structured: true` then it should be modified based on the +[structured logging documentation](structured_logging.md). + # Upgrading to v1.53.0 ## Dropping support for `webclient` listeners and non-HTTP(S) `web_client_location` @@ -157,7 +166,7 @@ Note that [Twisted 22.1.0](https://github.com/twisted/twisted/releases/tag/twist has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx) within the Twisted library. We do not believe Synapse is affected by this vulnerability, though we advise server administrators who installed Synapse via pip to upgrade Twisted -with `pip install --upgrade Twisted` as a matter of good practice. The Docker image +with `pip install --upgrade Twisted treq` as a matter of good practice. The Docker image `matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the updated library. diff --git a/docs/workers.md b/docs/workers.md index dadde4d7265d..b0f8599ef062 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -178,8 +178,11 @@ recommend the use of `systemd` where available: for information on setting up ### `synapse.app.generic_worker` -This worker can handle API requests matching the following regular -expressions: +This worker can handle API requests matching the following regular expressions. +These endpoints can be routed to any worker. If a worker is set up to handle a +stream then, for maximum efficiency, additional endpoints should be routed to that +worker: refer to the [stream writers](#stream-writers) section below for further +information. # Sync requests ^/_matrix/client/(v2_alpha|r0|v3)/sync$ @@ -209,7 +212,6 @@ expressions: ^/_matrix/federation/v1/user/devices/ ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query - ^/_matrix/federation/unstable/org.matrix.msc2946/spaces/ ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/ # Inbound federation transaction request @@ -222,22 +224,25 @@ expressions: ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ - ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/devices$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/query$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/changes$ + ^/_matrix/client/(r0|v3|unstable)/account/3pid$ + ^/_matrix/client/(r0|v3|unstable)/devices$ ^/_matrix/client/versions$ ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_groups$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups/ + ^/_matrix/client/(r0|v3|unstable)/joined_groups$ + ^/_matrix/client/(r0|v3|unstable)/publicised_groups$ + ^/_matrix/client/(r0|v3|unstable)/publicised_groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ + # Encryption requests + ^/_matrix/client/(r0|v3|unstable)/keys/query$ + ^/_matrix/client/(r0|v3|unstable)/keys/changes$ + ^/_matrix/client/(r0|v3|unstable)/keys/claim$ + ^/_matrix/client/(r0|v3|unstable)/room_keys/ + # Registration/login requests ^/_matrix/client/(api/v1|r0|v3|unstable)/login$ ^/_matrix/client/(r0|v3|unstable)/register$ @@ -251,6 +256,20 @@ expressions: ^/_matrix/client/(api/v1|r0|v3|unstable)/join/ ^/_matrix/client/(api/v1|r0|v3|unstable)/profile/ + # Device requests + ^/_matrix/client/(r0|v3|unstable)/sendToDevice/ + + # Account data requests + ^/_matrix/client/(r0|v3|unstable)/.*/tags + ^/_matrix/client/(r0|v3|unstable)/.*/account_data + + # Receipts requests + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers + + # Presence requests + ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ + Additionally, the following REST endpoints can be handled for GET requests: @@ -330,12 +349,10 @@ Additionally, there is *experimental* support for moving writing of specific streams (such as events) off of the main process to a particular worker. (This is only supported with Redis-based replication.) -Currently supported streams are `events` and `typing`. - To enable this, the worker must have a HTTP replication listener configured, -have a `worker_name` and be listed in the `instance_map` config. For example to -move event persistence off to a dedicated worker, the shared configuration would -include: +have a `worker_name` and be listed in the `instance_map` config. The same worker +can handle multiple streams. For example, to move event persistence off to a +dedicated worker, the shared configuration would include: ```yaml instance_map: @@ -347,6 +364,12 @@ stream_writers: events: event_persister1 ``` +Some of the streams have associated endpoints which, for maximum efficiency, should +be routed to the workers handling that stream. See below for the currently supported +streams and the endpoints associated with them: + +##### The `events` stream + The `events` stream also experimentally supports having multiple writers, where work is sharded between them by room ID. Note that you *must* restart all worker instances when adding or removing event persisters. An example `stream_writers` @@ -359,6 +382,43 @@ stream_writers: - event_persister2 ``` +##### The `typing` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `typing` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing + +##### The `to_device` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `to_device` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/ + +##### The `account_data` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `account_data` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags + ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data + +##### The `receipts` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `receipts` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers + +##### The `presence` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `presence` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ + #### Background tasks There is also *experimental* support for moving background tasks to a separate diff --git a/mypy.ini b/mypy.ini index 63848d664c57..38ff78760931 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,14 +31,11 @@ exclude = (?x) |synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py - |synapse/storage/databases/main/presence.py - |synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py |synapse/storage/databases/main/state.py - |synapse/storage/databases/main/user_directory.py |synapse/storage/schema/ |tests/api/test_auth.py @@ -78,16 +75,12 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_account.py - |tests/rest/client/test_events.py |tests/rest/client/test_filter.py - |tests/rest/client/test_groups.py - |tests/rest/client/test_register.py |tests/rest/client/test_report_event.py |tests/rest/client/test_rooms.py |tests/rest/client/test_third_party_rules.py |tests/rest/client/test_transactions.py |tests/rest/client/test_typing.py - |tests/rest/client/utils.py |tests/rest/key/v2/test_remote_key_resource.py |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py @@ -256,7 +249,7 @@ disallow_untyped_defs = True [mypy-tests.rest.admin.*] disallow_untyped_defs = True -[mypy-tests.rest.client.test_directory] +[mypy-tests.rest.client.*] disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] diff --git a/pyproject.toml b/pyproject.toml index 963f149c6af6..c9cd0cf6ec10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,3 +54,15 @@ exclude = ''' )/ ) ''' + +[tool.isort] +line_length = 88 +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TWISTED", "FIRSTPARTY", "TESTS", "LOCALFOLDER"] +default_section = "THIRDPARTY" +known_first_party = ["synapse"] +known_tests = ["tests"] +known_twisted = ["twisted", "OpenSSL"] +multi_line_output = 3 +include_trailing_comma = true +combine_as_imports = true + diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index c764011d6adb..493558ad651b 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -35,7 +35,7 @@ CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing y https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog" # If check-newsfragment returns a non-zero exit code, print the contributing guide and exit -tox -qe check-newsfragment || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) +python -m towncrier.check --compare-with=origin/develop || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) echo echo "--------------------------" diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index e08ffedaf33a..0aecb3daf158 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -5,7 +5,7 @@ # It makes a Synapse image which represents the current checkout, # builds a synapse-complement image on top, then runs tests with it. # -# By default the script will fetch the latest Complement master branch and +# By default the script will fetch the latest Complement main branch and # run tests with that. This can be overridden to use a custom Complement # checkout by setting the COMPLEMENT_DIR environment variable to the # filepath of a local Complement checkout or by setting the COMPLEMENT_REF @@ -32,7 +32,7 @@ cd "$(dirname $0)/.." # Check for a user-specified Complement checkout if [[ -z "$COMPLEMENT_DIR" ]]; then - COMPLEMENT_REF=${COMPLEMENT_REF:-master} + COMPLEMENT_REF=${COMPLEMENT_REF:-main} echo "COMPLEMENT_DIR not set. Fetching Complement checkout from ${COMPLEMENT_REF}..." wget -Nq https://github.com/matrix-org/complement/archive/${COMPLEMENT_REF}.tar.gz tar -xzf ${COMPLEMENT_REF}.tar.gz diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 4e1f99fee4e1..046453e65ff8 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -17,6 +17,8 @@ """An interactive script for doing a release. See `cli()` below. """ +import glob +import os import re import subprocess import sys @@ -209,8 +211,8 @@ def prepare(): with open("synapse/__init__.py", "w") as f: f.write(parsed_synapse_ast.dumps()) - # Generate changelogs - run_until_successful("python3 -m towncrier", shell=True) + # Generate changelogs. + generate_and_write_changelog(current_version) # Generate debian changelogs if parsed_new_version.pre is not None: @@ -523,5 +525,29 @@ class VersionSection: return "\n".join(version_changelog) +def generate_and_write_changelog(current_version: version.Version): + # We do this by getting a draft so that we can edit it before writing to the + # changelog. + result = run_until_successful( + "python3 -m towncrier --draft", shell=True, capture_output=True + ) + new_changes = result.stdout.decode("utf-8") + new_changes = new_changes.replace( + "No significant changes.", f"No significant changes since {current_version}." + ) + + # Prepend changes to changelog + with open("CHANGES.md", "r+") as f: + existing_content = f.read() + f.seek(0, 0) + f.write(new_changes) + f.write("\n") + f.write(existing_content) + + # Remove all the news fragments + for f in glob.iglob("changelog.d/*.*"): + os.remove(f) + + if __name__ == "__main__": cli() diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 5c6453d77ffb..f43676afaaac 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -44,7 +44,7 @@ class MockHomeserver(HomeServer): def run_background_updates(hs): - store = hs.get_datastore() + store = hs.get_datastores().main async def run_background_updates(): await store.db_pool.updates.run_background_updates(sleep=False) diff --git a/setup.cfg b/setup.cfg index e5ceb7ed1933..6213f3265b59 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,3 @@ -[trial] -test_suite = tests - [check-manifest] ignore = .git-blame-ignore-revs @@ -10,23 +7,3 @@ ignore = pylint.cfg tox.ini -[flake8] -# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes -# for error codes. The ones we ignore are: -# W503: line break before binary operator -# W504: line break after binary operator -# E203: whitespace before ':' (which is contrary to pep8?) -# E731: do not assign a lambda expression, use a def -# E501: Line too long (black enforces this for us) -ignore=W503,W504,E203,E731,E501 - -[isort] -line_length = 88 -sections=FUTURE,STDLIB,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER -default_section=THIRDPARTY -known_first_party = synapse -known_tests=tests -known_twisted=twisted,OpenSSL -multi_line_output=3 -include_trailing_comma=true -combine_as_imports=true diff --git a/setup.py b/setup.py index d0511c767f98..26f4650348d5 100755 --- a/setup.py +++ b/setup.py @@ -103,8 +103,8 @@ def exec_file(path_segments): ] CONDITIONAL_REQUIREMENTS["mypy"] = [ - "mypy==0.910", - "mypy-zope==0.3.2", + "mypy==0.931", + "mypy-zope==0.3.5", "types-bleach>=4.1.0", "types-jsonschema>=3.2.0", "types-opentracing>=2.4.2", @@ -165,6 +165,7 @@ def exec_file(path_segments): "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], scripts=["synctl"] + glob.glob("scripts/*"), cmdclass={"test": TestCommand}, diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi index 0eaef0049860..344d55cce118 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -66,13 +66,18 @@ class SortedDict(Dict[_KT, _VT]): def __copy__(self: _SD) -> _SD: ... @classmethod @overload - def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ... + def fromkeys( + cls, seq: Iterable[_T_h], value: None = ... + ) -> SortedDict[_T_h, None]: ... @classmethod @overload def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ... - def keys(self) -> SortedKeysView[_KT]: ... - def items(self) -> SortedItemsView[_KT, _VT]: ... - def values(self) -> SortedValuesView[_VT]: ... + # As of Python 3.10, `dict_{keys,items,values}` have an extra `mapping` attribute and so + # `Sorted{Keys,Items,Values}View` are no longer compatible with them. + # See https://github.com/python/typeshed/issues/6837 + def keys(self) -> SortedKeysView[_KT]: ... # type: ignore[override] + def items(self) -> SortedItemsView[_KT, _VT]: ... # type: ignore[override] + def values(self) -> SortedValuesView[_VT]: ... # type: ignore[override] @overload def pop(self, key: _KT) -> _VT: ... @overload diff --git a/synapse/__init__.py b/synapse/__init__.py index 903f2e815dd9..c6727024f026 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ except ImportError: pass -__version__ = "1.53.0" +__version__ = "1.54.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9bfb0a0a44cd..bd32cb4813a3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -60,7 +60,7 @@ class Auth: def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._account_validity_handler = hs.get_account_validity_handler() diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 08fe160c98ec..22348d2d8630 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -28,7 +28,7 @@ class AuthBlocking: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._hs_disabled = hs.config.server.hs_disabled diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index d087c816db57..cb532d723828 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -22,6 +22,7 @@ Dict, Iterable, List, + Mapping, Optional, Set, TypeVar, @@ -150,7 +151,7 @@ def matrix_user_id_validator(user_id_str: str) -> UserID: class Filtering: def __init__(self, hs: "HomeServer"): self._hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) @@ -294,7 +295,7 @@ def blocks_all_room_timeline(self) -> bool: class Filter: def __init__(self, hs: "HomeServer", filter_json: JsonDict): self._hs = hs - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.filter_json = filter_json self.limit = filter_json.get("limit", 10) @@ -361,10 +362,10 @@ def _check(self, event: FilterEvent) -> bool: return self._check_fields(field_matchers) else: content = event.get("content") - # Content is assumed to be a dict below, so ensure it is. This should + # Content is assumed to be a mapping below, so ensure it is. This should # always be true for events, but account_data has been allowed to # have non-dict content. - if not isinstance(content, dict): + if not isinstance(content, Mapping): content = {} sender = event.get("sender", None) diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index ee51480a9e64..334c3d2c178d 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -15,13 +15,13 @@ import sys from typing import Container -from synapse import python_dependencies # noqa: E402 +from synapse.util import check_dependencies logger = logging.getLogger(__name__) try: - python_dependencies.check_requirements() -except python_dependencies.DependencyException as e: + check_dependencies.check_requirements() +except check_dependencies.DependencyException as e: sys.stderr.writelines( e.message # noqa: B306, DependencyException.message is a property ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 452c0c09d526..3e59805baa90 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -448,7 +448,7 @@ def run_sighup(*args: Any, **kwargs: Any) -> None: # It is now safe to start your Synapse. hs.start_listening() - hs.get_datastore().db_pool.start_profiling() + hs.get_datastores().main.db_pool.start_profiling() hs.get_pusherpool().start() # Log when we start the shut down process. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index aadc882bf898..1536a4272333 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -142,7 +142,7 @@ def __init__(self, hs: HomeServer): """ super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.http_client = hs.get_simple_http_client() self.main_uri = hs.config.worker.worker_main_http_uri diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bfb30003c2d3..a6789a840ede 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -59,7 +59,6 @@ from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource @@ -70,6 +69,7 @@ from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.util.check_dependencies import check_requirements from synapse.util.httpresourcetree import create_resource_tree from synapse.util.module_loader import load_module @@ -372,7 +372,7 @@ async def start() -> None: await _base.start(hs) - hs.get_datastore().db_pool.updates.start_doing_background_updates() + hs.get_datastores().main.db_pool.updates.start_doing_background_updates() register_start(start) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 899dba5c3d76..40dbdace8ed1 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -82,7 +82,7 @@ async def phone_stats_home( # General statistics # - store = hs.get_datastore() + store = hs.get_datastores().main stats["homeserver"] = hs.config.server.server_name stats["server_context"] = hs.config.server.server_context @@ -170,18 +170,22 @@ def performance_stats_init() -> None: # Rather than update on per session basis, batch up the requests. # If you increase the loop period, the accuracy of user_daily_visits # table will decrease - clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000) + clock.looping_call( + hs.get_datastores().main.generate_user_daily_visits, 5 * 60 * 1000 + ) # monthly active user limiting functionality - clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60) - hs.get_datastore().reap_monthly_active_users() + clock.looping_call( + hs.get_datastores().main.reap_monthly_active_users, 1000 * 60 * 60 + ) + hs.get_datastores().main.reap_monthly_active_users() @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users() -> None: current_mau_count = 0 current_mau_count_by_service = {} reserved_users: Sized = () - store = hs.get_datastore() + store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 2069cd8f8606..87f36531f454 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -31,6 +31,14 @@ logger = logging.getLogger(__name__) +# Type for the `device_one_time_key_counts` field in an appservice transaction +# user ID -> {device ID -> {algorithm -> count}} +TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] + +# Type for the `device_unused_fallback_keys` field in an appservice transaction +# user ID -> {device ID -> [algorithm]} +TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]] + class ApplicationServiceState(Enum): DOWN = "down" @@ -72,6 +80,7 @@ def __init__( rate_limited: bool = True, ip_range_whitelist: Optional[IPSet] = None, supports_ephemeral: bool = False, + msc3202_transaction_extensions: bool = False, ): self.token = token self.url = ( @@ -84,6 +93,7 @@ def __init__( self.id = id self.ip_range_whitelist = ip_range_whitelist self.supports_ephemeral = supports_ephemeral + self.msc3202_transaction_extensions = msc3202_transaction_extensions if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -344,12 +354,16 @@ def __init__( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ): self.service = service self.id = id self.events = events self.ephemeral = ephemeral self.to_device_messages = to_device_messages + self.one_time_key_counts = one_time_key_counts + self.unused_fallback_keys = unused_fallback_keys async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -364,6 +378,8 @@ async def send(self, as_api: "ApplicationServiceApi") -> bool: events=self.events, ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, + one_time_key_counts=self.one_time_key_counts, + unused_fallback_keys=self.unused_fallback_keys, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 73be7ff3d4a1..a0ea958af62a 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -19,6 +19,11 @@ from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient @@ -26,7 +31,6 @@ from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: - from synapse.appservice import ApplicationService from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -219,6 +223,8 @@ async def push_bulk( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, txn_id: Optional[int] = None, ) -> bool: """ @@ -252,7 +258,7 @@ async def push_bulk( uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) # Never send ephemeral events to appservices that do not support it - body: Dict[str, List[JsonDict]] = {"events": serialized_events} + body: JsonDict = {"events": serialized_events} if service.supports_ephemeral: body.update( { @@ -262,6 +268,16 @@ async def push_bulk( } ) + if service.msc3202_transaction_extensions: + if one_time_key_counts: + body[ + "org.matrix.msc3202.device_one_time_key_counts" + ] = one_time_key_counts + if unused_fallback_keys: + body[ + "org.matrix.msc3202.device_unused_fallback_keys" + ] = unused_fallback_keys + try: await self.put_json( uri=uri, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 1e93439e1eb8..75ad1aabdcb7 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -54,12 +54,19 @@ Callable, Collection, Dict, + Iterable, List, Optional, Set, + Tuple, ) -from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase from synapse.logging.context import run_in_background @@ -92,11 +99,11 @@ class ApplicationServiceScheduler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) - self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) + self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs) async def start(self) -> None: logger.info("Starting appservice scheduler") @@ -153,7 +160,9 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl: "_TransactionController", clock: Clock): + def __init__( + self, txn_ctrl: "_TransactionController", clock: Clock, hs: "HomeServer" + ): # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [events]} @@ -165,6 +174,10 @@ def __init__(self, txn_ctrl: "_TransactionController", clock: Clock): self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock + self._msc3202_transaction_extensions_enabled: bool = ( + hs.config.experimental.msc3202_transaction_extensions + ) + self._store = hs.get_datastores().main def start_background_request(self, service: ApplicationService) -> None: # start a sender for this appservice if we don't already have one @@ -202,15 +215,84 @@ async def _send_request(self, service: ApplicationService) -> None: if not events and not ephemeral and not to_device_messages_to_send: return + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None + + if ( + self._msc3202_transaction_extensions_enabled + and service.msc3202_transaction_extensions + ): + # Compute the one-time key counts and fallback key usage states + # for the users which are mentioned in this transaction, + # as well as the appservice's sender. + ( + one_time_key_counts, + unused_fallback_keys, + ) = await self._compute_msc3202_otk_counts_and_fallback_keys( + service, events, ephemeral, to_device_messages_to_send + ) + try: await self.txn_ctrl.send( - service, events, ephemeral, to_device_messages_to_send + service, + events, + ephemeral, + to_device_messages_to_send, + one_time_key_counts, + unused_fallback_keys, ) except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) + async def _compute_msc3202_otk_counts_and_fallback_keys( + self, + service: ApplicationService, + events: Iterable[EventBase], + ephemerals: Iterable[JsonDict], + to_device_messages: Iterable[JsonDict], + ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + """ + Given a list of the events, ephemeral messages and to-device messages, + - first computes a list of application services users that may have + interesting updates to the one-time key counts or fallback key usage. + - then computes one-time key counts and fallback key usages for those users. + Given a list of application service users that are interesting, + compute one-time key counts and fallback key usages for the users. + """ + + # Set of 'interesting' users who may have updates + users: Set[str] = set() + + # The sender is always included + users.add(service.sender) + + # All AS users that would receive the PDUs or EDUs sent to these rooms + # are classed as 'interesting'. + rooms_of_interesting_users: Set[str] = set() + # PDUs + rooms_of_interesting_users.update(event.room_id for event in events) + # EDUs + rooms_of_interesting_users.update( + ephemeral["room_id"] for ephemeral in ephemerals + ) + + # Look up the AS users in those rooms + for room_id in rooms_of_interesting_users: + users.update( + await self._store.get_app_service_users_in_room(room_id, service) + ) + + # Add recipients of to-device messages. + # device_message["user_id"] is the ID of the recipient. + users.update(device_message["user_id"] for device_message in to_device_messages) + + # Compute and return the counts / fallback key usage states + otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users) + unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users) + return otk_counts, unused_fbks + class _TransactionController: """Transaction manager. @@ -238,6 +320,8 @@ async def send( events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -248,6 +332,10 @@ async def send( events: The persistent events to include in the transaction. ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -255,6 +343,8 @@ async def send( events=events, ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], + one_time_key_counts=one_time_key_counts or {}, + unused_fallback_keys=unused_fallback_keys or {}, ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 7fad2e0422b8..439bfe15269c 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -166,6 +166,16 @@ def _load_appservice( supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + # Opt-in flag for the MSC3202-specific transactional behaviour. + # When enabled, appservice transactions contain the following information: + # - device One-Time Key counts + # - device unused fallback key usage states + msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) + if not isinstance(msc3202_transaction_extensions, bool): + raise ValueError( + "The `org.matrix.msc3202` option should be true or false if specified." + ) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -174,8 +184,9 @@ def _load_appservice( hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], - supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, + supports_ephemeral=supports_ephemeral, + msc3202_transaction_extensions=msc3202_transaction_extensions, ) diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 387ac6d115e3..9a68da9c336a 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -20,7 +20,7 @@ import attr -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 040ff57d8999..61efa01d7869 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -43,20 +43,12 @@ def read_config(self, config: JsonDict, **kwargs): # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) - # MSC3283 (set displayname, avatar_url and change 3pid capabilities) - self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False) - # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) # MSC3030 (Jump to date API endpoint) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) - # The portion of MSC3202 which is related to device masquerading. - self.msc3202_device_masquerading_enabled: bool = experimental.get( - "msc3202_device_masquerading", False - ) - # MSC2409 (this setting only relates to optionally sending to-device messages). # Presence, typing and read receipt EDUs are already sent to application services that # have opted in to receive them. If enabled, this adds to-device messages to that list. @@ -64,5 +56,23 @@ def read_config(self, config: JsonDict, **kwargs): "msc2409_to_device_messages_enabled", False ) + # The portion of MSC3202 which is related to device masquerading. + self.msc3202_device_masquerading_enabled: bool = experimental.get( + "msc3202_device_masquerading", False + ) + + # Portion of MSC3202 related to transaction extensions: + # sending one-time key counts and fallback key usage to application services. + self.msc3202_transaction_extensions: bool = experimental.get( + "msc3202_transaction_extensions", False + ) + # MSC3706 (server-side support for partial state in /send_join responses) self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) + + # experimental support for faster joins over federation (msc2775, msc3706) + # requires a target server with msc3706_enabled enabled. + self.faster_joins_enabled: bool = experimental.get("faster_joins", False) + + # MSC3720 (Account status endpoint) + self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index b7145a44ae1d..cbbe22196528 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -33,7 +33,6 @@ globalLogBeginner, ) -from synapse.logging._structured import setup_structured_logging from synapse.logging.context import LoggingContextFilter from synapse.logging.filter import MetadataFilter @@ -138,6 +137,12 @@ removed in Synapse 1.3.0. You should instead set up a separate log configuration file. """ +STRUCTURED_ERROR = """\ +Support for the structured configuration option was removed in Synapse 1.54.0. +You should instead use the standard logging configuration. See +https://matrix-org.github.io/synapse/v1.54/structured_logging.html +""" + class LoggingConfig(Config): section = "logging" @@ -292,10 +297,9 @@ def _load_logging_config(log_config_path: str) -> None: if not log_config: logging.warning("Loaded a blank logging config?") - # If the old structured logging configuration is being used, convert it to - # the new style configuration. + # If the old structured logging configuration is being used, raise an error. if "structured" in log_config and log_config.get("structured"): - log_config = setup_structured_logging(log_config) + raise ConfigError(STRUCTURED_ERROR) logging.config.dictConfig(log_config) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 1cc26e757812..f62292ecf6cb 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -15,7 +15,7 @@ import attr -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index e783b1131501..f7e4f9ef22b1 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -20,11 +20,11 @@ from synapse.config._util import validate_config from synapse.config.sso import SsoAttributeRequirement -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_mxc_uri +from ..util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError, read_file DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc.JinjaOidcMappingProvider" diff --git a/synapse/config/redis.py b/synapse/config/redis.py index 33104af73472..bdb1aac3a218 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.config._base import Config -from synapse.python_dependencies import check_requirements +from synapse.util.check_dependencies import check_requirements class RedisConfig(Config): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 1980351e7711..0a0d901bfb65 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -20,8 +20,8 @@ import attr from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict +from synapse.util.check_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module from ._base import Config, ConfigError diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index ec9d9f65e7b5..43c456d5c675 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -17,8 +17,8 @@ from typing import Any, List, Set from synapse.config.sso import SsoAttributeRequirement -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict +from synapse.util.check_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError diff --git a/synapse/config/server.py b/synapse/config/server.py index 7bc962454684..49cd0a4f192d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -146,7 +146,7 @@ def generate_ip_set( "fec0::/10", ] -DEFAULT_ROOM_VERSION = "6" +DEFAULT_ROOM_VERSION = "9" ROOM_COMPLEXITY_TOO_GREAT = ( "Your homeserver is unable to join rooms this large or complex. " diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 21b9a88353a7..7aff618ea660 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -14,7 +14,7 @@ from typing import Set -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 72d4a69aac35..93d56c077a57 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -476,7 +476,7 @@ class StoreKeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _fetch_keys( self, keys_to_fetch: List[_FetchKeyRequest] @@ -498,7 +498,7 @@ class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config async def process_v2_response( diff --git a/synapse/event_auth.py b/synapse/event_auth.py index eca00bc97506..621a3efcccec 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -374,9 +374,9 @@ def _is_membership_change_allowed( return # Require the user to be in the room for membership changes other than join/knock. - if Membership.JOIN != membership and ( - RoomVersion.msc2403_knocking and Membership.KNOCK != membership - ): + # Note that the room version check for knocking is done implicitly by `caller_knocked` + # and the ability to set a membership of `knock` in the first place. + if Membership.JOIN != membership and Membership.KNOCK != membership: # If the user has been invited or has knocked, they are allowed to change their # membership event to leave if ( diff --git a/synapse/events/builder.py b/synapse/events/builder.py index eb39e0ae3291..1ea1bb7d3790 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -189,7 +189,7 @@ def __init__(self, hs: "HomeServer"): self.hostname = hs.hostname self.signing_key = hs.signing_key - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 5833fee25fce..46042b2bf7af 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -101,6 +101,9 @@ class EventContext: As with _current_state_ids, this is a private attribute. It should be accessed via get_prev_state_ids. + + partial_state: if True, we may be storing this event with a temporary, + incomplete state. """ rejected: Union[bool, str] = False @@ -113,12 +116,15 @@ class EventContext: _current_state_ids: Optional[StateMap[str]] = None _prev_state_ids: Optional[StateMap[str]] = None + partial_state: bool = False + @staticmethod def with_state( state_group: Optional[int], state_group_before_event: Optional[int], current_state_ids: Optional[StateMap[str]], prev_state_ids: Optional[StateMap[str]], + partial_state: bool, prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ) -> "EventContext": @@ -129,6 +135,7 @@ def with_state( state_group_before_event=state_group_before_event, prev_group=prev_group, delta_ids=delta_ids, + partial_state=partial_state, ) @staticmethod @@ -170,6 +177,7 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, + "partial_state": self.partial_state, } @staticmethod @@ -196,6 +204,7 @@ def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], + partial_state=input.get("partial_state", False), ) app_service_id = input["app_service_id"] diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 1bb8ca7145fd..ede72ee87631 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -17,6 +17,7 @@ from synapse.api.errors import ModuleFailedException, SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.storage.roommember import ProfileInfo from synapse.types import Requester, StateMap from synapse.util.async_helpers import maybe_awaitable @@ -37,6 +38,8 @@ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] +ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] +ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -143,7 +146,7 @@ class ThirdPartyEventRules: def __init__(self, hs: "HomeServer"): self.third_party_rules = None - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] @@ -154,6 +157,10 @@ def __init__(self, hs: "HomeServer"): CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = [] self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = [] + self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = [] + self._on_user_deactivation_status_changed_callbacks: List[ + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + ] = [] def register_third_party_rules_callbacks( self, @@ -166,6 +173,10 @@ def register_third_party_rules_callbacks( CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, + on_user_deactivation_status_changed: Optional[ + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + ] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -187,6 +198,14 @@ def register_third_party_rules_callbacks( if on_new_event is not None: self._on_new_event_callbacks.append(on_new_event) + if on_profile_update is not None: + self._on_profile_update_callbacks.append(on_profile_update) + + if on_user_deactivation_status_changed is not None: + self._on_user_deactivation_status_changed_callbacks.append( + on_user_deactivation_status_changed, + ) + async def check_event_allowed( self, event: EventBase, context: EventContext ) -> Tuple[bool, Optional[dict]]: @@ -334,9 +353,6 @@ async def on_new_event(self, event_id: str) -> None: Args: event_id: The ID of the event. - - Raises: - ModuleFailureError if a callback raised any exception. """ # Bail out early without hitting the store if we don't have any callbacks if len(self._on_new_event_callbacks) == 0: @@ -370,3 +386,41 @@ async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: state_events[key] = room_state_events[event_id] return state_events + + async def on_profile_update( + self, user_id: str, new_profile: ProfileInfo, by_admin: bool, deactivation: bool + ) -> None: + """Called after the global profile of a user has been updated. Does not include + per-room profile changes. + + Args: + user_id: The user whose profile was changed. + new_profile: The updated profile for the user. + by_admin: Whether the profile update was performed by a server admin. + deactivation: Whether this change was made while deactivating the user. + """ + for callback in self._on_profile_update_callbacks: + try: + await callback(user_id, new_profile, by_admin, deactivation) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + + async def on_user_deactivation_status_changed( + self, user_id: str, deactivated: bool, by_admin: bool + ) -> None: + """Called after a user has been deactivated or reactivated. + + Args: + user_id: The deactivated user. + deactivated: Whether the user is now deactivated. + by_admin: Whether the deactivation was performed by a server admin. + """ + for callback in self._on_user_deactivation_status_changed_callbacks: + try: + await callback(user_id, deactivated, by_admin) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 243696b35724..9386fa29ddd3 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -425,6 +425,33 @@ def serialize_event( return serialized_event + def _apply_edit( + self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase + ) -> None: + """Replace the content, preserving existing relations of the serialized event. + + Args: + orig_event: The original event. + serialized_event: The original event, serialized. This is modified. + edit: The event which edits the above. + """ + + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations + relates_to = orig_event.content.get("m.relates_to") + if relates_to: + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relates_to.copy() + else: + serialized_event["content"].pop("m.relates_to", None) + def _inject_bundled_aggregations( self, event: EventBase, @@ -450,26 +477,11 @@ def _inject_bundled_aggregations( serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references if aggregations.replace: - # If there is an edit replace the content, preserving existing - # relations. + # If there is an edit, apply it to the event. edit = aggregations.replace + self._apply_edit(event, serialized_event, edit) - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relates_to = event.content.get("m.relates_to") - if relates_to: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relates_to.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - + # Include information about it in the relations dict. serialized_aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, @@ -478,13 +490,22 @@ def _inject_bundled_aggregations( # If this event is the start of a thread, include a summary of the replies. if aggregations.thread: + thread = aggregations.thread + + # Don't bundle aggregations as this could recurse forever. + serialized_latest_event = self.serialize_event( + thread.latest_event, time_now, bundle_aggregations=None + ) + # Manually apply an edit, if one exists. + if thread.latest_edit: + self._apply_edit( + thread.latest_event, serialized_latest_event, thread.latest_edit + ) + serialized_aggregations[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": self.serialize_event( - aggregations.thread.latest_event, time_now, bundle_aggregations=None - ), - "count": aggregations.thread.count, - "current_user_participated": aggregations.thread.current_user_participated, + "latest_event": serialized_latest_event, + "count": thread.count, + "current_user_participated": thread.current_user_participated, } # Include the bundled aggregations in the event. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 896168c05c0a..41ac49fdc8bf 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -39,7 +39,7 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.keyring = hs.get_keyring() self.spam_checker = hs.get_spam_checker() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._clock = hs.get_clock() async def _check_sigs_and_hash( @@ -47,6 +47,11 @@ async def _check_sigs_and_hash( ) -> EventBase: """Checks that event is correctly signed by the sending server. + Also checks the content hash, and redacts the event if there is a mismatch. + + Also runs the event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. + Args: room_version: The room version of the PDU pdu: the event to be checked @@ -55,7 +60,10 @@ async def _check_sigs_and_hash( * the original event if the checks pass * a redacted version of the event (if the signature matched but the hash did not) - * throws a SynapseError if the signature check failed.""" + + Raises: + SynapseError if the signature check failed. + """ try: await _check_sigs_on_pdu(self.keyring, room_version, pdu) except SynapseError as e: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 74f17aa4daa3..64e595e748f4 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1,4 +1,4 @@ -# Copyright 2015-2021 The Matrix.org Foundation C.I.C. +# Copyright 2015-2022 The Matrix.org Foundation C.I.C. # Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -56,7 +56,7 @@ from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -89,6 +89,12 @@ class SendJoinResult: state: List[EventBase] auth_chain: List[EventBase] + # True if 'state' elides non-critical membership events + partial_state: bool + + # if 'partial_state' is set, a list of the servers in the room (otherwise empty) + servers_in_room: List[str] + class FederationClient(FederationBase): def __init__(self, hs: "HomeServer"): @@ -413,26 +419,90 @@ async def get_room_state_ids( return state_event_ids, auth_event_ids + async def get_room_state( + self, + destination: str, + room_id: str, + event_id: str, + room_version: RoomVersion, + ) -> Tuple[List[EventBase], List[EventBase]]: + """Calls the /state endpoint to fetch the state at a particular point + in the room. + + Any invalid events (those with incorrect or unverifiable signatures or hashes) + are filtered out from the response, and any duplicate events are removed. + + (Size limits and other event-format checks are *not* performed.) + + Note that the result is not ordered, so callers must be careful to process + the events in an order that handles dependencies. + + Returns: + a tuple of (state events, auth events) + """ + result = await self.transport_layer.get_room_state( + room_version, + destination, + room_id, + event_id, + ) + state_events = result.state + auth_events = result.auth_events + + # we may as well filter out any duplicates from the response, to save + # processing them multiple times. (In particular, events may be present in + # `auth_events` as well as `state`, which is redundant). + # + # We don't rely on the sort order of the events, so we can just stick them + # in a dict. + state_event_map = {event.event_id: event for event in state_events} + auth_event_map = { + event.event_id: event + for event in auth_events + if event.event_id not in state_event_map + } + + logger.info( + "Processing from /state: %d state events, %d auth events", + len(state_event_map), + len(auth_event_map), + ) + + valid_auth_events = await self._check_sigs_and_hash_and_fetch( + destination, auth_event_map.values(), room_version + ) + + valid_state_events = await self._check_sigs_and_hash_and_fetch( + destination, state_event_map.values(), room_version + ) + + return valid_state_events, valid_auth_events + async def _check_sigs_and_hash_and_fetch( self, origin: str, pdus: Collection[EventBase], room_version: RoomVersion, ) -> List[EventBase]: - """Takes a list of PDUs and checks the signatures and hashes of each - one. If a PDU fails its signature check then we check if we have it in - the database and if not then request if from the originating server of - that PDU. + """Checks the signatures and hashes of a list of events. + + If a PDU fails its signature check then we check if we have it in + the database, and if not then request it from the sender's server (if that + is different from `origin`). If that still fails, the event is omitted from + the returned list. If a PDU fails its content hash check then it is redacted. - The given list of PDUs are not modified, instead the function returns + Also runs each event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. + + The given list of PDUs are not modified; instead the function returns a new list. Args: - origin - pdu - room_version + origin: The server that sent us these events + pdus: The events to be checked + room_version: the version of the room these events are in Returns: A list of PDUs that have valid signatures and hashes. @@ -463,11 +533,16 @@ async def _check_sigs_and_hash_and_fetch_one( origin: str, room_version: RoomVersion, ) -> Optional[EventBase]: - """Takes a PDU and checks its signatures and hashes. If the PDU fails - its signature check then we check if we have it in the database and if - not then request if from the originating server of that PDU. + """Takes a PDU and checks its signatures and hashes. + + If the PDU fails its signature check then we check if we have it in the + database; if not, we then request it from sender's server (if that is not the + same as `origin`). If that still fails, we return None. + + If the PDU fails its content hash check, it is redacted. - If then PDU fails its content hash check then it is redacted. + Also runs the event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. Args: origin @@ -540,11 +615,15 @@ def _is_unknown_endpoint( synapse_error = e.to_synapse_error() # There is no good way to detect an "unknown" endpoint. # - # Dendrite returns a 404 (with no body); synapse returns a 400 + # Dendrite returns a 404 (with a body of "404 page not found"); + # Conduit returns a 404 (with no body); and Synapse returns a 400 # with M_UNRECOGNISED. - return e.code == 404 or ( - e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED - ) + # + # This needs to be rather specific as some endpoints truly do return 404 + # errors. + return ( + e.code == 404 and (not e.response or e.response == b"404 page not found") + ) or (e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED) async def _try_destination_list( self, @@ -864,23 +943,32 @@ async def _execute(pdu: EventBase) -> None: for s in signed_state: s.internal_metadata = copy.deepcopy(s.internal_metadata) - # double-check that the same create event has ended up in the auth chain + # double-check that the auth chain doesn't include a different create event auth_chain_create_events = [ e.event_id for e in signed_auth if (e.type, e.state_key) == (EventTypes.Create, "") ] - if auth_chain_create_events != [create_event.event_id]: + if auth_chain_create_events and auth_chain_create_events != [ + create_event.event_id + ]: raise InvalidResponseError( "Unexpected create event(s) in auth chain: %s" % (auth_chain_create_events,) ) + if response.partial_state and not response.servers_in_room: + raise InvalidResponseError( + "partial_state was set, but no servers were listed in the room" + ) + return SendJoinResult( event=event, state=signed_state, auth_chain=signed_auth, origin=destination, + partial_state=response.partial_state, + servers_in_room=response.servers_in_room or [], ) # MSC3083 defines additional error codes for room joins. @@ -918,7 +1006,7 @@ async def _do_send_join( ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. if not self._is_unknown_endpoint(e): raise @@ -987,7 +1075,7 @@ async def _do_send_invite( except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, # fallback to the v1 endpoint if the room uses old-style event IDs. - # Otherwise consider it a legitmate error and raise. + # Otherwise, consider it a legitimate error and raise. err = e.to_synapse_error() if self._is_unknown_endpoint(e, err): if room_version.event_format != EventFormatVersions.V1: @@ -1048,7 +1136,7 @@ async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict: ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. if not self._is_unknown_endpoint(e): raise @@ -1274,61 +1362,6 @@ async def get_room_complexity( # server doesn't give it to us. return None - async def get_space_summary( - self, - destinations: Iterable[str], - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: List[str], - ) -> "FederationSpaceSummaryResult": - """ - Call other servers to get a summary of the given space - - - Args: - destinations: The remote servers. We will try them in turn, omitting any - that have been blacklisted. - - room_id: ID of the space to be queried - - suggested_only: If true, ask the remote server to only return children - with the "suggested" flag set - - max_rooms_per_space: A limit on the number of children to return for each - space - - exclude_rooms: A list of room IDs to tell the remote server to skip - - Returns: - a parsed FederationSpaceSummaryResult - - Raises: - SynapseError if we were unable to get a valid summary from any of the - remote servers - """ - - async def send_request(destination: str) -> FederationSpaceSummaryResult: - res = await self.transport_layer.get_space_summary( - destination=destination, - room_id=room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - exclude_rooms=exclude_rooms, - ) - - try: - return FederationSpaceSummaryResult.from_json_dict(res) - except ValueError as e: - raise InvalidResponseError(str(e)) - - return await self._try_destination_list( - "fetch space summary", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) - async def get_room_hierarchy( self, destinations: Iterable[str], @@ -1374,8 +1407,8 @@ async def send_request( ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the unstable endpoint. Otherwise consider it a - # legitmate error and raise. + # fallback to the unstable endpoint. Otherwise, consider it a + # legitimate error and raise. if not self._is_unknown_endpoint(e): raise @@ -1400,10 +1433,8 @@ async def send_request( if any(not isinstance(e, dict) for e in children_state): raise InvalidResponseError("Invalid event in 'children_state' list") try: - [ - FederationSpaceSummaryEventResult.from_json_dict(e) - for e in children_state - ] + for child_state in children_state: + _validate_hierarchy_event(child_state) except ValueError as e: raise InvalidResponseError(str(e)) @@ -1425,62 +1456,12 @@ async def send_request( return room, children_state, children, inaccessible_children - try: - result = await self._try_destination_list( - "fetch room hierarchy", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) - except SynapseError as e: - # If an unexpected error occurred, re-raise it. - if e.code != 502: - raise - - logger.debug( - "Couldn't fetch room hierarchy, falling back to the spaces API" - ) - - # Fallback to the old federation API and translate the results if - # no servers implement the new API. - # - # The algorithm below is a bit inefficient as it only attempts to - # parse information for the requested room, but the legacy API may - # return additional layers. - legacy_result = await self.get_space_summary( - destinations, - room_id, - suggested_only, - max_rooms_per_space=None, - exclude_rooms=[], - ) - - # Find the requested room in the response (and remove it). - for _i, room in enumerate(legacy_result.rooms): - if room.get("room_id") == room_id: - break - else: - # The requested room was not returned, nothing we can do. - raise - requested_room = legacy_result.rooms.pop(_i) - - # Find any children events of the requested room. - children_events = [] - children_room_ids = set() - for event in legacy_result.events: - if event.room_id == room_id: - children_events.append(event.data) - children_room_ids.add(event.state_key) - - # Find the children rooms. - children = [] - for room in legacy_result.rooms: - if room.get("room_id") in children_room_ids: - children.append(room) - - # It isn't clear from the response whether some of the rooms are - # not accessible. - result = (requested_room, children_events, children, ()) + result = await self._try_destination_list( + "fetch room hierarchy", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) # Cache the result to avoid fetching data over federation every time. self._get_room_hierarchy_cache[(room_id, suggested_only)] = result @@ -1526,6 +1507,64 @@ async def timestamp_to_event( except ValueError as e: raise InvalidResponseError(str(e)) + async def get_account_status( + self, destination: str, user_ids: List[str] + ) -> Tuple[JsonDict, List[str]]: + """Retrieves account statuses for a given list of users on a given remote + homeserver. + + If the request fails for any reason, all user IDs for this destination are marked + as failed. + + Args: + destination: the destination to contact + user_ids: the user ID(s) for which to request account status(es) + + Returns: + The account statuses, as well as the list of user IDs for which it was not + possible to retrieve a status. + """ + try: + res = await self.transport_layer.get_account_status(destination, user_ids) + except Exception: + # If the query failed for any reason, mark all the users as failed. + return {}, user_ids + + statuses = res.get("account_statuses", {}) + failures = res.get("failures", []) + + if not isinstance(statuses, dict) or not isinstance(failures, list): + # Make sure we're not feeding back malformed data back to the caller. + logger.warning( + "Destination %s responded with malformed data to account_status query", + destination, + ) + return {}, user_ids + + for user_id in user_ids: + # Any account whose status is missing is a user we failed to receive the + # status of. + if user_id not in statuses and user_id not in failures: + failures.append(user_id) + + # Filter out any user ID that doesn't belong to the remote server that sent its + # status (or failure). + def filter_user_id(user_id: str) -> bool: + try: + return UserID.from_string(user_id).domain == destination + except SynapseError: + # If the user ID doesn't parse, ignore it. + return False + + filtered_statuses = dict( + # item is a (key, value) tuple, so item[0] is the user ID. + filter(lambda item: filter_user_id(item[0]), statuses.items()) + ) + + filtered_failures = list(filter(filter_user_id, failures)) + + return filtered_statuses, filtered_failures + @attr.s(frozen=True, slots=True, auto_attribs=True) class TimestampToEventResponse: @@ -1564,89 +1603,34 @@ def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse": return cls(event_id, origin_server_ts, d) -@attr.s(frozen=True, slots=True, auto_attribs=True) -class FederationSpaceSummaryEventResult: - """Represents a single event in the result of a successful get_space_summary call. - - It's essentially just a serialised event object, but we do a bit of parsing and - validation in `from_json_dict` and store some of the validated properties in - object attributes. - """ - - event_type: str - room_id: str - state_key: str - via: Sequence[str] - - # the raw data, including the above keys - data: JsonDict - - @classmethod - def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": - """Parse an event within the result of a /spaces/ request - - Args: - d: json object to be parsed - - Raises: - ValueError if d is not a valid event - """ - - event_type = d.get("type") - if not isinstance(event_type, str): - raise ValueError("Invalid event: 'event_type' must be a str") - - room_id = d.get("room_id") - if not isinstance(room_id, str): - raise ValueError("Invalid event: 'room_id' must be a str") - - state_key = d.get("state_key") - if not isinstance(state_key, str): - raise ValueError("Invalid event: 'state_key' must be a str") - - content = d.get("content") - if not isinstance(content, dict): - raise ValueError("Invalid event: 'content' must be a dict") +def _validate_hierarchy_event(d: JsonDict) -> None: + """Validate an event within the result of a /hierarchy request - via = content.get("via") - if not isinstance(via, Sequence): - raise ValueError("Invalid event: 'via' must be a list") - if any(not isinstance(v, str) for v in via): - raise ValueError("Invalid event: 'via' must be a list of strings") + Args: + d: json object to be parsed - return cls(event_type, room_id, state_key, via, d) + Raises: + ValueError if d is not a valid event + """ + event_type = d.get("type") + if not isinstance(event_type, str): + raise ValueError("Invalid event: 'event_type' must be a str") -@attr.s(frozen=True, slots=True, auto_attribs=True) -class FederationSpaceSummaryResult: - """Represents the data returned by a successful get_space_summary call.""" + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") - rooms: List[JsonDict] - events: Sequence[FederationSpaceSummaryEventResult] + state_key = d.get("state_key") + if not isinstance(state_key, str): + raise ValueError("Invalid event: 'state_key' must be a str") - @classmethod - def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": - """Parse the result of a /spaces/ request - - Args: - d: json object to be parsed - - Raises: - ValueError if d is not a valid /spaces/ response - """ - rooms = d.get("rooms") - if not isinstance(rooms, List): - raise ValueError("'rooms' must be a list") - if any(not isinstance(r, dict) for r in rooms): - raise ValueError("Invalid room in 'rooms' list") - - events = d.get("events") - if not isinstance(events, Sequence): - raise ValueError("'events' must be a list") - if any(not isinstance(e, dict) for e in events): - raise ValueError("Invalid event in 'events' list") - parsed_events = [ - FederationSpaceSummaryEventResult.from_json_dict(e) for e in events - ] + content = d.get("content") + if not isinstance(content, dict): + raise ValueError("Invalid event: 'content' must be a dict") - return cls(rooms, parsed_events) + via = content.get("via") + if not isinstance(via, Sequence): + raise ValueError("Invalid event: 'via' must be a list") + if any(not isinstance(v, str) for v in via): + raise ValueError("Invalid event: 'via' must be a list of strings") diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 720d7bd74d07..6106a486d10a 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -228,7 +228,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8152e80b88d2..c8768f22bc6b 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -76,7 +76,7 @@ def __init__( ): self._server_name = hs.hostname self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_manager = transaction_manager self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config @@ -381,7 +381,8 @@ async def _catch_up_transmission_loop(self) -> None: ) ) - if self._last_successful_stream_ordering is None: + _tmp_last_successful_stream_ordering = self._last_successful_stream_ordering + if _tmp_last_successful_stream_ordering is None: # if it's still None, then this means we don't have the information # in our database ­ we haven't successfully sent a PDU to this server # (at least since the introduction of the feature tracking @@ -391,11 +392,12 @@ async def _catch_up_transmission_loop(self) -> None: self._catching_up = False return + last_successful_stream_ordering: int = _tmp_last_successful_stream_ordering + # get at most 50 catchup room/PDUs while True: event_ids = await self._store.get_catch_up_room_event_ids( - self._destination, - self._last_successful_stream_ordering, + self._destination, last_successful_stream_ordering ) if not event_ids: @@ -403,7 +405,7 @@ async def _catch_up_transmission_loop(self) -> None: # of a race condition, so we check that no new events have been # skipped due to us being in catch-up mode - if self._catchup_last_skipped > self._last_successful_stream_ordering: + if self._catchup_last_skipped > last_successful_stream_ordering: # another event has been skipped because we were in catch-up mode continue @@ -470,7 +472,7 @@ async def _catch_up_transmission_loop(self) -> None: # offline if ( p.internal_metadata.stream_ordering - < self._last_successful_stream_ordering + < last_successful_stream_ordering ): continue @@ -513,12 +515,11 @@ async def _catch_up_transmission_loop(self) -> None: # from the *original* PDU, rather than the PDU(s) we actually # send. This is because we use it to mark our position in the # queue of missed PDUs to process. - self._last_successful_stream_ordering = ( - pdu.internal_metadata.stream_ordering - ) + last_successful_stream_ordering = pdu.internal_metadata.stream_ordering + self._last_successful_stream_ordering = last_successful_stream_ordering await self._store.set_destination_last_successful_stream_ordering( - self._destination, self._last_successful_stream_ordering + self._destination, last_successful_stream_ordering ) def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 742ee572558d..0c1cad86ab65 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -53,7 +53,7 @@ class TransactionManager: def __init__(self, hs: "synapse.server.HomeServer"): self._server_name = hs.hostname self.clock = hs.get_clock() # nb must be called this for @measure_func - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8782586cd6b4..de6e5f44fe85 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1,4 +1,4 @@ -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2014-2022 The Matrix.org Foundation C.I.C. # Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -60,17 +60,17 @@ class TransportLayerClient: def __init__(self, hs): self.server_name = hs.hostname self.client = hs.get_federation_http_client() + self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled async def get_room_state_ids( self, destination: str, room_id: str, event_id: str ) -> JsonDict: - """Requests all state for a given room from the given server at the - given event. Returns the state's event_id's + """Requests the IDs of all state for a given room at the given event. Args: destination: The host name of the remote homeserver we want to get the state from. - context: The name of the context we want the state of + room_id: the room we want the state of event_id: The event we want the context at. Returns: @@ -86,6 +86,29 @@ async def get_room_state_ids( try_trailing_slash_on_400=True, ) + async def get_room_state( + self, room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> "StateRequestResponse": + """Requests the full state for a given room at the given event. + + Args: + room_version: the version of the room (required to build the event objects) + destination: The host name of the remote homeserver we want + to get the state from. + room_id: the room we want the state of + event_id: The event we want the context at. + + Returns: + Results in a dict received from the remote homeserver. + """ + path = _create_v1_path("/state/%s", room_id) + return await self.client.get_json( + destination, + path=path, + args={"event_id": event_id}, + parser=_StateParser(room_version), + ) + async def get_event( self, destination: str, event_id: str, timeout: Optional[int] = None ) -> JsonDict: @@ -235,8 +258,9 @@ async def make_query( args: dict, retry_on_dns_fail: bool, ignore_backoff: bool = False, + prefix: str = FEDERATION_V1_PREFIX, ) -> JsonDict: - path = _create_v1_path("/query/%s", query_type) + path = _create_path(prefix, "/query/%s", query_type) return await self.client.get_json( destination=destination, @@ -336,10 +360,15 @@ async def send_join_v2( content: JsonDict, ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) + query_params: Dict[str, str] = {} + if self._faster_joins_enabled: + # lazy-load state on join + query_params["org.matrix.msc3706.partial_state"] = "true" return await self.client.put_json( destination=destination, path=path, + args=query_params, data=content, parser=SendJoinParser(room_version, v1_api=False), max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, @@ -1150,39 +1179,6 @@ async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict: return await self.client.get_json(destination=destination, path=path) - async def get_space_summary( - self, - destination: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: List[str], - ) -> JsonDict: - """ - Args: - destination: The remote server - room_id: The room ID to ask about. - suggested_only: if True, only suggested rooms will be returned - max_rooms_per_space: an optional limit to the number of children to be - returned per space - exclude_rooms: a list of any rooms we can skip - """ - # TODO When switching to the stable endpoint, use GET instead of POST. - path = _create_path( - FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/spaces/%s", room_id - ) - - params = { - "suggested_only": suggested_only, - "exclude_rooms": exclude_rooms, - } - if max_rooms_per_space is not None: - params["max_rooms_per_space"] = max_rooms_per_space - - return await self.client.post_json( - destination=destination, path=path, data=params - ) - async def get_room_hierarchy( self, destination: str, room_id: str, suggested_only: bool ) -> JsonDict: @@ -1219,6 +1215,22 @@ async def get_room_hierarchy_unstable( args={"suggested_only": "true" if suggested_only else "false"}, ) + async def get_account_status( + self, destination: str, user_ids: List[str] + ) -> JsonDict: + """ + Args: + destination: The remote server. + user_ids: The user ID(s) for which to request account status(es). + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc3720/account_status" + ) + + return await self.client.post_json( + destination=destination, path=path, data={"user_ids": user_ids} + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ @@ -1271,6 +1283,20 @@ class SendJoinResponse: # "event" is not included in the response. event: Optional[EventBase] = None + # The room state is incomplete + partial_state: bool = False + + # List of servers in the room + servers_in_room: Optional[List[str]] = None + + +@attr.s(slots=True, auto_attribs=True) +class StateRequestResponse: + """The parsed response of a `/state` request.""" + + auth_events: List[EventBase] + state: List[EventBase] + @ijson.coroutine def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: @@ -1297,6 +1323,32 @@ def _event_list_parser( events.append(event) +@ijson.coroutine +def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]: + """Helper function for use with `ijson.items_coro` + + Parses the partial_state field in send_join responses + """ + while True: + val = yield + if not isinstance(val, bool): + raise TypeError("partial_state must be a boolean") + response.partial_state = val + + +@ijson.coroutine +def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]: + """Helper function for use with `ijson.items_coro` + + Parses the servers_in_room field in send_join responses + """ + while True: + val = yield + if not isinstance(val, list) or any(not isinstance(x, str) for x in val): + raise TypeError("servers_in_room must be a list of strings") + response.servers_in_room = val + + class SendJoinParser(ByteParser[SendJoinResponse]): """A parser for the response to `/send_join` requests. @@ -1308,44 +1360,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]): CONTENT_TYPE = "application/json" def __init__(self, room_version: RoomVersion, v1_api: bool): - self._response = SendJoinResponse([], [], {}) + self._response = SendJoinResponse([], [], event_dict={}) self._room_version = room_version + self._coros = [] # The V1 API has the shape of `[200, {...}]`, which we handle by # prefixing with `item.*`. prefix = "item." if v1_api else "" - self._coro_state = ijson.items_coro( - _event_list_parser(room_version, self._response.state), - prefix + "state.item", - use_float=True, - ) - self._coro_auth = ijson.items_coro( - _event_list_parser(room_version, self._response.auth_events), - prefix + "auth_chain.item", - use_float=True, - ) - # TODO Remove the unstable prefix when servers have updated. - # - # By re-using the same event dictionary this will cause the parsing of - # org.matrix.msc3083.v2.event and event to stomp over each other. - # Generally this should be fine. - self._coro_unstable_event = ijson.kvitems_coro( - _event_parser(self._response.event_dict), - prefix + "org.matrix.msc3083.v2.event", - use_float=True, - ) - self._coro_event = ijson.kvitems_coro( - _event_parser(self._response.event_dict), - prefix + "event", - use_float=True, - ) + self._coros = [ + ijson.items_coro( + _event_list_parser(room_version, self._response.state), + prefix + "state.item", + use_float=True, + ), + ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + prefix + "auth_chain.item", + use_float=True, + ), + # TODO Remove the unstable prefix when servers have updated. + # + # By re-using the same event dictionary this will cause the parsing of + # org.matrix.msc3083.v2.event and event to stomp over each other. + # Generally this should be fine. + ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "org.matrix.msc3083.v2.event", + use_float=True, + ), + ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "event", + use_float=True, + ), + ] + + if not v1_api: + self._coros.append( + ijson.items_coro( + _partial_state_parser(self._response), + "org.matrix.msc3706.partial_state", + use_float="True", + ) + ) + + self._coros.append( + ijson.items_coro( + _servers_in_room_parser(self._response), + "org.matrix.msc3706.servers_in_room", + use_float="True", + ) + ) def write(self, data: bytes) -> int: - self._coro_state.send(data) - self._coro_auth.send(data) - self._coro_unstable_event.send(data) - self._coro_event.send(data) + for c in self._coros: + c.send(data) return len(data) @@ -1355,3 +1425,37 @@ def finish(self) -> SendJoinResponse: self._response.event_dict, self._room_version ) return self._response + + +class _StateParser(ByteParser[StateRequestResponse]): + """A parser for the response to `/state` requests. + + Args: + room_version: The version of the room. + """ + + CONTENT_TYPE = "application/json" + + def __init__(self, room_version: RoomVersion): + self._response = StateRequestResponse([], []) + self._room_version = room_version + self._coros = [ + ijson.items_coro( + _event_list_parser(room_version, self._response.state), + "pdus.item", + use_float=True, + ), + ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + "auth_chain.item", + use_float=True, + ), + ] + + def write(self, data: bytes) -> int: + for c in self._coros: + c.send(data) + return len(data) + + def finish(self) -> StateRequestResponse: + return self._response diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index db4fe2c79803..67a634790712 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -24,6 +24,7 @@ ) from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, + FederationAccountStatusServlet, FederationTimestampLookupServlet, ) from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES @@ -336,6 +337,13 @@ def register_servlets( ): continue + # Only allow the `/account_status` servlet if msc3720 is enabled + if ( + servletclass == FederationAccountStatusServlet + and not hs.config.experimental.msc3720_enabled + ): + continue + servletclass( hs=hs, authenticator=authenticator, diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index dff2b68359b2..87e99c7ddf5b 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -55,7 +55,7 @@ def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist ) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index e85a8eda5b87..aed3d5069ca1 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -110,7 +110,7 @@ async def on_PUT( if issue_8631_logger.isEnabledFor(logging.DEBUG): DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"] device_list_updates = [ - edu.content + edu.get("content", {}) for edu in transaction_data.get("edus", []) if edu.get("edu_type") in DEVICE_UPDATE_EDUS ] @@ -624,81 +624,6 @@ async def on_GET( ) -class FederationSpaceSummaryServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - PATH = "/spaces/(?P[^/]*)" - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_room_summary_handler() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) - - max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - exclude_rooms = content.get("exclude_rooms", []) - if not isinstance(exclude_rooms, list) or any( - not isinstance(x, str) for x in exclude_rooms - ): - raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - class FederationRoomHierarchyServlet(BaseFederationServlet): PATH = "/hierarchy/(?P[^/]*)" @@ -746,7 +671,7 @@ def __init__( server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self._store = self.hs.get_datastore() + self._store = self.hs.get_datastores().main async def on_GET( self, @@ -766,6 +691,40 @@ async def on_GET( return 200, complexity +class FederationAccountStatusServlet(BaseFederationServerServlet): + PATH = "/query/account_status" + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3720" + + def __init__( + self, + hs: "HomeServer", + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._account_handler = hs.get_account_handler() + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + if "user_ids" not in content: + raise SynapseError( + 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM + ) + + statuses, failures = await self._account_handler.get_account_statuses( + content["user_ids"], + allow_remote=False, + ) + + return 200, {"account_statuses": statuses, "failures": failures} + + FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationSendServlet, FederationEventServlet, @@ -792,9 +751,9 @@ async def on_GET( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, - FederationSpaceSummaryServlet, FederationRoomHierarchyServlet, FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, + FederationAccountStatusServlet, ) diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index a87896e5386c..ed26d6a6ce72 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -140,7 +140,7 @@ class GroupAttestionRenewer: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.assestations = hs.get_groups_attestation_signing() self.transport_client = hs.get_federation_transport_client() self.is_mine_id = hs.is_mine_id diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 449bbc70042d..4c3a5a6e24d1 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -45,7 +45,7 @@ class GroupsServerWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py new file mode 100644 index 000000000000..d5badf635bf7 --- /dev/null +++ b/synapse/handlers/account.py @@ -0,0 +1,144 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, List, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class AccountHandler: + def __init__(self, hs: "HomeServer"): + self._main_store = hs.get_datastores().main + self._is_mine = hs.is_mine + self._federation_client = hs.get_federation_client() + + async def get_account_statuses( + self, + user_ids: List[str], + allow_remote: bool, + ) -> Tuple[JsonDict, List[str]]: + """Get account statuses for a list of user IDs. + + If one or more account(s) belong to remote homeservers, retrieve their status(es) + over federation if allowed. + + Args: + user_ids: The list of accounts to retrieve the status of. + allow_remote: Whether to try to retrieve the status of remote accounts, if + any. + + Returns: + The account statuses as well as the list of users whose statuses could not be + retrieved. + + Raises: + SynapseError if a required parameter is missing or malformed, or if one of + the accounts isn't local to this homeserver and allow_remote is False. + """ + statuses = {} + failures = [] + remote_users: List[UserID] = [] + + for raw_user_id in user_ids: + try: + user_id = UserID.from_string(raw_user_id) + except SynapseError: + raise SynapseError( + 400, + f"Not a valid Matrix user ID: {raw_user_id}", + Codes.INVALID_PARAM, + ) + + if self._is_mine(user_id): + status = await self._get_local_account_status(user_id) + statuses[user_id.to_string()] = status + else: + if not allow_remote: + raise SynapseError( + 400, + f"Not a local user: {raw_user_id}", + Codes.INVALID_PARAM, + ) + + remote_users.append(user_id) + + if allow_remote and len(remote_users) > 0: + remote_statuses, remote_failures = await self._get_remote_account_statuses( + remote_users, + ) + + statuses.update(remote_statuses) + failures += remote_failures + + return statuses, failures + + async def _get_local_account_status(self, user_id: UserID) -> JsonDict: + """Retrieve the status of a local account. + + Args: + user_id: The account to retrieve the status of. + + Returns: + The account's status. + """ + status = {"exists": False} + + userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) + + if userinfo is not None: + status = { + "exists": True, + "deactivated": userinfo.is_deactivated, + } + + return status + + async def _get_remote_account_statuses( + self, remote_users: List[UserID] + ) -> Tuple[JsonDict, List[str]]: + """Send out federation requests to retrieve the statuses of remote accounts. + + Args: + remote_users: The accounts to retrieve the statuses of. + + Returns: + The statuses of the accounts, and a list of accounts for which no status + could be retrieved. + """ + # Group remote users by destination, so we only send one request per remote + # homeserver. + by_destination: Dict[str, List[str]] = {} + for user in remote_users: + if user.domain not in by_destination: + by_destination[user.domain] = [] + + by_destination[user.domain].append(user.to_string()) + + # Retrieve the statuses and failures for remote accounts. + final_statuses: JsonDict = {} + final_failures: List[str] = [] + for destination, users in by_destination.items(): + statuses, failures = await self._federation_client.get_account_status( + destination, + users, + ) + + final_statuses.update(statuses) + final_failures += failures + + return final_statuses, final_failures diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index bad48713bcb1..177b4f899155 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -30,7 +30,7 @@ class AccountDataHandler: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() @@ -166,7 +166,7 @@ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> in class AccountDataEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_account_data_stream_id() diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 87e415df75e8..9d0975f636ab 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -43,7 +43,7 @@ class AccountValidityHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.send_email_handler = self.hs.get_send_email_handler() self.clock = self.hs.get_clock() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 00ab5e79bf2e..96376963f239 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -29,7 +29,7 @@ class AdminHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index a42c3558e42b..e6461cc3c980 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -47,7 +47,7 @@ class ApplicationServicesHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() self.scheduler = hs.get_application_service_scheduler() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6959d1aa7e47..3e29c96a49e5 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -194,7 +194,7 @@ class AuthHandler: SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.checkers: Dict[str, UserInteractiveAuthChecker] = {} @@ -1183,7 +1183,7 @@ async def validate_login( # No password providers were able to handle this 3pid # Check local store - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( medium, address ) if not user_id: @@ -2064,6 +2064,10 @@ def run(*args: Tuple, **kwargs: Dict) -> Awaitable: [JsonDict, JsonDict], Awaitable[Optional[str]], ] +GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] @@ -2080,6 +2084,9 @@ def __init__(self) -> None: self.get_username_for_registration_callbacks: List[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = [] + self.get_displayname_for_registration_callbacks: List[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] # Mapping from login type to login parameters @@ -2099,6 +2106,9 @@ def register_password_auth_provider_callbacks( get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2148,6 +2158,11 @@ def register_password_auth_provider_callbacks( get_username_for_registration, ) + if get_displayname_for_registration is not None: + self.get_displayname_for_registration_callbacks.append( + get_displayname_for_registration, + ) + if is_3pid_allowed is not None: self.is_3pid_allowed_callbacks.append(is_3pid_allowed) @@ -2350,6 +2365,49 @@ async def get_username_for_registration( return None + async def get_displayname_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the display name to use when registering the user, using the + credentials and parameters provided during the UIA flow. + + Stops at the first callback that returns a tuple containing at least one string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + A tuple which first element is the display name, and the second is an MXC URL + to the user's avatar. + """ + for callback in self.get_displayname_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_displayname_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_displayname_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None + async def is_3pid_allowed( self, medium: str, diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 5d8f6c50a949..7163af8004e3 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -61,7 +61,7 @@ class CasHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 7a13d76a687f..76ae768e6ef5 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -29,7 +29,7 @@ class DeactivateAccountHandler: """Handler which deals with deactivating user accounts.""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() @@ -38,6 +38,7 @@ def __init__(self, hs: "HomeServer"): self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname + self._third_party_rules = hs.get_third_party_event_rules() # Flag that indicates whether the process to part users from rooms is running self._user_parter_running = False @@ -135,9 +136,13 @@ async def deactivate_account( if erase_data: user = UserID.from_string(user_id) # Remove avatar URL from this user - await self._profile_handler.set_avatar_url(user, requester, "", by_admin) + await self._profile_handler.set_avatar_url( + user, requester, "", by_admin, deactivation=True + ) # Remove displayname from this user - await self._profile_handler.set_displayname(user, requester, "", by_admin) + await self._profile_handler.set_displayname( + user, requester, "", by_admin, deactivation=True + ) logger.info("Marking %s as erased", user_id) await self.store.mark_user_erased(user_id) @@ -160,6 +165,13 @@ async def deactivate_account( # Remove account data (including ignored users and push rules). await self.store.purge_account_data_for_user(user_id) + # Let modules know the user has been deactivated. + await self._third_party_rules.on_user_deactivation_status_changed( + user_id, + True, + by_admin, + ) + return identity_server_supports_unbinding async def _reject_pending_invites_for_user(self, user_id: str) -> None: @@ -264,6 +276,10 @@ async def activate_account(self, user_id: str) -> None: # Mark the user as active. await self.store.set_user_deactivated_status(user_id, False) + await self._third_party_rules.on_user_deactivation_status_changed( + user_id, False, True + ) + # Add the user to the directory, if necessary. Note that # this must be done after the user is re-activated, because # deactivated users are excluded from the user directory. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 36c05f8363b9..934b5bd7349b 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -63,7 +63,7 @@ class DeviceWorkerHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.state_store = hs.get_storage().state @@ -628,7 +628,7 @@ class DeviceListUpdater: "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index b582266af908..4cb725d027c7 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -43,7 +43,7 @@ def __init__(self, hs: "HomeServer"): Args: hs: server """ - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.is_mine = hs.is_mine diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 082f521791cd..b7064c6624b7 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -44,7 +44,7 @@ def __init__(self, hs: "HomeServer"): self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d4dfddf63fb4..d96456cd406c 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -47,7 +47,7 @@ class E2eKeysHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() self.is_mine = hs.is_mine @@ -1335,7 +1335,7 @@ class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.e2e_keys_handler = e2e_keys_handler diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 12614b2c5d5a..52e44a2d426f 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -45,7 +45,7 @@ class E2eRoomKeysHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # Used to lock whenever a client is uploading key data. This prevents collisions # between clients trying to upload the details of a new session, given all diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 365063ebdf77..d441ebb0ab3d 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -43,7 +43,7 @@ class EventAuthHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname async def check_auth_rules_from_context( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index bac5de052609..97e75e60c32e 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -33,7 +33,7 @@ class EventStreamHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs @@ -134,7 +134,7 @@ async def get_stream( class EventHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c0f642005ff4..eb03a5accbac 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,8 +49,8 @@ make_deferred_yieldable, nested_logging_context, preserve_fn, - run_in_background, ) +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, @@ -107,7 +107,7 @@ class FederationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.federation_client = hs.get_federation_client() @@ -516,11 +516,20 @@ async def do_invite_join( await self.store.upsert_room_on_join( room_id=room_id, room_version=room_version_obj, - auth_events=auth_chain, + state_events=state, ) + if ret.partial_state: + await self.store.store_partial_state_room(room_id, ret.servers_in_room) + max_stream_id = await self._federation_event_handler.process_remote_join( - origin, room_id, auth_chain, state, event, room_version_obj + origin, + room_id, + auth_chain, + state, + event, + room_version_obj, + partial_state=ret.partial_state, ) # We wait here until this instance has seen the events come down @@ -559,7 +568,9 @@ async def do_invite_join( # lots of requests for missing prev_events which we do actually # have. Hence we fire off the background task, but don't wait for it. - run_in_background(self._handle_queued_pdus, room_queue) + run_as_background_process( + "handle_queued_pdus", self._handle_queued_pdus, room_queue + ) async def do_knock( self, diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9edc7369d6fe..4bd87709f373 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -95,7 +95,7 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._storage = hs.get_storage() self._state_store = self._storage.state @@ -397,6 +397,7 @@ async def process_remote_join( state: List[EventBase], event: EventBase, room_version: RoomVersion, + partial_state: bool, ) -> int: """Persists the events returned by a send_join @@ -412,6 +413,7 @@ async def process_remote_join( event room_version: The room version we expect this room to have, and will raise if it doesn't match the version in the create event. + partial_state: True if the state omits non-critical membership events Returns: The stream ID after which all events have been persisted. @@ -419,10 +421,8 @@ async def process_remote_join( Raises: SynapseError if the response is in some way invalid. """ - event_map = {e.event_id: e for e in itertools.chain(auth_events, state)} - create_event = None - for e in auth_events: + for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): create_event = e break @@ -439,11 +439,6 @@ async def process_remote_join( if room_version.identifier != room_version_id: raise SynapseError(400, "Room version mismatch") - # filter out any events we have already seen - seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) - for s in seen_remotes: - event_map.pop(s, None) - # persist the auth chain and state events. # # any invalid events here will be marked as rejected, and we'll carry on. @@ -455,13 +450,19 @@ async def process_remote_join( # signatures right now doesn't mean that we will *never* be able to, so it # is premature to reject them. # - await self._auth_and_persist_outliers(room_id, event_map.values()) + await self._auth_and_persist_outliers( + room_id, itertools.chain(auth_events, state) + ) # and now persist the join event itself. - logger.info("Peristing join-via-remote %s", event) + logger.info( + "Peristing join-via-remote %s (partial_state: %s)", event, partial_state + ) with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( - event, old_state=state + event, + old_state=state, + partial_state=partial_state, ) context = await self._check_event_auth(origin, event, context) @@ -703,6 +704,8 @@ async def _process_pulled_event( try: state = await self._resolve_state_at_missing_prevs(origin, event) + # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does + # not return partial state await self._process_received_pdu( origin, event, state=state, backfilled=backfilled ) @@ -1245,6 +1248,16 @@ async def _auth_and_persist_outliers( """ event_map = {event.event_id: event for event in events} + # filter out any events we have already seen. This might happen because + # the events were eagerly pushed to us (eg, during a room join), or because + # another thread has raced against us since we decided to request the event. + # + # This is just an optimisation, so it doesn't need to be watertight - the event + # persister does another round of deduplication. + seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) + for s in seen_remotes: + event_map.pop(s, None) + # XXX: it might be possible to kick this process off in parallel with fetching # the events. while event_map: @@ -1717,31 +1730,22 @@ async def _get_remote_auth_chain_for_event( event_id: the event for which we are lacking auth events """ try: - remote_event_map = { - e.event_id: e - for e in await self._federation_client.get_event_auth( - destination, room_id, event_id - ) - } + remote_events = await self._federation_client.get_event_auth( + destination, room_id, event_id + ) + except RequestSendFailed as e1: # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e1) return - logger.info("/event_auth returned %i events", len(remote_event_map)) + logger.info("/event_auth returned %i events", len(remote_events)) # `event` may be returned, but we should not yet process it. - remote_event_map.pop(event_id, None) - - # nor should we reprocess any events we have already seen. - seen_remotes = await self._store.have_seen_events( - room_id, remote_event_map.keys() - ) - for s in seen_remotes: - remote_event_map.pop(s, None) + remote_auth_events = (e for e in remote_events if e.event_id != event_id) - await self._auth_and_persist_outliers(room_id, remote_event_map.values()) + await self._auth_and_persist_outliers(room_id, remote_auth_events) async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] @@ -1795,6 +1799,7 @@ async def _update_context_for_auth_events( prev_state_ids=prev_state_ids, prev_group=prev_group, delta_ids=state_updates, + partial_state=context.partial_state, ) async def _run_push_actions_and_persist_event( diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 9e270d461b2c..e7a399787beb 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -63,7 +63,7 @@ async def f( class GroupsLocalWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.groups_server_handler = hs.get_groups_server_handler() self.transport_client = hs.get_federation_transport_client() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c83eaea359e7..57c9fdfe629a 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -49,7 +49,7 @@ class IdentityHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 346a06ff49b7..344f20f37cce 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -46,7 +46,7 @@ class InitialSyncHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.hs = hs diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 66b21bf2b37d..68f3927a8e1a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -55,8 +55,8 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester -from synapse.util import json_decoder, json_encoder, log_failure -from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError +from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError +from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -75,7 +75,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @@ -397,7 +397,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() @@ -550,10 +550,11 @@ async def create_event( if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version_id = event_dict["content"]["room_version"] - room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) - if not room_version_obj: + maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not maybe_room_version_obj: # this can happen if support is withdrawn for a room version raise UnsupportedRoomVersionError(room_version_id) + room_version_obj = maybe_room_version_obj else: try: room_version_obj = await self.store.get_room_version( @@ -991,6 +992,8 @@ async def create_new_client_event( and full_state_ids_at_event and builder.internal_metadata.is_historical() ): + # TODO(faster_joins): figure out how this works, and make sure that the + # old state is complete. old_state = await self.store.get_events_as_list(full_state_ids_at_event) context = await self.state.compute_event_context(event, old_state=old_state) else: @@ -1145,12 +1148,13 @@ async def handle_new_client_event( room_version_id = event.content.get( "room_version", RoomVersions.V1.identifier ) - room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) - if not room_version_obj: + maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not maybe_room_version_obj: raise UnsupportedRoomVersionError( "Attempt to create a room with unsupported room version %s" % (room_version_id,) ) + room_version_obj = maybe_room_version_obj else: room_version_obj = await self.store.get_room_version(event.room_id) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 8f71d975e9a4..593a2aac6691 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -273,7 +273,7 @@ def __init__( token_generator: "OidcSessionTokenGenerator", provider: OidcProviderConfig, ): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._token_generator = token_generator diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 06d491eda716..dba368764118 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -131,7 +131,7 @@ class PaginationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.clock = hs.get_clock() diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 067c43ae4714..c155098beeb8 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -133,7 +133,7 @@ class BasePresenceHandler(abc.ABC): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -204,25 +204,27 @@ async def current_state_for_users( Returns: dict: `user_id` -> `UserPresenceState` """ - states = { - user_id: self.user_to_current_state.get(user_id, None) - for user_id in user_ids - } + states = {} + missing = [] + for user_id in user_ids: + state = self.user_to_current_state.get(user_id, None) + if state: + states[user_id] = state + else: + missing.append(user_id) - missing = [user_id for user_id, state in states.items() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = await self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.items() if not state] - if missing: - new = { - user_id: UserPresenceState.default(user_id) for user_id in missing - } - states.update(new) - self.user_to_current_state.update(new) + for user_id in missing: + # if user has no state in database, create the state + if not res.get(user_id, None): + new_state = UserPresenceState.default(user_id) + states[user_id] = new_state + self.user_to_current_state[user_id] = new_state return states @@ -1539,7 +1541,7 @@ def __init__(self, hs: "HomeServer"): self.get_presence_handler = hs.get_presence_handler self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36e3ad2ba971..6554c0d3c209 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -54,7 +54,7 @@ class ProfileHandler: PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs @@ -71,6 +71,8 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.config.server.server_name + self._third_party_rules = hs.get_third_party_event_rules() + if hs.config.worker.run_background_tasks: self.clock.looping_call( self._update_remote_profile_cache, self.PROFILE_UPDATE_MS @@ -171,6 +173,7 @@ async def set_displayname( requester: Requester, new_displayname: str, by_admin: bool = False, + deactivation: bool = False, ) -> None: """Set the displayname of a user @@ -179,6 +182,7 @@ async def set_displayname( requester: The user attempting to make this change. new_displayname: The displayname to give this user. by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -227,6 +231,10 @@ async def set_displayname( target_user.to_string(), profile ) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + await self._update_join_states(requester, target_user) async def get_avatar_url(self, target_user: UserID) -> Optional[str]: @@ -261,6 +269,7 @@ async def set_avatar_url( requester: Requester, new_avatar_url: str, by_admin: bool = False, + deactivation: bool = False, ) -> None: """Set a new avatar URL for a user. @@ -269,6 +278,7 @@ async def set_avatar_url( requester: The user attempting to make this change. new_avatar_url: The avatar URL to give this user. by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -315,6 +325,10 @@ async def set_avatar_url( target_user.to_string(), profile ) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + await self._update_join_states(requester, target_user) @cached() diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 29562d9e3b23..9da143754f3c 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -27,7 +27,7 @@ class ReadMarkerHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 704ba7fd6b58..5a7d0fdb70dc 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -29,7 +29,7 @@ class ReceiptsHandler: def __init__(self, hs: "HomeServer"): self.notifier = hs.get_notifier() self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_auth_handler = hs.get_event_auth_handler() self.hs = hs @@ -164,7 +164,7 @@ async def received_client_receipt( class ReceiptEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config @staticmethod diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a719d5eef351..05bb1e0225c3 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -86,7 +86,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() @@ -320,12 +320,12 @@ async def register_user( if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = await self.store.generate_user_id() - user = UserID(localpart, self.hs.hostname) + generated_localpart = await self.store.generate_user_id() + user = UserID(generated_localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) if generate_display_name: - default_display_name = localpart + default_display_name = generated_localpart try: await self.register_with_store( user_id=user_id, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a4689c172d47..7b1224a3609e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -117,7 +117,7 @@ class EventContext: class RoomCreationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.hs = hs @@ -1130,7 +1130,7 @@ class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state @@ -1261,7 +1261,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: class TimestampLookupHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() @@ -1401,7 +1401,7 @@ async def get_event_for_timestamp( class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, @@ -1491,7 +1491,7 @@ def __init__(self, hs: "HomeServer"): self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def shutdown_room( self, diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 9be4d82e8c03..f460ad671633 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -16,7 +16,7 @@ class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 1a33211a1fe1..f3577b5d5aec 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -49,7 +49,7 @@ class RoomListHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index bf1a47efb020..a582837cf0aa 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -66,7 +66,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config @@ -82,6 +82,7 @@ def __init__(self, hs: "HomeServer"): self.event_auth_handler = hs.get_event_auth_handler() self.member_linearizer: Linearizer = Linearizer(name="member") + self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() @@ -500,25 +501,32 @@ async def update_membership( key = (room_id,) - with (await self.member_linearizer.queue(key)): - result = await self.update_membership_locked( - requester, - target, - room_id, - action, - txn_id=txn_id, - remote_room_hosts=remote_room_hosts, - third_party_signed=third_party_signed, - ratelimit=ratelimit, - content=content, - new_room=new_room, - require_consent=require_consent, - outlier=outlier, - historical=historical, - allow_no_prev_events=allow_no_prev_events, - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - ) + as_id = object() + if requester.app_service: + as_id = requester.app_service.id + + # We first linearise by the application service (to try to limit concurrent joins + # by application services), and then by room ID. + with (await self.member_as_limiter.queue(as_id)): + with (await self.member_linearizer.queue(key)): + result = await self.update_membership_locked( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + content=content, + new_room=new_room, + require_consent=require_consent, + outlier=outlier, + historical=historical, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + ) return result diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 4844b69a0345..55c2cbdba8bc 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -15,7 +15,6 @@ import itertools import logging import re -from collections import deque from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple import attr @@ -90,7 +89,7 @@ class RoomSummaryHandler: def __init__(self, hs: "HomeServer"): self._event_auth_handler = hs.get_event_auth_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() @@ -107,153 +106,6 @@ def __init__(self, hs: "HomeServer"): "get_room_hierarchy", ) - async def get_space_summary( - self, - requester: str, - room_id: str, - suggested_only: bool = False, - max_rooms_per_space: Optional[int] = None, - ) -> JsonDict: - """ - Implementation of the space summary C-S API - - Args: - requester: user id of the user making this request - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. This does not apply to the root room (ie, room_id), and - is overridden by MAX_ROOMS_PER_SPACE. - - Returns: - summary dict to return - """ - # First of all, check that the room is accessible. - if not await self._is_local_room_accessible(room_id, requester): - raise AuthError( - 403, - "User %s not in room %s, and room previews are disabled" - % (requester, room_id), - ) - - # the queue of rooms to process - room_queue = deque((_RoomQueueEntry(room_id, ()),)) - - # rooms we have already processed - processed_rooms: Set[str] = set() - - # events we have already processed. We don't necessarily have their event ids, - # so instead we key on (room id, state key) - processed_events: Set[Tuple[str, str]] = set() - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: - max_rooms_per_space = MAX_ROOMS_PER_SPACE - - while room_queue and len(rooms_result) < MAX_ROOMS: - queue_entry = room_queue.popleft() - room_id = queue_entry.room_id - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - is_in_room = await self._store.is_host_joined(room_id, self._server_name) - - # The client-specified max_rooms_per_space limit doesn't apply to the - # room_id specified in the request, so we ignore it if this is the - # first room we are processing. - max_children = max_rooms_per_space if processed_rooms else MAX_ROOMS - - if is_in_room: - room_entry = await self._summarize_local_room( - requester, None, room_id, suggested_only, max_children - ) - - events: Sequence[JsonDict] = [] - if room_entry: - rooms_result.append(room_entry.room) - events = room_entry.children_state_events - - logger.debug( - "Query of local room %s returned events %s", - room_id, - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - else: - fed_rooms = await self._summarize_remote_room( - queue_entry, - suggested_only, - max_children, - exclude_rooms=processed_rooms, - ) - - # The results over federation might include rooms that the we, - # as the requesting server, are allowed to see, but the requesting - # user is not permitted see. - # - # Filter the returned results to only what is accessible to the user. - events = [] - for room_entry in fed_rooms: - room = room_entry.room - fed_room_id = room_entry.room_id - - # The user can see the room, include it! - if await self._is_remote_room_accessible( - requester, fed_room_id, room - ): - # Before returning to the client, remove the allowed_room_ids - # and allowed_spaces keys. - room.pop("allowed_room_ids", None) - room.pop("allowed_spaces", None) # historical - - rooms_result.append(room) - events.extend(room_entry.children_state_events) - - # All rooms returned don't need visiting again (even if the user - # didn't have access to them). - processed_rooms.add(fed_room_id) - - logger.debug( - "Query of %s returned rooms %s, events %s", - room_id, - [room_entry.room.get("room_id") for room_entry in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - - # the room we queried may or may not have been returned, but don't process - # it again, anyway. - processed_rooms.add(room_id) - - # XXX: is it ok that we blindly iterate through any events returned by - # a remote server, whether or not they actually link to any rooms in our - # tree? - for ev in events: - # remote servers might return events we have already processed - # (eg, Dendrite returns inward pointers as well as outward ones), so - # we need to filter them out, to avoid returning duplicate links to the - # client. - ev_key = (ev["room_id"], ev["state_key"]) - if ev_key in processed_events: - continue - events_result.append(ev) - - # add the child to the queue. we have already validated - # that the vias are a list of server names. - room_queue.append( - _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) - ) - processed_events.add(ev_key) - - return {"rooms": rooms_result, "events": events_result} - async def get_room_hierarchy( self, requester: Requester, @@ -398,8 +250,6 @@ async def _get_room_hierarchy( None, room_id, suggested_only, - # Do not limit the maximum children. - max_children=None, ) # Otherwise, attempt to use information for federation. @@ -488,74 +338,6 @@ async def _get_room_hierarchy( return result - async def federation_space_summary( - self, - origin: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: Iterable[str], - ) -> JsonDict: - """ - Implementation of the space summary Federation API - - Args: - origin: The server requesting the spaces summary. - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. Unlike the C-S API, this applies to the root room (room_id). - It is clipped to MAX_ROOMS_PER_SPACE. - - exclude_rooms: a list of rooms to skip over (presumably because the - calling server has already seen them). - - Returns: - summary dict to return - """ - # the queue of rooms to process - room_queue = deque((room_id,)) - - # the set of rooms that we should not walk further. Initialise it with the - # excluded-rooms list; we will add other rooms as we process them so that - # we do not loop. - processed_rooms: Set[str] = set(exclude_rooms) - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - # Set a limit on the number of rooms to return. - if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: - max_rooms_per_space = MAX_ROOMS_PER_SPACE - - while room_queue and len(rooms_result) < MAX_ROOMS: - room_id = room_queue.popleft() - if room_id in processed_rooms: - # already done this room - continue - - room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_rooms_per_space - ) - - processed_rooms.add(room_id) - - if room_entry: - rooms_result.append(room_entry.room) - events_result.extend(room_entry.children_state_events) - - # add any children to the queue - room_queue.extend( - edge_event["state_key"] - for edge_event in room_entry.children_state_events - ) - - return {"rooms": rooms_result, "events": events_result} - async def get_federation_hierarchy( self, origin: str, @@ -579,7 +361,7 @@ async def get_federation_hierarchy( The JSON hierarchy dictionary. """ root_room_entry = await self._summarize_local_room( - None, origin, requested_room_id, suggested_only, max_children=None + None, origin, requested_room_id, suggested_only ) if root_room_entry is None: # Room is inaccessible to the requesting server. @@ -600,7 +382,7 @@ async def get_federation_hierarchy( continue room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_children=0 + None, origin, room_id, suggested_only, include_children=False ) # If the room is accessible, include it in the results. # @@ -626,7 +408,7 @@ async def _summarize_local_room( origin: Optional[str], room_id: str, suggested_only: bool, - max_children: Optional[int], + include_children: bool = True, ) -> Optional["_RoomEntry"]: """ Generate a room entry and a list of event entries for a given room. @@ -641,9 +423,8 @@ async def _summarize_local_room( room_id: The room ID to summarize. suggested_only: True if only suggested children should be returned. Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. A value of None - means no limit. + include_children: + Whether to include the events of any children. Returns: A room entry if the room should be returned. None, otherwise. @@ -653,9 +434,8 @@ async def _summarize_local_room( room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) - # If the room is not a space or the children don't matter, return just - # the room information. - if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: + # If the room is not a space return just the room information. + if room_entry.get("room_type") != RoomTypes.SPACE or not include_children: return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. @@ -665,14 +445,6 @@ async def _summarize_local_room( # we only care about suggested children child_events = filter(_is_suggested_child_event, child_events) - # TODO max_children is legacy code for the /spaces endpoint. - if max_children is not None: - child_iter: Iterable[EventBase] = itertools.islice( - child_events, max_children - ) - else: - child_iter = child_events - stripped_events: List[JsonDict] = [ { "type": e.type, @@ -682,80 +454,10 @@ async def _summarize_local_room( "sender": e.sender, "origin_server_ts": e.origin_server_ts, } - for e in child_iter + for e in child_events ] return _RoomEntry(room_id, room_entry, stripped_events) - async def _summarize_remote_room( - self, - room: "_RoomQueueEntry", - suggested_only: bool, - max_children: Optional[int], - exclude_rooms: Iterable[str], - ) -> Iterable["_RoomEntry"]: - """ - Request room entries and a list of event entries for a given room by querying a remote server. - - Args: - room: The room to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - exclude_rooms: - Rooms IDs which do not need to be summarized. - - Returns: - An iterable of room entries. - """ - room_id = room.room_id - logger.info("Requesting summary for %s via %s", room_id, room.via) - - # we need to make the exclusion list json-serialisable - exclude_rooms = list(exclude_rooms) - - via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) - try: - res = await self._federation_client.get_space_summary( - via, - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_children, - exclude_rooms=exclude_rooms, - ) - except Exception as e: - logger.warning( - "Unable to get summary of %s via federation: %s", - room_id, - e, - exc_info=logger.isEnabledFor(logging.DEBUG), - ) - return () - - # Group the events by their room. - children_by_room: Dict[str, List[JsonDict]] = {} - for ev in res.events: - if ev.event_type == EventTypes.SpaceChild: - children_by_room.setdefault(ev.room_id, []).append(ev.data) - - # Generate the final results. - results = [] - for fed_room in res.rooms: - fed_room_id = fed_room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue - - results.append( - _RoomEntry( - fed_room_id, - fed_room, - children_by_room.get(fed_room_id, []), - ) - ) - - return results - async def _summarize_remote_room_hierarchy( self, room: "_RoomQueueEntry", suggested_only: bool ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: @@ -958,9 +660,8 @@ async def _is_remote_room_accessible( ): return True - # Check if the user is a member of any of the allowed spaces - # from the response. - allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") + # Check if the user is a member of any of the allowed rooms from the response. + allowed_rooms = room.get("allowed_room_ids") if allowed_rooms and isinstance(allowed_rooms, list): if await self._event_auth_handler.is_user_in_rooms( allowed_rooms, requester @@ -1028,8 +729,6 @@ async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDic ) if allowed_rooms: entry["allowed_room_ids"] = allowed_rooms - # TODO Remove this key once the API is stable. - entry["allowed_spaces"] = allowed_rooms # Filter out Nones – rather omit the field altogether room_entry = {k: v for k, v in entry.items() if v is not None} @@ -1094,7 +793,7 @@ async def get_room_summary( room_id, # Suggested-only doesn't matter since no children are requested. suggested_only=False, - max_children=0, + include_children=False, ) if not room_entry: diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 727d75a50c6c..9602f0d0bb48 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -52,7 +52,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 41cb80907893..aa16e417eb30 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -14,8 +14,9 @@ import itertools import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +import attr from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, Membership @@ -32,9 +33,23 @@ logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _SearchResult: + # The count of results. + count: int + # A mapping of event ID to the rank of that event. + rank_map: Dict[str, int] + # A list of the resulting events. + allowed_events: List[EventBase] + # A map of room ID to results. + room_groups: Dict[str, JsonDict] + # A set of event IDs to highlight. + highlights: Set[str] + + class SearchHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.clock = hs.get_clock() self.hs = hs @@ -100,7 +115,7 @@ async def search( """Performs a full text search for a user. Args: - user + user: The user performing the search. content: Search parameters batch: The next_batch parameter. Used for pagination. @@ -156,6 +171,8 @@ async def search( # Include context around each event? event_context = room_cat.get("event_context", None) + before_limit = after_limit = None + include_profile = False # Group results together? May allow clients to paginate within a # group @@ -182,6 +199,73 @@ async def search( % (set(group_keys) - {"room_id", "sender"},), ) + return await self._search( + user, + batch_group, + batch_group_key, + batch_token, + search_term, + keys, + filter_dict, + order_by, + include_state, + group_keys, + event_context, + before_limit, + after_limit, + include_profile, + ) + + async def _search( + self, + user: UserID, + batch_group: Optional[str], + batch_group_key: Optional[str], + batch_token: Optional[str], + search_term: str, + keys: List[str], + filter_dict: JsonDict, + order_by: str, + include_state: bool, + group_keys: List[str], + event_context: Optional[bool], + before_limit: Optional[int], + after_limit: Optional[int], + include_profile: bool, + ) -> JsonDict: + """Performs a full text search for a user. + + Args: + user: The user performing the search. + batch_group: Pagination information. + batch_group_key: Pagination information. + batch_token: Pagination information. + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + filter_dict: The JSON to build a filter out of. + order_by: How to order the results. Valid values ore "rank" and "recent". + include_state: True if the state of the room at each result should + be included. + group_keys: A list of ways to group the results. Valid values are + "room_id" and "sender". + event_context: True to include contextual events around results. + before_limit: + The number of events before a result to include as context. + + Only used if event_context is True. + after_limit: + The number of events after a result to include as context. + + Only used if event_context is True. + include_profile: True if historical profile information should be + included in the event context. + + Only used if event_context is True. + + Returns: + dict to be returned to the client with results of search + """ search_filter = Filter(self.hs, filter_dict) # TODO: Search through left rooms too @@ -216,278 +300,399 @@ async def search( } } - rank_map = {} # event_id -> rank of event - allowed_events = [] - # Holds result of grouping by room, if applicable - room_groups: Dict[str, JsonDict] = {} - # Holds result of grouping by sender, if applicable - sender_group: Dict[str, JsonDict] = {} + sender_group: Optional[Dict[str, JsonDict]] - # Holds the next_batch for the entire result set if one of those exists - global_next_batch = None - - highlights = set() + if order_by == "rank": + search_result, sender_group = await self._search_by_rank( + user, room_ids, search_term, keys, search_filter + ) + # Unused return values for rank search. + global_next_batch = None + elif order_by == "recent": + search_result, global_next_batch = await self._search_by_recent( + user, + room_ids, + search_term, + keys, + search_filter, + batch_group, + batch_group_key, + batch_token, + ) + # Unused return values for recent search. + sender_group = None + else: + # We should never get here due to the guard earlier. + raise NotImplementedError() - count = None + logger.info("Found %d events to return", len(search_result.allowed_events)) - if order_by == "rank": - search_result = await self.store.search_msgs(room_ids, search_term, keys) + # If client has asked for "context" for each event (i.e. some surrounding + # events and state), fetch that + if event_context is not None: + # Note that before and after limit must be set in this case. + assert before_limit is not None + assert after_limit is not None + + contexts = await self._calculate_event_contexts( + user, + search_result.allowed_events, + before_limit, + after_limit, + include_profile, + ) + else: + contexts = {} - count = search_result["count"] + # TODO: Add a limit - if search_result["highlights"]: - highlights.update(search_result["highlights"]) + state_results = {} + if include_state: + for room_id in {e.room_id for e in search_result.allowed_events}: + state = await self.state_handler.get_current_state(room_id) + state_results[room_id] = list(state.values()) - results = search_result["results"] + aggregations = None + if self._msc3666_enabled: + aggregations = await self.store.get_bundled_aggregations( + # Generate an iterable of EventBase for all the events that will be + # returned, including contextual events. + itertools.chain( + # The events_before and events_after for each context. + itertools.chain.from_iterable( + itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type] + for context in contexts.values() + ), + # The returned events. + search_result.allowed_events, + ), + user.to_string(), + ) - rank_map.update({r["event"].event_id: r["rank"] for r in results}) + # We're now about to serialize the events. We should not make any + # blocking calls after this. Otherwise, the 'age' will be wrong. - filtered_events = await search_filter.filter([r["event"] for r in results]) + time_now = self.clock.time_msec() - events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + for context in contexts.values(): + context["events_before"] = self._event_serializer.serialize_events( + context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] + ) + context["events_after"] = self._event_serializer.serialize_events( + context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] ) - events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[: search_filter.limit] + results = [ + { + "rank": search_result.rank_map[e.event_id], + "result": self._event_serializer.serialize_event( + e, time_now, bundle_aggregations=aggregations + ), + "context": contexts.get(e.event_id, {}), + } + for e in search_result.allowed_events + ] - for e in allowed_events: - rm = room_groups.setdefault( - e.room_id, {"results": [], "order": rank_map[e.event_id]} - ) - rm["results"].append(e.event_id) + rooms_cat_res: JsonDict = { + "results": results, + "count": search_result.count, + "highlights": list(search_result.highlights), + } - s = sender_group.setdefault( - e.sender, {"results": [], "order": rank_map[e.event_id]} - ) - s["results"].append(e.event_id) + if state_results: + rooms_cat_res["state"] = { + room_id: self._event_serializer.serialize_events(state_events, time_now) + for room_id, state_events in state_results.items() + } - elif order_by == "recent": - room_events: List[EventBase] = [] - i = 0 - - pagination_token = batch_token - - # We keep looping and we keep filtering until we reach the limit - # or we run out of things. - # But only go around 5 times since otherwise synapse will be sad. - while len(room_events) < search_filter.limit and i < 5: - i += 1 - search_result = await self.store.search_rooms( - room_ids, - search_term, - keys, - search_filter.limit * 2, - pagination_token=pagination_token, - ) + if search_result.room_groups and "room_id" in group_keys: + rooms_cat_res.setdefault("groups", {})[ + "room_id" + ] = search_result.room_groups - if search_result["highlights"]: - highlights.update(search_result["highlights"]) + if sender_group and "sender" in group_keys: + rooms_cat_res.setdefault("groups", {})["sender"] = sender_group - count = search_result["count"] + if global_next_batch: + rooms_cat_res["next_batch"] = global_next_batch - results = search_result["results"] + return {"search_categories": {"room_events": rooms_cat_res}} - results_map = {r["event"].event_id: r for r in results} + async def _search_by_rank( + self, + user: UserID, + room_ids: Collection[str], + search_term: str, + keys: Iterable[str], + search_filter: Filter, + ) -> Tuple[_SearchResult, Dict[str, JsonDict]]: + """ + Performs a full text search for a user ordering by rank. - rank_map.update({r["event"].event_id: r["rank"] for r in results}) + Args: + user: The user performing the search. + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + search_filter: The event filter to use. - filtered_events = await search_filter.filter( - [r["event"] for r in results] - ) + Returns: + A tuple of: + The search results. + A map of sender ID to results. + """ + rank_map = {} # event_id -> rank of event + # Holds result of grouping by room, if applicable + room_groups: Dict[str, JsonDict] = {} + # Holds result of grouping by sender, if applicable + sender_group: Dict[str, JsonDict] = {} - events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events - ) + search_result = await self.store.search_msgs(room_ids, search_term, keys) - room_events.extend(events) - room_events = room_events[: search_filter.limit] + if search_result["highlights"]: + highlights = search_result["highlights"] + else: + highlights = set() - if len(results) < search_filter.limit * 2: - pagination_token = None - break - else: - pagination_token = results[-1]["pagination_token"] - - for event in room_events: - group = room_groups.setdefault(event.room_id, {"results": []}) - group["results"].append(event.event_id) - - if room_events and len(room_events) >= search_filter.limit: - last_event_id = room_events[-1].event_id - pagination_token = results_map[last_event_id]["pagination_token"] - - # We want to respect the given batch group and group keys so - # that if people blindly use the top level `next_batch` token - # it returns more from the same group (if applicable) rather - # than reverting to searching all results again. - if batch_group and batch_group_key: - global_next_batch = encode_base64( - ( - "%s\n%s\n%s" - % (batch_group, batch_group_key, pagination_token) - ).encode("ascii") - ) - else: - global_next_batch = encode_base64( - ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") - ) + results = search_result["results"] - for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64( - ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( - "ascii" - ) - ) + # event_id -> rank of event + rank_map = {r["event"].event_id: r["rank"] for r in results} - allowed_events.extend(room_events) + filtered_events = await search_filter.filter([r["event"] for r in results]) - else: - # We should never get here due to the guard earlier. - raise NotImplementedError() + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events + ) - logger.info("Found %d events to return", len(allowed_events)) + events.sort(key=lambda e: -rank_map[e.event_id]) + allowed_events = events[: search_filter.limit] - # If client has asked for "context" for each event (i.e. some surrounding - # events and state), fetch that - if event_context is not None: - now_token = self.hs.get_event_sources().get_current_token() + for e in allowed_events: + rm = room_groups.setdefault( + e.room_id, {"results": [], "order": rank_map[e.event_id]} + ) + rm["results"].append(e.event_id) - contexts = {} - for event in allowed_events: - res = await self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit - ) + s = sender_group.setdefault( + e.sender, {"results": [], "order": rank_map[e.event_id]} + ) + s["results"].append(e.event_id) + + return ( + _SearchResult( + search_result["count"], + rank_map, + allowed_events, + room_groups, + highlights, + ), + sender_group, + ) - logger.info( - "Context for search returned %d and %d events", - len(res.events_before), - len(res.events_after), - ) + async def _search_by_recent( + self, + user: UserID, + room_ids: Collection[str], + search_term: str, + keys: Iterable[str], + search_filter: Filter, + batch_group: Optional[str], + batch_group_key: Optional[str], + batch_token: Optional[str], + ) -> Tuple[_SearchResult, Optional[str]]: + """ + Performs a full text search for a user ordering by recent. - events_before = await filter_events_for_client( - self.storage, user.to_string(), res.events_before - ) + Args: + user: The user performing the search. + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + search_filter: The event filter to use. + batch_group: Pagination information. + batch_group_key: Pagination information. + batch_token: Pagination information. - events_after = await filter_events_for_client( - self.storage, user.to_string(), res.events_after - ) + Returns: + A tuple of: + The search results. + Optionally, a pagination token. + """ + rank_map = {} # event_id -> rank of event + # Holds result of grouping by room, if applicable + room_groups: Dict[str, JsonDict] = {} - context = { - "events_before": events_before, - "events_after": events_after, - "start": await now_token.copy_and_replace( - "room_key", res.start - ).to_string(self.store), - "end": await now_token.copy_and_replace( - "room_key", res.end - ).to_string(self.store), - } + # Holds the next_batch for the entire result set if one of those exists + global_next_batch = None - if include_profile: - senders = { - ev.sender - for ev in itertools.chain(events_before, [event], events_after) - } + highlights = set() - if events_after: - last_event_id = events_after[-1].event_id - else: - last_event_id = event.event_id + room_events: List[EventBase] = [] + i = 0 + + pagination_token = batch_token + + # We keep looping and we keep filtering until we reach the limit + # or we run out of things. + # But only go around 5 times since otherwise synapse will be sad. + while len(room_events) < search_filter.limit and i < 5: + i += 1 + search_result = await self.store.search_rooms( + room_ids, + search_term, + keys, + search_filter.limit * 2, + pagination_token=pagination_token, + ) - state_filter = StateFilter.from_types( - [(EventTypes.Member, sender) for sender in senders] - ) + if search_result["highlights"]: + highlights.update(search_result["highlights"]) + + count = search_result["count"] + + results = search_result["results"] + + results_map = {r["event"].event_id: r for r in results} + + rank_map.update({r["event"].event_id: r["rank"] for r in results}) + + filtered_events = await search_filter.filter([r["event"] for r in results]) - state = await self.state_store.get_state_for_event( - last_event_id, state_filter + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events + ) + + room_events.extend(events) + room_events = room_events[: search_filter.limit] + + if len(results) < search_filter.limit * 2: + break + else: + pagination_token = results[-1]["pagination_token"] + + for event in room_events: + group = room_groups.setdefault(event.room_id, {"results": []}) + group["results"].append(event.event_id) + + if room_events and len(room_events) >= search_filter.limit: + last_event_id = room_events[-1].event_id + pagination_token = results_map[last_event_id]["pagination_token"] + + # We want to respect the given batch group and group keys so + # that if people blindly use the top level `next_batch` token + # it returns more from the same group (if applicable) rather + # than reverting to searching all results again. + if batch_group and batch_group_key: + global_next_batch = encode_base64( + ( + "%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token) + ).encode("ascii") + ) + else: + global_next_batch = encode_base64( + ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") + ) + + for room_id, group in room_groups.items(): + group["next_batch"] = encode_base64( + ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( + "ascii" ) + ) - context["profile_info"] = { - s.state_key: { - "displayname": s.content.get("displayname", None), - "avatar_url": s.content.get("avatar_url", None), - } - for s in state.values() - if s.type == EventTypes.Member and s.state_key in senders - } + return ( + _SearchResult(count, rank_map, room_events, room_groups, highlights), + global_next_batch, + ) - contexts[event.event_id] = context - else: - contexts = {} + async def _calculate_event_contexts( + self, + user: UserID, + allowed_events: List[EventBase], + before_limit: int, + after_limit: int, + include_profile: bool, + ) -> Dict[str, JsonDict]: + """ + Calculates the contextual events for any search results. - # TODO: Add a limit + Args: + user: The user performing the search. + allowed_events: The search results. + before_limit: + The number of events before a result to include as context. + after_limit: + The number of events after a result to include as context. + include_profile: True if historical profile information should be + included in the event context. - time_now = self.clock.time_msec() + Returns: + A map of event ID to contextual information. + """ + now_token = self.hs.get_event_sources().get_current_token() - aggregations = None - if self._msc3666_enabled: - aggregations = await self.store.get_bundled_aggregations( - # Generate an iterable of EventBase for all the events that will be - # returned, including contextual events. - itertools.chain( - # The events_before and events_after for each context. - itertools.chain.from_iterable( - itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type] - for context in contexts.values() - ), - # The returned events. - allowed_events, - ), - user.to_string(), + contexts = {} + for event in allowed_events: + res = await self.store.get_events_around( + event.room_id, event.event_id, before_limit, after_limit ) - for context in contexts.values(): - context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] + logger.info( + "Context for search returned %d and %d events", + len(res.events_before), + len(res.events_after), ) - context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] + + events_before = await filter_events_for_client( + self.storage, user.to_string(), res.events_before ) - state_results = {} - if include_state: - for room_id in {e.room_id for e in allowed_events}: - state = await self.state_handler.get_current_state(room_id) - state_results[room_id] = list(state.values()) + events_after = await filter_events_for_client( + self.storage, user.to_string(), res.events_after + ) - # We're now about to serialize the events. We should not make any - # blocking calls after this. Otherwise the 'age' will be wrong + context: JsonDict = { + "events_before": events_before, + "events_after": events_after, + "start": await now_token.copy_and_replace( + "room_key", res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace("room_key", res.end).to_string( + self.store + ), + } - results = [] - for e in allowed_events: - results.append( - { - "rank": rank_map[e.event_id], - "result": self._event_serializer.serialize_event( - e, time_now, bundle_aggregations=aggregations - ), - "context": contexts.get(e.event_id, {}), + if include_profile: + senders = { + ev.sender + for ev in itertools.chain(events_before, [event], events_after) } - ) - rooms_cat_res = { - "results": results, - "count": count, - "highlights": list(highlights), - } + if events_after: + last_event_id = events_after[-1].event_id + else: + last_event_id = event.event_id - if state_results: - s = {} - for room_id, state_events in state_results.items(): - s[room_id] = self._event_serializer.serialize_events( - state_events, time_now + state_filter = StateFilter.from_types( + [(EventTypes.Member, sender) for sender in senders] ) - rooms_cat_res["state"] = s - - if room_groups and "room_id" in group_keys: - rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups + state = await self.state_store.get_state_for_event( + last_event_id, state_filter + ) - if sender_group and "sender" in group_keys: - rooms_cat_res.setdefault("groups", {})["sender"] = sender_group + context["profile_info"] = { + s.state_key: { + "displayname": s.content.get("displayname", None), + "avatar_url": s.content.get("avatar_url", None), + } + for s in state.values() + if s.type == EventTypes.Member and s.state_key in senders + } - if global_next_batch: - rooms_cat_res["next_batch"] = global_next_batch + contexts[event.event_id] = context - return {"search_categories": {"room_events": rooms_cat_res}} + return contexts diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 706ad72761a3..73861bbd4085 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -27,7 +27,7 @@ class SetPasswordHandler: """Handler which deals with changing user account passwords""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 0bb8b0929e7e..ff5b5169cac1 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -180,7 +180,7 @@ class SsoHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index d30ba2b7248c..2d197282edda 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -30,7 +30,7 @@ class MatchChange(Enum): class StateDeltasHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _get_key_change( self, diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 29e41a4c796c..436cd971cedd 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -39,7 +39,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index aa9a76f8a968..0aa3052fd6ac 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -266,7 +266,7 @@ def __bool__(self) -> bool: class SyncHandler: def __init__(self, hs: "HomeServer"): self.hs_config = hs.config - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() @@ -697,6 +697,15 @@ async def get_state_at( else: # no events in this room - so presumably no state state = {} + + # (erikj) This should be rarely hit, but we've had some reports that + # we get more state down gappy syncs than we should, so let's add + # some logging. + logger.info( + "Failed to find any events in room %s at %s", + room_id, + stream_position.room_key, + ) return state async def compute_summary( @@ -1289,23 +1298,54 @@ async def _generate_sync_entry_for_device_list( # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id - ) - - # Always tell the user about their own devices. We check as the user - # ID is almost certainly already included (unless they're not in any - # rooms) and taking a copy of the set is relatively expensive. - if user_id not in users_who_share_room: - users_who_share_room = set(users_who_share_room) - users_who_share_room.add(user_id) + users_that_have_changed = set() - tracked_users = users_who_share_room + joined_rooms = sync_result_builder.joined_room_ids - # Step 1a, check for changes in devices of users we share a room with - users_that_have_changed = await self.store.get_users_whose_devices_changed( - since_token.device_list_key, tracked_users + # Step 1a, check for changes in devices of users we share a room + # with + # + # We do this in two different ways depending on what we have cached. + # If we already have a list of all the user that have changed since + # the last sync then it's likely more efficient to compare the rooms + # they're in with the rooms the syncing user is in. + # + # If we don't have that info cached then we get all the users that + # share a room with our user and check if those users have changed. + changed_users = self.store.get_cached_device_list_changes( + since_token.device_list_key ) + if changed_users is not None: + result = await self.store.get_rooms_for_users_with_stream_ordering( + changed_users + ) + + for changed_user_id, entries in result.items(): + # Check if the changed user shares any rooms with the user, + # or if the changed user is the syncing user (as we always + # want to include device list updates of their own devices). + if user_id == changed_user_id or any( + e.room_id in joined_rooms for e in entries + ): + users_that_have_changed.add(changed_user_id) + else: + users_who_share_room = ( + await self.store.get_users_who_share_room_with_user(user_id) + ) + + # Always tell the user about their own devices. We check as the user + # ID is almost certainly already included (unless they're not in any + # rooms) and taking a copy of the set is relatively expensive. + if user_id not in users_who_share_room: + users_who_share_room = set(users_who_share_room) + users_who_share_room.add(user_id) + + tracked_users = users_who_share_room + users_that_have_changed = ( + await self.store.get_users_whose_devices_changed( + since_token.device_list_key, tracked_users + ) + ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: @@ -1329,7 +1369,14 @@ async def _generate_sync_entry_for_device_list( newly_left_users.update(left_users) # Remove any users that we still share a room with. - newly_left_users -= users_who_share_room + left_users_rooms = ( + await self.store.get_rooms_for_users_with_stream_ordering( + newly_left_users + ) + ) + for user_id, entries in left_users_rooms.items(): + if any(e.room_id in joined_rooms for e in entries): + newly_left_users.discard(user_id) return DeviceLists(changed=users_that_have_changed, left=newly_left_users) else: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e4bed1c93758..843c68eb0fdf 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -57,7 +57,7 @@ class FollowerTypingHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.config.server.server_name self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -446,7 +446,7 @@ def process_replication_rows( class TypingNotificationEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self._main_store = hs.get_datastore() + self._main_store = hs.get_datastores().main self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: # diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 184730ebe8a4..014754a6308a 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -139,7 +139,7 @@ async def check_auth(self, authdict: dict, clientip: str) -> Any: class _BaseThreepidAuthChecker: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _check_threepid(self, medium: str, authdict: dict) -> dict: if "threepid_creds" not in authdict: @@ -255,7 +255,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self._enabled = bool(hs.config.registration.registration_requires_token) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def is_enabled(self) -> bool: return self._enabled diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 1565e034cb25..d27ed2be6a99 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -55,7 +55,7 @@ class UserDirectoryHandler(StateDeltasHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c5f8fcbb2a7d..40bf1e06d602 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -351,7 +351,7 @@ def __init__(self, hs: "HomeServer", tls_client_options_factory): ) self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 @@ -958,6 +958,7 @@ async def post_json( ) return body + @overload async def get_json( self, destination: str, @@ -967,7 +968,38 @@ async def get_json( timeout: Optional[int] = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, + parser: Literal[None] = None, + max_response_size: Optional[int] = None, ) -> Union[JsonDict, list]: + ... + + @overload + async def get_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = ..., + retry_on_dns_fail: bool = ..., + timeout: Optional[int] = ..., + ignore_backoff: bool = ..., + try_trailing_slash_on_400: bool = ..., + parser: ByteParser[T] = ..., + max_response_size: Optional[int] = ..., + ) -> T: + ... + + async def get_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser] = None, + max_response_size: Optional[int] = None, + ): """GETs some json from the given host homeserver and path Args: @@ -992,6 +1024,13 @@ async def get_json( try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. + + parser: The parser to use to decode the response. Defaults to + parsing as JSON. + + max_response_size: The maximum size to read from the response. If None, + uses the default. + Returns: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -1026,8 +1065,17 @@ async def get_json( else: _sec_timeout = self.default_timeout + if parser is None: + parser = JsonParser() + body = await _handle_response( - self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=parser, + max_response_size=max_response_size, ) return body diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py deleted file mode 100644 index b9933a152847..000000000000 --- a/synapse/logging/_structured.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os.path -from typing import Any, Dict, Generator, Optional, Tuple - -from constantly import NamedConstant, Names - -from synapse.config._base import ConfigError - - -class DrainType(Names): - CONSOLE = NamedConstant() - CONSOLE_JSON = NamedConstant() - CONSOLE_JSON_TERSE = NamedConstant() - FILE = NamedConstant() - FILE_JSON = NamedConstant() - NETWORK_JSON_TERSE = NamedConstant() - - -DEFAULT_LOGGERS = {"synapse": {"level": "info"}} - - -def parse_drain_configs( - drains: dict, -) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - """ - Parse the drain configurations. - - Args: - drains (dict): A list of drain configurations. - - Yields: - dict instances representing a logging handler. - - Raises: - ConfigError: If any of the drain configuration items are invalid. - """ - - for name, config in drains.items(): - if "type" not in config: - raise ConfigError("Logging drains require a 'type' key.") - - try: - logging_type = DrainType.lookupByName(config["type"].upper()) - except ValueError: - raise ConfigError( - "%s is not a known logging drain type." % (config["type"],) - ) - - # Either use the default formatter or the tersejson one. - if logging_type in ( - DrainType.CONSOLE_JSON, - DrainType.FILE_JSON, - ): - formatter: Optional[str] = "json" - elif logging_type in ( - DrainType.CONSOLE_JSON_TERSE, - DrainType.NETWORK_JSON_TERSE, - ): - formatter = "tersejson" - else: - # A formatter of None implies using the default formatter. - formatter = None - - if logging_type in [ - DrainType.CONSOLE, - DrainType.CONSOLE_JSON, - DrainType.CONSOLE_JSON_TERSE, - ]: - location = config.get("location") - if location is None or location not in ["stdout", "stderr"]: - raise ConfigError( - ( - "The %s drain needs the 'location' key set to " - "either 'stdout' or 'stderr'." - ) - % (logging_type,) - ) - - yield name, { - "class": "logging.StreamHandler", - "formatter": formatter, - "stream": "ext://sys." + location, - } - - elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]: - if "location" not in config: - raise ConfigError( - "The %s drain needs the 'location' key set." % (logging_type,) - ) - - location = config.get("location") - if os.path.abspath(location) != location: - raise ConfigError( - "File paths need to be absolute, '%s' is a relative path" - % (location,) - ) - - yield name, { - "class": "logging.FileHandler", - "formatter": formatter, - "filename": location, - } - - elif logging_type in [DrainType.NETWORK_JSON_TERSE]: - host = config.get("host") - port = config.get("port") - maximum_buffer = config.get("maximum_buffer", 1000) - - yield name, { - "class": "synapse.logging.RemoteHandler", - "formatter": formatter, - "host": host, - "port": port, - "maximum_buffer": maximum_buffer, - } - - else: - raise ConfigError( - "The %s drain type is currently not implemented." - % (config["type"].upper(),) - ) - - -def setup_structured_logging( - log_config: dict, -) -> dict: - """ - Convert a legacy structured logging configuration (from Synapse < v1.23.0) - to one compatible with the new standard library handlers. - """ - if "drains" not in log_config: - raise ConfigError("The logging configuration requires a list of drains.") - - new_config = { - "version": 1, - "formatters": { - "json": {"class": "synapse.logging.JsonFormatter"}, - "tersejson": {"class": "synapse.logging.TerseJsonFormatter"}, - }, - "handlers": {}, - "loggers": log_config.get("loggers", DEFAULT_LOGGERS), - "root": {"handlers": []}, - } - - for handler_name, handler in parse_drain_configs(log_config["drains"]): - new_config["handlers"][handler_name] = handler - - # Add each handler to the root logger. - new_config["root"]["handlers"].append(handler_name) - - return new_config diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d4fca369231a..c42eeedd87ae 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -59,6 +59,8 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, ON_CREATE_ROOM_CALLBACK, ON_NEW_EVENT_CALLBACK, + ON_PROFILE_UPDATE_CALLBACK, + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ) from synapse.handlers.account_validity import ( IS_USER_EXPIRED_CALLBACK, @@ -70,6 +72,7 @@ from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK, IS_3PID_ALLOWED_CALLBACK, ON_LOGGED_OUT_CALLBACK, @@ -144,6 +147,7 @@ "JsonDict", "EventBase", "StateMap", + "ProfileInfo", ] logger = logging.getLogger(__name__) @@ -171,7 +175,9 @@ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None: # TODO: Fix this type hint once the types for the data stores have been ironed # out. - self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore() + self._store: Union[ + DataStore, "GenericWorkerSlavedStore" + ] = hs.get_datastores().main self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -277,6 +283,10 @@ def register_third_party_rules_callbacks( CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, + on_user_deactivation_status_changed: Optional[ + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + ] = None, ) -> None: """Registers callbacks for third party event rules capabilities. @@ -288,6 +298,8 @@ def register_third_party_rules_callbacks( check_threepid_can_be_invited=check_threepid_can_be_invited, check_visibility_can_be_modified=check_visibility_can_be_modified, on_new_event=on_new_event, + on_profile_update=on_profile_update, + on_user_deactivation_status_changed=on_user_deactivation_status_changed, ) def register_presence_router_callbacks( @@ -317,6 +329,9 @@ def register_password_auth_provider_callbacks( get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -328,6 +343,7 @@ def register_password_auth_provider_callbacks( is_3pid_allowed=is_3pid_allowed, auth_checkers=auth_checkers, get_username_for_registration=get_username_for_registration, + get_displayname_for_registration=get_displayname_for_registration, ) def register_background_update_controller_callbacks( @@ -648,7 +664,11 @@ def record_user_external_id( Added in Synapse v1.9.0. Args: - auth_provider: identifier for the remote auth provider + auth_provider: identifier for the remote auth provider, see `sso` and + `oidc_providers` in the homeserver configuration. + + Note that no error is raised if the provided value is not in the + homeserver configuration. external_id: id on that system user_id: complete mxid that it is mapped to """ @@ -917,7 +937,7 @@ async def update_room_membership( ) # Try to retrieve the resulting event. - event = await self._hs.get_datastore().get_event(event_id) + event = await self._hs.get_datastores().main.get_event(event_id) # update_membership is supposed to always return after the event has been # successfully persisted. @@ -1261,7 +1281,7 @@ class PublicRoomListManager: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def room_is_in_public_room_list(self, room_id: str) -> bool: """Checks whether a room is in the public room list. diff --git a/synapse/notifier.py b/synapse/notifier.py index e0fad2da66b7..16d15a1f3328 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -138,7 +138,7 @@ def notify( self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token self.last_notified_ms = time_now_ms - noify_deferred = self.notify_deferred + notify_deferred = self.notify_deferred log_kv( { @@ -153,7 +153,7 @@ def notify( with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - noify_deferred.callback(self.current_token) + notify_deferred.callback(self.current_token) def remove(self, notifier: "Notifier") -> None: """Remove this listener from all the indexes in the Notifier @@ -222,7 +222,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] # Called when there are new things to stream over replication diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 5176a1c1861d..a1b771109848 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -68,7 +68,7 @@ class ThrottleParams: class Pusher(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self.hs = hs - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() self.pusher_id = pusher_config.id diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 67fd34dd08a9..24a531a8c517 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -130,7 +130,9 @@ def make_base_prepend_rules( return rules -BASE_APPEND_CONTENT_RULES = [ +# We have to annotate these types, otherwise mypy infers them as +# `List[Dict[str, Sequence[Collection[str]]]]`. +BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/content/.m.rule.contains_user_name", "conditions": [ @@ -149,7 +151,7 @@ def make_base_prepend_rules( ] -BASE_PREPEND_OVERRIDE_RULES = [ +BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/override/.m.rule.master", "enabled": False, @@ -159,7 +161,7 @@ def make_base_prepend_rules( ] -BASE_APPEND_OVERRIDE_RULES = [ +BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/override/.m.rule.suppress_notices", "conditions": [ @@ -298,7 +300,7 @@ def make_base_prepend_rules( ] -BASE_APPEND_UNDERRIDE_RULES = [ +BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/underride/.m.rule.call", "conditions": [ diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index ec770c9c6b8f..726474205438 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -96,7 +96,7 @@ class BulkPushRuleEvaluator: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._event_auth_handler = hs.get_event_auth_handler() # Used by `RulesForRoom` to ensure only one thing mutates the cache at a @@ -371,7 +371,7 @@ def __init__( """ self.room_id = room_id self.is_mine_id = hs.is_mine_id - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_push_rule_cache_metrics = room_push_rule_cache_metrics # Used to ensure only one thing mutates the cache at a time. Keyed off diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 39bb2acae4b5..1710dd51b9d4 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -66,7 +66,7 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer super().__init__(hs, pusher_config) self.mailer = mailer - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.email = pusher_config.pushkey self.timed_call: Optional[IDelayedCall] = None self.throttle_params: Dict[str, ThrottleParams] = {} diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 29d2812528cc..eb8c89a1c1ef 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -109,6 +109,7 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] + self.badge_count_last_call: Optional[int] = None def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. @@ -132,11 +133,13 @@ async def _update_badge(self) -> None: # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) - await self._send_badge(badge) + if self.badge_count_last_call is None or self.badge_count_last_call != badge: + self.badge_count_last_call = badge + await self._send_badge(badge) def on_timer(self) -> None: self._start_processing() @@ -280,7 +283,7 @@ async def _process_one(self, push_action: HttpPushAction) -> bool: tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions) badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) @@ -322,7 +325,7 @@ async def _build_notification_dict( # This was checked in the __init__, but mypy doesn't seem to know that. assert self.data is not None if self.data.get("format") == "event_id_only": - d = { + d: Dict[str, Any] = { "notification": { "event_id": event.event_id, "room_id": event.room_id, @@ -402,6 +405,8 @@ async def dispatch_push( rejected = [] if "rejected" in resp: rejected = resp["rejected"] + else: + self.badge_count_last_call = badge return rejected async def _send_badge(self, badge: int) -> None: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3df8452eecf7..649a4f49d024 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -112,7 +112,7 @@ def __init__( self.template_text = template_text self.send_email_handler = hs.get_send_email_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.state_store = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 3fe512b62e15..452e47999e57 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,12 +15,12 @@ import logging import re -from typing import Any, Dict, List, Optional, Pattern, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union from matrix_common.regex import glob_to_regex, to_word_pattern from synapse.events import EventBase -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -238,7 +238,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: def _flatten_dict( - d: Union[EventBase, JsonDict], + d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -251,7 +251,7 @@ def _flatten_dict( value = str(value) if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() - elif isinstance(value, dict): + elif isinstance(value, Mapping): _flatten_dict(value, prefix=(prefix + [key]), result=result) return result diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 7912311d2401..d0cc657b4464 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -59,7 +59,7 @@ class PusherPool: def __init__(self, hs: "HomeServer"): self.hs = hs self.pusher_factory = PusherFactory(hs) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() # We shard the handling of push notifications by user ID. diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 86162e0f2c65..b40a7bbb76ca 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -17,14 +17,7 @@ import itertools import logging -from typing import List, Set - -from pkg_resources import ( - DistributionNotFound, - Requirement, - VersionConflict, - get_provider, -) +from typing import Set logger = logging.getLogger(__name__) @@ -87,8 +80,11 @@ # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", - "ijson>=3.1", + # ijson 3.1.4 fixes a bug with "." in property names + "ijson>=3.1.4", "matrix-common~=1.1.0", + # We need packaging.requirements.Requirement, added in 16.1. + "packaging>=16.1", ] CONDITIONAL_REQUIREMENTS = { @@ -143,102 +139,6 @@ def list_requirements(): return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS) -class DependencyException(Exception): - @property - def message(self): - return "\n".join( - [ - "Missing Requirements: %s" % (", ".join(self.dependencies),), - "To install run:", - " pip install --upgrade --force %s" % (" ".join(self.dependencies),), - "", - ] - ) - - @property - def dependencies(self): - for i in self.args[0]: - yield '"' + i + '"' - - -def check_requirements(for_feature=None): - deps_needed = [] - errors = [] - - if for_feature: - reqs = CONDITIONAL_REQUIREMENTS[for_feature] - else: - reqs = REQUIREMENTS - - for dependency in reqs: - try: - _check_requirement(dependency) - except VersionConflict as e: - deps_needed.append(dependency) - errors.append( - "Needed %s, got %s==%s" - % ( - dependency, - e.dist.project_name, # type: ignore[attr-defined] # noqa - e.dist.version, # type: ignore[attr-defined] # noqa - ) - ) - except DistributionNotFound: - deps_needed.append(dependency) - if for_feature: - errors.append( - "Needed %s for the '%s' feature but it was not installed" - % (dependency, for_feature) - ) - else: - errors.append("Needed %s but it was not installed" % (dependency,)) - - if not for_feature: - # Check the optional dependencies are up to date. We allow them to not be - # installed. - OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), []) - - for dependency in OPTS: - try: - _check_requirement(dependency) - except VersionConflict as e: - deps_needed.append(dependency) - errors.append( - "Needed optional %s, got %s==%s" - % ( - dependency, - e.dist.project_name, # type: ignore[attr-defined] # noqa - e.dist.version, # type: ignore[attr-defined] # noqa - ) - ) - except DistributionNotFound: - # If it's not found, we don't care - pass - - if deps_needed: - for err in errors: - logging.error(err) - - raise DependencyException(deps_needed) - - -def _check_requirement(dependency_string): - """Parses a dependency string, and checks if the specified requirement is installed - - Raises: - VersionConflict if the requirement is installed, but with the the wrong version - DistributionNotFound if nothing is found to provide the requirement - """ - req = Requirement.parse(dependency_string) - - # first check if the markers specify that this requirement needs installing - if req.marker is not None and not req.marker.evaluate(): - # not required for this environment - return - - get_provider(req) - - if __name__ == "__main__": import sys diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 388b42323cf3..f1abb986534b 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -293,7 +293,9 @@ async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: raise e.to_synapse_error() except Exception as e: _outgoing_request_counter.labels(cls.NAME, "ERR").inc() - raise SynapseError(502, "Failed to talk to main process") from e + raise SynapseError( + 502, f"Failed to talk to {instance_name} process" + ) from e _outgoing_request_counter.labels(cls.NAME, 200).inc() return result diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index f2f40129fe68..3d63645726b9 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -63,7 +63,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index d529c8a19fa2..3e7300b4a148 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -68,7 +68,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -167,7 +167,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -214,7 +214,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -260,7 +260,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] @@ -297,7 +297,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override] diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 0145858e4719..663bff573848 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -50,7 +50,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -119,7 +119,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -188,7 +188,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -258,7 +258,7 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -325,7 +325,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.registeration_handler = hs.get_registration_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.distributor = hs.get_distributor() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index c7f751b70d0d..6c8f8388fd60 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -36,7 +36,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod @@ -112,7 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 33e98daf8aba..ce781768364b 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -69,7 +69,7 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d59ce7ccf967..1b8479b0b4ec 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -111,7 +111,7 @@ class ReplicationDataHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._reactor = hs.get_reactor() self._clock = hs.get_clock() @@ -340,7 +340,7 @@ class FederationSenderHandler: def __init__(self, hs: "HomeServer"): assert hs.should_send_federation() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id self._hs = hs diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 17e157239336..0d2013a3cfc5 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -95,7 +95,7 @@ class ReplicationCommandHandler: def __init__(self, hs: "HomeServer"): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ecd6190f5b9b..494e42a2be8f 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -72,7 +72,7 @@ class ReplicationStreamer: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.notifier = hs.get_notifier() self._instance_name = hs.get_instance_name() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 914b9eae8492..23d631a76944 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -239,7 +239,7 @@ class BackfillStreamRow: ROW_TYPE = BackfillStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._current_token, @@ -267,7 +267,7 @@ class PresenceStreamRow: ROW_TYPE = PresenceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main if hs.get_instance_name() in hs.config.worker.writers.presence: # on the presence writer, query the presence handler @@ -355,7 +355,7 @@ class ReceiptsStreamRow: ROW_TYPE = ReceiptsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), @@ -374,7 +374,7 @@ class PushRulesStreamRow: ROW_TYPE = PushRulesStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -401,7 +401,7 @@ class PushersStreamRow: ROW_TYPE = PushersStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -434,7 +434,7 @@ class CachesStreamRow: ROW_TYPE = CachesStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), store.get_cache_stream_token_for_writer, @@ -455,7 +455,7 @@ class DeviceListsStreamRow: ROW_TYPE = DeviceListsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), @@ -474,7 +474,7 @@ class ToDeviceStreamRow: ROW_TYPE = ToDeviceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_to_device_stream_token), @@ -495,7 +495,7 @@ class TagAccountDataStreamRow: ROW_TYPE = TagAccountDataStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_account_data_stream_id), @@ -516,7 +516,7 @@ class AccountDataStreamRow: ROW_TYPE = AccountDataStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), @@ -585,7 +585,7 @@ class GroupsStreamRow: ROW_TYPE = GroupsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_group_stream_token), @@ -604,7 +604,7 @@ class UserSignatureStreamRow: ROW_TYPE = UserSignatureStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 50c4a5ba03a1..26f4fa7cfd16 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -124,7 +124,7 @@ class EventsStream(Stream): NAME = "events" def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._store._stream_id_gen.get_current_token_for_writer, diff --git a/synapse/res/providers.json b/synapse/res/providers.json index f1838f955901..7b9958e45464 100644 --- a/synapse/res/providers.json +++ b/synapse/res/providers.json @@ -5,8 +5,6 @@ "endpoints": [ { "schemes": [ - "https://twitter.com/*/status/*", - "https://*.twitter.com/*/status/*", "https://twitter.com/*/moments/*", "https://*.twitter.com/*/moments/*" ], @@ -14,4 +12,4 @@ } ] } -] \ No newline at end of file +] diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ba0d989d81c3..6de302f81352 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -116,7 +116,7 @@ class PurgeHistoryRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index e9bce22a347b..93a78db81186 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -112,7 +112,7 @@ class BackgroundUpdateStartJobRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d9905ff560cb..cef46ba0ddf2 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -44,7 +44,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -113,7 +113,7 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -144,7 +144,7 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_POST( diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 38477f8eadeb..6d634eef7081 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -53,7 +53,7 @@ class EventReportsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -115,7 +115,7 @@ class EventReportDetailRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index d162e0081ea1..023ed92144b4 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -48,7 +48,7 @@ class ListDestinationsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) @@ -105,7 +105,7 @@ class DestinationRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -165,7 +165,7 @@ class DestinationMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -221,7 +221,7 @@ class DestinationResetConnectionRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._authenticator = Authenticator(hs) async def on_POST( diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 299f5c9eb0f2..8ca57bdb2830 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -47,7 +47,7 @@ class QuarantineMediaInRoom(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -74,7 +74,7 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -103,7 +103,7 @@ class QuarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -132,7 +132,7 @@ class UnquarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -156,7 +156,7 @@ class ProtectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -178,7 +178,7 @@ class UnprotectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -200,7 +200,7 @@ class ListMediaInRoom(RestServlet): PATTERNS = admin_patterns("/room/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -251,7 +251,7 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -283,7 +283,7 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -352,7 +352,7 @@ class UserMediaRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repository = hs.get_media_repository() async def on_GET( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 04948b640834..af606e925216 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -71,7 +71,7 @@ class ListRegistrationTokensRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -109,7 +109,7 @@ class NewRegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() # A string of all the characters allowed to be in a registration_token self.allowed_chars = string.ascii_letters + string.digits + "._~-" @@ -260,7 +260,7 @@ class RegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Retrieve a registration token.""" diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 4d31864a2ba3..5b7053966662 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -73,7 +73,7 @@ class RoomRestV2Servlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._pagination_handler = hs.get_pagination_handler() async def on_DELETE( @@ -196,7 +196,7 @@ class ListRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -286,7 +286,7 @@ class RoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -396,7 +396,7 @@ class RoomMembersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -422,7 +422,7 @@ class RoomStateRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() @@ -539,7 +539,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -684,7 +684,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_DELETE( self, request: SynapseRequest, room_identifier: str @@ -795,7 +795,7 @@ class BlockRoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 7a6546372eef..3b142b840206 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -38,7 +38,7 @@ class UserMediaStatisticsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 6c392bed6ac9..2e9b6ec08e65 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,7 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -157,7 +157,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.set_password_handler = hs.get_set_password_handler() @@ -589,7 +589,7 @@ def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, target_user_id: str @@ -675,7 +675,7 @@ class ResetPasswordRestServlet(RestServlet): PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -718,7 +718,7 @@ class SearchUsersRestServlet(RestServlet): PATTERNS = admin_patterns("/search_users/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -776,7 +776,7 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -836,7 +836,7 @@ class UserMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str @@ -865,7 +865,7 @@ class PushersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -906,7 +906,7 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.is_mine_id = hs.is_mine_id @@ -975,7 +975,7 @@ class ShadowBanRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1027,7 +1027,7 @@ class RateLimitRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1130,7 +1130,7 @@ class AccountDataRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id async def on_GET( diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index cfa2aee76d49..5587cae98a61 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -60,7 +60,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.config = hs.config self.identity_handler = hs.get_identity_handler() @@ -114,7 +114,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after # canonicalisation, matches the one in our database. - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -168,7 +168,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() @@ -347,7 +347,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config self.identity_handler = hs.get_identity_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( @@ -450,7 +450,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -533,7 +533,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( self.config.email.email_add_threepid_template_failure_html @@ -600,7 +600,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: @@ -634,7 +634,7 @@ def __init__(self, hs: "HomeServer"): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -768,7 +768,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are @@ -883,7 +883,9 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: response = { "user_id": requester.user.to_string(), # MSC: https://github.com/matrix-org/matrix-doc/pull/3069 + # Entered spec in Matrix 1.2 "org.matrix.msc3069.is_guest": bool(requester.is_guest), + "is_guest": bool(requester.is_guest), } # Appservices and similar accounts do not have device IDs @@ -894,6 +896,33 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: return 200, response +class AccountStatusRestServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc3720/account_status$", unstable=True, releases=() + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._account_handler = hs.get_account_handler() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self._auth.get_user_by_req(request) + + body = parse_json_object_from_request(request) + if "user_ids" not in body: + raise SynapseError( + 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM + ) + + statuses, failures = await self._account_handler.get_account_statuses( + body["user_ids"], + allow_remote=True, + ) + + return 200, {"account_statuses": statuses, "failures": failures} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) @@ -908,3 +937,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) + + if hs.config.experimental.msc3720_enabled: + AccountStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 58b8adbd3245..bfe985939beb 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -42,7 +42,7 @@ class AccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( @@ -90,7 +90,7 @@ class RoomAccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 9c15a04338b1..e0b2b80e5b9b 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -62,7 +62,7 @@ async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: if stagetype == LoginType.RECAPTCHA: html = self.recaptcha_template.render( session=session, - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.captcha.recaptcha_public_key, ) @@ -74,7 +74,7 @@ async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: self.hs.config.server.public_baseurl, self.hs.config.consent.user_consent_version, ), - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), ) @@ -118,7 +118,7 @@ async def on_POST(self, request: Request, stagetype: str) -> None: # Authentication failed, let user try again html = self.recaptcha_template.render( session=session, - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.captcha.recaptcha_public_key, error=e.msg, @@ -143,7 +143,7 @@ async def on_POST(self, request: Request, stagetype: str) -> None: self.hs.config.server.public_baseurl, self.hs.config.consent.user_consent_version, ), - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), error=e.msg, ) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 6682da077a3e..4237071c61bd 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -72,22 +72,10 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: "org.matrix.msc3244.room_capabilities" ] = MSC3244_CAPABILITIES - # Must be removed in later versions. - # Is only included for migration. - # Also the parts in `synapse/config/experimental.py`. - if self.config.experimental.msc3283_enabled: - response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.registration.enable_set_displayname + if self.config.experimental.msc3720_enabled: + response["capabilities"]["org.matrix.msc3720.account_status"] = { + "enabled": True, } - response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.registration.enable_set_avatar_url - } - response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.registration.enable_3pid_changes - } - - if self.config.experimental.msc3440_enabled: - response["capabilities"]["io.element.thread"] = {"enabled": True} return HTTPStatus.OK, response diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index ee247e3d1e0d..e181a0dde2f8 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -47,7 +47,7 @@ class ClientDirectoryServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -129,7 +129,7 @@ class ClientDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -173,7 +173,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 672c821061ff..916f5230f1b1 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -39,7 +39,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index a7e9aa3e9bbf..7e1149c7f433 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -705,7 +705,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id @_validate_group_id @@ -854,7 +854,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @_validate_group_id async def on_PUT( @@ -879,7 +879,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_GET( @@ -901,7 +901,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index 49b1037b28d4..cfadcb8e50f5 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -33,7 +33,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 730c18f08f78..ce806e3c11be 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -198,7 +198,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f9994658c4a3..c9d44c5964ce 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -104,13 +104,13 @@ def __init__(self, hs: "HomeServer"): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 8e427a96a320..20377a9ac628 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -35,7 +35,7 @@ class NotificationsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index add56d699884..820682ec4265 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -67,7 +67,7 @@ class IdTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8fe75bd750bd..a93f6fd5e0a8 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -57,7 +57,7 @@ class PushRuleRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 98604a93887f..d6487c31dd48 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -46,7 +46,9 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + pushers = await self.hs.get_datastores().main.get_pushers_by_user_id( + user.to_string() + ) filtered_pushers = [p.as_dict() for p in pushers] diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 884196f88381..843f628c68c0 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -123,7 +123,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: request, "email", email ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -203,7 +203,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: request, "msisdn", msisdn ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "msisdn", msisdn ) @@ -258,7 +258,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( @@ -386,7 +386,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), @@ -416,7 +416,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.identity_handler = hs.get_identity_handler() @@ -695,11 +695,18 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: session_id ) + display_name = await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, + default_display_name=display_name, address=client_addr, user_agent_ips=entries, ) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 2cab83c4e6e7..487ea38b55cc 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -85,7 +85,7 @@ class RelationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() @@ -190,7 +190,7 @@ class RelationAggregationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_handler = hs.get_event_handler() async def on_GET( @@ -282,7 +282,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index d4a4adb50c59..6e962a45321e 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -38,7 +38,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, room_id: str, event_id: str diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 32589407f0a4..e1098b3073df 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -477,7 +477,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -553,7 +553,7 @@ def __init__(self, hs: "HomeServer"): self._hs = hs self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -621,7 +621,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -642,7 +642,7 @@ class RoomEventServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() @@ -1027,7 +1027,7 @@ class JoinedRoomsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -1116,7 +1116,7 @@ class TimestampLookupRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() async def on_GET( @@ -1141,73 +1141,6 @@ async def on_GET( } -class RoomSpaceSummaryRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" - "/rooms/(?P[^/]*)/spaces$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - - max_rooms_per_space = parse_integer(request, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=parse_boolean(request, "suggested_only", default=False), - max_rooms_per_space=max_rooms_per_space, - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) - - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - ) - - class RoomHierarchyRestServlet(RestServlet): PATTERNS = ( re.compile( @@ -1301,7 +1234,6 @@ def register_servlets( RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) - RoomSpaceSummaryRestServlet(hs).register(http_server) RoomHierarchyRestServlet(hs).register(http_server) if hs.config.experimental.msc3266_enabled: RoomSummaryRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 9afa0f738cc5..006a434740a7 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -75,7 +75,7 @@ class RoomBatchSendEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self.room_batch_handler = hs.get_room_batch_handler() diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index 09a46737de06..e669fa78902d 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -41,7 +41,7 @@ class UserSharedRoomsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_directory_active = hs.config.server.update_user_directory async def on_GET( diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f9615da52583..f3018ff69077 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -103,7 +103,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py index c88cb9367c5f..ca638755c7ee 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py @@ -39,7 +39,7 @@ class TagListServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str, room_id: str diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2290c57c126e..2e5d0e4e2258 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -73,6 +73,8 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]: "r0.5.0", "r0.6.0", "r0.6.1", + "v1.1", + "v1.2", ], # as per MSC1497: "unstable_features": { @@ -97,6 +99,8 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]: "org.matrix.msc2716": self.config.experimental.msc2716_enabled, # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 "org.matrix.msc3030": self.config.experimental.msc3030_enabled, + # Adds support for thread relations, per MSC3440. + "org.matrix.msc3440": self.config.experimental.msc3440_enabled, }, }, ) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 3d2afacc50d9..25f9ea285bca 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -78,7 +78,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() # this is required by the request_handler wrapper diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 3923ba8439b0..3525d6ae54c6 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -94,7 +94,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.fetcher = ServerKeyFetcher(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 71b9a34b145d..6c414402bd46 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -75,7 +75,7 @@ def __init__(self, hs: "HomeServer"): self.client = hs.get_federation_http_client() self.clock = hs.get_clock() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.max_upload_size = hs.config.media.max_upload_size self.max_image_pixels = hs.config.media.max_image_pixels diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index ab2e960776d4..97ad4769acc3 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -133,7 +133,7 @@ def __init__( self.filepaths = media_repo.filepaths self.max_spider_size = hs.config.media.max_spider_size self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.client = SimpleHttpClient( hs, treq_args={"browser_like_redirects": True}, @@ -402,7 +402,15 @@ async def _download_url(self, url: str, output_stream: BinaryIO) -> DownloadResu url, output_stream=output_stream, max_size=self.max_spider_size, - headers={"Accept-Language": self.url_preview_accept_language}, + headers={ + b"Accept-Language": self.url_preview_accept_language, + # Use a custom user agent for the preview because some sites will only return + # Open Graph metadata to crawler user agents. Omit the Synapse version + # string to avoid leaking information. + b"User-Agent": [ + "Synapse (bot; +https://github.com/matrix-org/synapse)" + ], + }, is_allowed_content_type=_is_previewable, ) except SynapseError: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index ed91ef5a4283..53b156524375 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -50,7 +50,7 @@ def __init__( ): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = media_repo self.media_storage = media_storage self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index fde28d08cb8e..e73e431dc9f1 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -37,7 +37,7 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): self.media_repo = media_repo self.filepaths = media_repo.filepaths - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self.auth = hs.get_auth() diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index 28a67f04e33f..6ac9dbc7c9be 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -44,7 +44,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._local_threepid_handling_disabled_due_to_email_config = ( hs.config.email.local_threepid_handling_disabled_due_to_email_config diff --git a/synapse/server.py b/synapse/server.py index 564afdcb9697..b5e2a319bcef 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -17,7 +17,7 @@ # homeservers; either as a full homeserver as a real application, or a small # partial one for unit test mocking. -# Imports required for the default HomeServer() implementation + import abc import functools import logging @@ -62,6 +62,7 @@ from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler +from synapse.handlers.account import AccountHandler from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.admin import AdminHandler @@ -133,7 +134,7 @@ WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, DataStore, Storage +from synapse.storage import Databases, Storage from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock @@ -224,7 +225,7 @@ class HomeServer(metaclass=abc.ABCMeta): # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be - # instantiated during setup() for future return by get_datastore() + # instantiated during setup() for future return by get_datastores() DATASTORE_CLASS = abc.abstractproperty() tls_server_context_factory: Optional[IOpenSSLContextFactory] @@ -354,12 +355,6 @@ def is_mine_id(self, string: str) -> bool: def get_clock(self) -> Clock: return Clock(self._reactor) - def get_datastore(self) -> DataStore: - if not self.datastores: - raise Exception("HomeServer.setup must be called before getting datastores") - - return self.datastores.main - def get_datastores(self) -> Databases: if not self.datastores: raise Exception("HomeServer.setup must be called before getting datastores") @@ -373,7 +368,7 @@ def get_distributor(self) -> Distributor: @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( - store=self.get_datastore(), + store=self.get_datastores().main, clock=self.get_clock(), rate_hz=self.config.ratelimiting.rc_registration.per_second, burst_count=self.config.ratelimiting.rc_registration.burst_count, @@ -807,6 +802,10 @@ def get_event_auth_handler(self) -> EventAuthHandler: def get_external_cache(self) -> ExternalCache: return ExternalCache(self) + @cache_in_self + def get_account_handler(self) -> AccountHandler: + return AccountHandler(self) + @cache_in_self def get_outbound_redis_connection(self) -> "RedisProtocol": """ @@ -842,7 +841,7 @@ def should_send_federation(self) -> bool: @cache_in_self def get_request_ratelimiter(self) -> RequestRatelimiter: return RequestRatelimiter( - self.get_datastore(), + self.get_datastores().main, self.get_clock(), self.config.ratelimiting.rc_message, self.config.ratelimiting.rc_admin_redaction, diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index e09a25591feb..698ca742ed66 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -32,7 +32,7 @@ class ConsentServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._users_in_progress: Set[str] = set() diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 8522930b5048..015dd08f05e4 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -36,7 +36,7 @@ class ResourceLimitsServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth = hs.get_auth() self._config = hs.config self._resouce_limited = False diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 0cf60236f8b4..7b4814e04994 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -29,7 +29,7 @@ class ServerNoticesManager: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._config = hs.config self._account_data_handler = hs.get_account_data_handler() self._room_creation_handler = hs.get_room_creation_handler() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 67e8bc6ec288..6babd5963cc1 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -126,7 +126,7 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() @@ -258,7 +258,10 @@ async def get_hosts_in_room_at_events( return await self.store.get_joined_hosts(room_id, entry) async def compute_event_context( - self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + self, + event: EventBase, + old_state: Optional[Iterable[EventBase]] = None, + partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,6 +276,8 @@ async def compute_event_context( calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. + partial_state: True if `old_state` is partial and omits non-critical + membership events Returns: The event context. """ @@ -295,8 +300,28 @@ async def compute_event_context( else: # otherwise, we'll need to resolve the state across the prev_events. - logger.debug("calling resolve_state_groups from compute_event_context") + # partial_state should not be set explicitly in this case: + # we work it out dynamically + assert not partial_state + + # if any of the prev-events have partial state, so do we. + # (This is slightly racy - the prev-events might get fixed up before we use + # their states - but I don't think that really matters; it just means we + # might redundantly recalculate the state for this event later.) + prev_event_ids = event.prev_event_ids() + incomplete_prev_events = await self.store.get_partial_state_events( + prev_event_ids + ) + if any(incomplete_prev_events.values()): + logger.debug( + "New/incoming event %s refers to prev_events %s with partial state", + event.event_id, + [k for (k, v) in incomplete_prev_events.items() if v], + ) + partial_state = True + + logger.debug("calling resolve_state_groups from compute_event_context") entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -342,6 +367,7 @@ async def compute_event_context( prev_state_ids=state_ids_before_event, prev_group=state_group_before_event_prev_group, delta_ids=deltas_to_state_group_before_event, + partial_state=partial_state, ) # @@ -373,6 +399,7 @@ async def compute_event_context( prev_state_ids=state_ids_before_event, prev_group=state_group_before_event, delta_ids=delta_ids, + partial_state=partial_state, ) @measure_func() diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index cfe887b7f73d..ce3d1d4e942e 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class Databases(Generic[DataStoreT]): """ databases: List[DatabasePool] - main: DataStoreT + main: "DataStore" # FIXME: #11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 431c94225f1d..84da0becfc78 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,15 +20,19 @@ ApplicationService, ApplicationServiceState, AppServiceTransaction, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines._base import IsolationLevel +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.types import JsonDict from synapse.util import json_encoder +from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -57,7 +61,7 @@ def _make_exclusive_regex( return exclusive_user_pattern -class ApplicationServiceWorkerStore(SQLBaseStore): +class ApplicationServiceWorkerStore(RoomMemberWorkerStore): def __init__( self, database: DatabasePool, @@ -125,6 +129,18 @@ def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]: return service return None + @cached(iterable=True, cache_context=True) + async def get_app_service_users_in_room( + self, + room_id: str, + app_service: "ApplicationService", + cache_context: _CacheContext, + ) -> List[str]: + users_in_room = await self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + return list(filter(app_service.is_interested_in_user, users_in_room)) + class ApplicationServiceStore(ApplicationServiceWorkerStore): # This is currently empty due to there not being any AS storage functions @@ -200,6 +216,8 @@ async def create_appservice_txn( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -210,6 +228,10 @@ async def create_appservice_txn( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. Returns: A new transaction. @@ -245,6 +267,8 @@ def _create_appservice_txn(txn): events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, + one_time_key_counts=one_time_key_counts, + unused_fallback_keys=unused_fallback_keys, ) return await self.db_pool.runInteraction( @@ -338,12 +362,17 @@ def _get_oldest_unsent_txn(txn): events = await self.get_events_as_list(event_ids) + # TODO: to-device messages, one-time key counts and unused fallback keys + # are not yet populated for catch-up transactions. + # We likely want to populate those for reliability. return AppServiceTransaction( service=service, id=entry["txn_id"], events=events, ephemeral=[], to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8d845fe9516f..3b3a089b7627 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -670,6 +670,16 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict] device["device_id"]: db_to_json(device["content"]) for device in devices } + def get_cached_device_list_changes( + self, + from_key: int, + ) -> Optional[Set[str]]: + """Get set of users whose devices have changed since `from_key`, or None + if that information is not in our cache. + """ + + return self._device_list_stream_cache.get_all_entities_changed(from_key) + async def get_users_whose_devices_changed( self, from_key: int, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1f8447b5076f..9b293475c871 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -29,6 +29,10 @@ from canonicaljson import encode_canonical_json from synapse.api.constants import DeviceKeyAlgorithms +from synapse.appservice import ( + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -439,6 +443,114 @@ def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]: "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def count_bulk_e2e_one_time_keys_for_as( + self, user_ids: Collection[str] + ) -> TransactionOneTimeKeyCounts: + """ + Counts, in bulk, the one-time keys for all the users specified. + Intended to be used by application services for populating OTK counts in + transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithm -> count + Empty algorithm -> count dicts are created if needed to represent a + lack of unused one-time keys. + """ + + def _count_bulk_e2e_one_time_keys_txn( + txn: LoggingTransaction, + ) -> TransactionOneTimeKeyCounts: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + sql = f""" + SELECT user_id, device_id, algorithm, COUNT(key_id) + FROM devices + LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id) + WHERE {user_in_where_clause} + GROUP BY user_id, device_id, algorithm + """ + txn.execute(sql, user_parameters) + + result: TransactionOneTimeKeyCounts = {} + + for user_id, device_id, algorithm, count in txn: + # We deliberately construct empty dictionaries for + # users and devices without any unused one-time keys. + # We *could* omit these empty dicts if there have been no + # changes since the last transaction, but we currently don't + # do any change tracking! + device_count_by_algo = result.setdefault(user_id, {}).setdefault( + device_id, {} + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_count_by_algo[algorithm] = count + + return result + + return await self.db_pool.runInteraction( + "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn + ) + + async def get_e2e_bulk_unused_fallback_key_types( + self, user_ids: Collection[str] + ) -> TransactionUnusedFallbackKeys: + """ + Finds, in bulk, the types of unused fallback keys for all the users specified. + Intended to be used by application services for populating unused fallback + keys in transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithms + Empty lists are created for devices if there are no unused fallback + keys. This matches the response structure of MSC3202. + """ + if len(user_ids) == 0: + return {} + + def _get_bulk_e2e_unused_fallback_keys_txn( + txn: LoggingTransaction, + ) -> TransactionUnusedFallbackKeys: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "devices.user_id", user_ids + ) + # We can't use USING here because we require the `.used` condition + # to be part of the JOIN condition so that we generate empty lists + # when all keys are used (as opposed to just when there are no keys at all). + sql = f""" + SELECT devices.user_id, devices.device_id, algorithm + FROM devices + LEFT JOIN e2e_fallback_keys_json AS fallback_keys + ON devices.user_id = fallback_keys.user_id + AND devices.device_id = fallback_keys.device_id + AND NOT fallback_keys.used + WHERE + {user_in_where_clause} + """ + txn.execute(sql, user_parameters) + + result: TransactionUnusedFallbackKeys = {} + + for user_id, device_id, algorithm in txn: + # We deliberately construct empty dictionaries and lists for + # users and devices without any unused fallback keys. + # We *could* omit these empty dicts if there have been no + # changes since the last transaction, but we currently don't + # do any change tracking! + device_unused_keys = result.setdefault(user_id, {}).setdefault( + device_id, [] + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_unused_keys.append(algorithm) + + return result + + return await self.db_pool.runInteraction( + "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn + ) + async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5246fccad57a..ca2a9ba9d116 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -130,7 +130,7 @@ async def _persist_events_and_state_updates( *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], - new_forward_extremeties: Dict[str, List[str]], + new_forward_extremities: Dict[str, Set[str]], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, ) -> None: @@ -143,7 +143,7 @@ async def _persist_events_and_state_updates( the room based on forward extremities state_delta_for_room: Map from room_id to the delta to apply to room state - new_forward_extremities: Map from room_id to list of event IDs + new_forward_extremities: Map from room_id to set of event IDs that are the new forward extremities of the room. use_negative_stream_ordering: Whether to start stream_ordering on the negative side and decrement. This should be set as True @@ -193,7 +193,7 @@ async def _persist_events_and_state_updates( events_and_contexts=events_and_contexts, inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, ) persist_event_counter.inc(len(events_and_contexts)) @@ -220,7 +220,7 @@ async def _persist_events_and_state_updates( for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) - for room_id, latest_event_ids in new_forward_extremeties.items(): + for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) @@ -334,8 +334,8 @@ def _persist_events_txn( events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, - new_forward_extremeties: Optional[Dict[str, List[str]]] = None, - ): + new_forward_extremities: Optional[Dict[str, Set[str]]] = None, + ) -> None: """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -353,13 +353,13 @@ def _persist_events_txn( from the database. This is useful when retrying due to IntegrityError. state_delta_for_room: The current-state delta for each room. - new_forward_extremetie: The new forward extremities for each room. + new_forward_extremities: The new forward extremities for each room. For each room, a list of the event ids which are the forward extremities. """ state_delta_for_room = state_delta_for_room or {} - new_forward_extremeties = new_forward_extremeties or {} + new_forward_extremities = new_forward_extremities or {} all_events_and_contexts = events_and_contexts @@ -372,7 +372,7 @@ def _persist_events_txn( self._update_forward_extremities_txn( txn, - new_forward_extremities=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, max_stream_order=max_stream_order, ) @@ -975,6 +975,17 @@ def _update_current_state_txn( to_delete = delta_state.to_delete to_insert = delta_state.to_insert + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invalidate the caches for all + # those users. + members_changed = { + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + } + if delta_state.no_longer_in_room: # Server is no longer in the room so we delete the room from # current_state_events, being careful we've already updated the @@ -993,6 +1004,11 @@ def _update_current_state_txn( """ txn.execute(sql, (stream_id, self._instance_name, room_id)) + # We also want to invalidate the membership caches for users + # that were in the room. + users_in_room = self.store.get_users_in_room_txn(txn, room_id) + members_changed.update(users_in_room) + self.db_pool.simple_delete_txn( txn, table="current_state_events", @@ -1102,17 +1118,6 @@ def _update_current_state_txn( # Invalidate the various caches - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invalidate the caches for all - # those users. - members_changed = { - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - } - for member in members_changed: txn.call_after( self.store.get_rooms_for_user_with_stream_ordering.invalidate, @@ -1153,7 +1158,10 @@ def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): ) def _update_forward_extremities_txn( - self, txn, new_forward_extremities, max_stream_order + self, + txn: LoggingTransaction, + new_forward_extremities: Dict[str, Set[str]], + max_stream_order: int, ): for room_id in new_forward_extremities.keys(): self.db_pool.simple_delete_txn( @@ -1468,10 +1476,10 @@ def _store_rejected_events_txn(self, txn, events_and_contexts): def _update_metadata_tables_txn( self, - txn, + txn: LoggingTransaction, *, - events_and_contexts, - all_events_and_contexts, + events_and_contexts: List[Tuple[EventBase, EventContext]], + all_events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, ): """Update all the miscellaneous tables for new events @@ -1948,20 +1956,20 @@ def _handle_redaction(self, txn, redacted_event_id): txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) - def _store_room_topic_txn(self, txn, event): - if hasattr(event, "content") and "topic" in event.content: + def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("topic"), str): self.store_event_search_txn( txn, event, "content.topic", event.content["topic"] ) - def _store_room_name_txn(self, txn, event): - if hasattr(event, "content") and "name" in event.content: + def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("name"), str): self.store_event_search_txn( txn, event, "content.name", event.content["name"] ) - def _store_room_message_txn(self, txn, event): - if hasattr(event, "content") and "body" in event.content: + def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("body"), str): self.store_event_search_txn( txn, event, "content.body", event.content["body"] ) @@ -2137,6 +2145,14 @@ def _store_event_state_mappings_txn( state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): + # double-check that we don't have any events that claim to be outliers + # *and* have partial state (which is meaningless: we should have no + # state at all for an outlier) + if context.partial_state: + raise ValueError( + "Outlier event %s claims to have partial state", event.event_id + ) + continue # if the event was rejected, just give it the same state as its @@ -2147,6 +2163,23 @@ def _store_event_state_mappings_txn( state_groups[event.event_id] = context.state_group + # if we have partial state for these events, record the fact. (This happens + # here rather than in _store_event_txn because it also needs to happen when + # we de-outlier an event.) + self.db_pool.simple_insert_many_txn( + txn, + table="partial_state_events", + keys=("room_id", "event_id"), + values=[ + ( + event.room_id, + event.event_id, + ) + for event, ctx in events_and_contexts + if ctx.partial_state + ], + ) + self.db_pool.simple_upsert_many_txn( txn, table="event_to_state_groups", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8d4287045a8b..26784f755e40 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -408,7 +408,7 @@ async def get_events( include the previous states content in the unsigned field. allow_rejected: If True, return rejected events. Otherwise, - omits rejeted events from the response. + omits rejected events from the response. Returns: A mapping from event_id to event. @@ -1854,7 +1854,7 @@ def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: forward_edge_query = """ SELECT 1 FROM event_edges /* Check to make sure the event referencing our event in question is not rejected */ - LEFT JOIN rejections ON event_edges.event_id == rejections.event_id + LEFT JOIN rejections ON event_edges.event_id = rejections.event_id WHERE event_edges.room_id = ? AND event_edges.prev_event_id = ? @@ -1953,3 +1953,31 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: "get_event_id_for_timestamp_txn", get_event_id_for_timestamp_txn, ) + + @cachedList("is_partial_state_event", list_name="event_ids") + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + """Checks which of the given events have partial state""" + result = await self.db_pool.simple_select_many_batch( + table="partial_state_events", + column="event_id", + iterable=event_ids, + retcols=["event_id"], + desc="get_partial_state_events", + ) + # convert the result to a dict, to make @cachedList work + partial = {r["event_id"] for r in result} + return {e_id: e_id in partial for e_id in event_ids} + + @cached() + async def is_partial_state_event(self, event_id: str) -> bool: + """Checks if the given event has partial state""" + result = await self.db_pool.simple_select_one_onecol( + table="partial_state_events", + keyvalues={"event_id": event_id}, + retcol="1", + allow_none=True, + desc="is_partial_state_event", + ) + return result is not None diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 8f09dd8e8751..e9a0cdc6be94 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -112,7 +112,7 @@ async def get_registered_reserved_users(self) -> List[str]: for tp in self.hs.config.server.mau_limits_reserved_threepids[ : self.hs.config.server.max_mau_value ]: - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( tp["medium"], canonicalise_email(tp["address"]) ) if user_id: diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 4f05811a77eb..d3c4611686f8 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,15 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter @@ -35,7 +43,7 @@ def __init__( database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) # Used by `PresenceStore._get_active_presence()` @@ -54,11 +62,14 @@ def __init__( database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._presence_id_gen: AbstractStreamIdGenerator + self._can_persist_presence = ( - hs.get_instance_name() in hs.config.worker.writers.presence + self._instance_name in hs.config.worker.writers.presence ) if isinstance(database.engine, PostgresEngine): @@ -109,7 +120,9 @@ async def update_presence(self, presence_states) -> Tuple[int, int]: return stream_orderings[-1], self._presence_id_gen.get_current_token() - def _update_presence_txn(self, txn, stream_orderings, presence_states): + def _update_presence_txn( + self, txn: LoggingTransaction, stream_orderings, presence_states + ) -> None: for stream_id, state in zip(stream_orderings, presence_states): txn.call_after( self.presence_stream_cache.entity_has_changed, state.user_id, stream_id @@ -183,19 +196,23 @@ async def get_all_presence_updates( if last_id == current_id: return [], current_id, False - def get_all_presence_updates_txn(txn): + def get_all_presence_updates_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, - status_msg, - currently_active + status_msg, currently_active FROM presence_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] + updates = cast( + List[Tuple[int, list]], + [(row[0], row[1:]) for row in txn], + ) upper_bound = current_id limited = False @@ -210,7 +227,7 @@ def get_all_presence_updates_txn(txn): ) @cached() - def _get_presence_for_user(self, user_id): + def _get_presence_for_user(self, user_id: str) -> None: raise NotImplementedError() @cachedList( @@ -218,7 +235,9 @@ def _get_presence_for_user(self, user_id): list_name="user_ids", num_args=1, ) - async def get_presence_for_users(self, user_ids): + async def get_presence_for_users( + self, user_ids: Iterable[str] + ) -> Dict[str, UserPresenceState]: rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", @@ -257,7 +276,9 @@ async def should_user_receive_full_presence_with_token( True if the user should have full presence sent to them, False otherwise. """ - def _should_user_receive_full_presence_with_token_txn(txn): + def _should_user_receive_full_presence_with_token_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT 1 FROM users_to_send_full_presence_to WHERE user_id = ? @@ -271,7 +292,7 @@ def _should_user_receive_full_presence_with_token_txn(txn): _should_user_receive_full_presence_with_token_txn, ) - async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): + async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: """Adds to the list of users who should receive a full snapshot of presence upon their next sync. @@ -353,10 +374,10 @@ async def get_presence_for_all_users( return users_to_state - def get_current_presence_token(self): + def get_current_presence_token(self) -> int: return self._presence_id_gen.get_current_token() - def _get_active_presence(self, db_conn: Connection): + def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: """Fetch non-offline presence from the database so that we can register the appropriate time outs. """ @@ -379,12 +400,12 @@ def _get_active_presence(self, db_conn: Connection): return [UserPresenceState(**row) for row in rows] - def take_presence_startup_info(self): + def take_presence_startup_info(self) -> List[UserPresenceState]: active_on_startup = self._presence_on_startup - self._presence_on_startup = None + self._presence_on_startup = [] return active_on_startup - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index e87a8fb85dce..2e3818e43244 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -13,9 +13,10 @@ # limitations under the License. import logging -from typing import Any, List, Set, Tuple +from typing import Any, List, Set, Tuple, cast from synapse.api.errors import SynapseError +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken @@ -55,7 +56,11 @@ async def purge_history( ) def _purge_history_txn( - self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool + self, + txn: LoggingTransaction, + room_id: str, + token: RoomStreamToken, + delete_local_events: bool, ) -> Set[int]: # Tables that should be pruned: # event_auth @@ -273,7 +278,7 @@ def _purge_history_txn( """, (room_id,), ) - (min_depth,) = txn.fetchone() + (min_depth,) = cast(Tuple[int], txn.fetchone()) logger.info("[purge] updating room_depth to %d", min_depth) @@ -318,7 +323,7 @@ async def purge_room(self, room_id: str) -> List[int]: "purge_room", self._purge_room_txn, room_id ) - def _purge_room_txn(self, txn, room_id: str) -> List[int]: + def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: # First we fetch all the state groups that should be deleted, before # we delete that information. txn.execute( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index aac94fa46444..dc6665237abf 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -622,10 +622,13 @@ async def record_user_external_id( ) -> None: """Record a mapping from an external user id to a mxid + See notes in _record_user_external_id_txn about what constitutes valid data. + Args: auth_provider: identifier for the remote auth provider external_id: id on that system user_id: complete mxid that it is mapped to + Raises: ExternalIDReuseException if the new external_id could not be mapped. """ @@ -648,6 +651,21 @@ def _record_user_external_id_txn( external_id: str, user_id: str, ) -> None: + """ + Record a mapping from an external user id to a mxid. + + Note that the auth provider IDs (and the external IDs) are not validated + against configured IdPs as Synapse does not know its relationship to + external systems. For example, it might be useful to pre-configure users + before enabling a new IdP or an IdP might be temporarily offline, but + still valid. + + Args: + txn: The database transaction. + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ self.db_pool.simple_insert_txn( txn, @@ -687,10 +705,13 @@ async def replace_user_external_id( """Replace mappings from external user ids to a mxid in a single transaction. All mappings are deleted and the new ones are created. + See notes in _record_user_external_id_txn about what constitutes valid data. + Args: record_external_ids: List with tuple of auth_provider and external_id to record user_id: complete mxid that it is mapped to + Raises: ExternalIDReuseException if the new external_id could not be mapped. """ @@ -1660,7 +1681,8 @@ def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: user_id=row[1], device_id=row[2], next_token_id=row[3], - has_next_refresh_token_been_refreshed=row[4], + # SQLite returns 0 or 1 for false/true, so convert to a bool. + has_next_refresh_token_been_refreshed=bool(row[4]), # This column is nullable, ensure it's a boolean has_next_access_token_been_used=(row[5] or False), expiry_ts=row[6], @@ -1676,12 +1698,15 @@ async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None Set the successor of a refresh token, removing the existing successor if any. + This also deletes the predecessor refresh and access tokens, + since they cannot be valid anymore. + Args: token_id: ID of the refresh token to update. next_token_id: ID of its successor. """ - def _replace_refresh_token_txn(txn) -> None: + def _replace_refresh_token_txn(txn: LoggingTransaction) -> None: # First check if there was an existing refresh token old_next_token_id = self.db_pool.simple_select_one_onecol_txn( txn, @@ -1707,6 +1732,16 @@ def _replace_refresh_token_txn(txn) -> None: {"id": old_next_token_id}, ) + # Delete the previous refresh token, since we only want to keep the + # last 2 refresh tokens in the database. + # (The predecessor of the latest refresh token is still useful in + # case the refresh was interrupted and the client re-uses the old + # one.) + # This cascades to delete the associated access token. + self.db_pool.simple_delete_txn( + txn, "refresh_tokens", {"next_token_id": token_id} + ) + await self.db_pool.runInteraction( "replace_refresh_token", _replace_refresh_token_txn ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index e2c27e594b24..36aa1092f602 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -53,8 +53,13 @@ @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: + # The latest event in the thread. latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. count: int + # True if the current user has sent an event to the thread. current_user_participated: bool @@ -461,8 +466,8 @@ def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") async def _get_thread_summaries( self, event_ids: Collection[str] - ) -> Dict[str, Optional[Tuple[int, EventBase]]]: - """Get the number of threaded replies and the latest reply (if any) for the given event. + ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: + """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. Args: event_ids: Summarize the thread related to this event ID. @@ -471,8 +476,10 @@ async def _get_thread_summaries( A map of the thread summary each event. A missing event implies there are no threaded replies. - Each summary includes the number of items in the thread and the most - recent response. + Each summary is a tuple of: + The number of events in the thread. + The most recent event in the thread. + The most recent edit to the most recent event in the thread, if applicable. """ def _get_thread_summaries_txn( @@ -482,7 +489,7 @@ def _get_thread_summaries_txn( # TODO Should this only allow m.room.message events. if isinstance(self.database_engine, PostgresEngine): # The `DISTINCT ON` clause will pick the *first* row it encounters, - # so ordering by topologica ordering + stream ordering desc will + # so ordering by topological ordering + stream ordering desc will # ensure we get the latest event in the thread. sql = """ SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child @@ -558,6 +565,9 @@ def _get_thread_summaries_txn( latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] + # Check to see if any of those events are edited. + latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + # Map to the event IDs to the thread summary. # # There might not be a summary due to there not being a thread or @@ -568,7 +578,8 @@ def _get_thread_summaries_txn( summary = None if latest_event: - summary = (counts[parent_event_id], latest_event) + latest_edit = latest_edits.get(latest_event_id) + summary = (counts[parent_event_id], latest_event, latest_edit) summaries[parent_event_id] = summary return summaries @@ -828,11 +839,12 @@ async def get_bundled_aggregations( ) for event_id, summary in summaries.items(): if summary: - thread_count, latest_thread_event = summary + thread_count, latest_thread_event, edit = summary results.setdefault( event_id, BundledAggregations() ).thread = _ThreadAggregation( latest_event=latest_thread_event, + latest_edit=edit, count=thread_count, # If there's a thread summary it must also exist in the # participated dictionary. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 95167116c953..94068940b90d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -20,6 +20,7 @@ TYPE_CHECKING, Any, Awaitable, + Collection, Dict, List, Optional, @@ -1498,7 +1499,7 @@ def __init__( self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") async def upsert_room_on_join( - self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] + self, room_id: str, room_version: RoomVersion, state_events: List[EventBase] ) -> None: """Ensure that the room is stored in the table @@ -1511,7 +1512,7 @@ async def upsert_room_on_join( has_auth_chain_index = await self.has_auth_chain_index(room_id) create_event = None - for e in auth_events: + for e in state_events: if (e.type, e.state_key) == (EventTypes.Create, ""): create_event = e break @@ -1543,6 +1544,42 @@ async def upsert_room_on_join( lock=False, ) + async def store_partial_state_room( + self, + room_id: str, + servers: Collection[str], + ) -> None: + """Mark the given room as containing events with partial state + + Args: + room_id: the ID of the room + servers: other servers known to be in the room + """ + await self.db_pool.runInteraction( + "store_partial_state_room", + self._store_partial_state_room_txn, + room_id, + servers, + ) + + @staticmethod + def _store_partial_state_room_txn( + txn: LoggingTransaction, room_id: str, servers: Collection[str] + ) -> None: + DatabasePool.simple_insert_txn( + txn, + table="partial_state_rooms", + values={ + "room_id": room_id, + }, + ) + DatabasePool.simple_insert_many_txn( + txn, + table="partial_state_rooms_servers", + keys=("room_id", "server_name"), + values=((room_id, s) for s in servers), + ) + async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion ) -> None: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4489732fdac4..e48ec5f495af 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -504,6 +504,68 @@ def _get_rooms_for_user_with_stream_ordering_txn( for room_id, instance, stream_id in txn ) + @cachedList( + cached_method_name="get_rooms_for_user_with_stream_ordering", + list_name="user_ids", + ) + async def get_rooms_for_users_with_stream_ordering( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: + """A batched version of `get_rooms_for_user_with_stream_ordering`. + + Returns: + Map from user_id to set of rooms that is currently in. + """ + return await self.db_pool.runInteraction( + "get_rooms_for_users_with_stream_ordering", + self._get_rooms_for_users_with_stream_ordering_txn, + user_ids, + ) + + def _get_rooms_for_users_with_stream_ordering_txn( + self, txn, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: + + clause, args = make_in_list_sql_clause( + self.database_engine, + "c.state_key", + user_ids, + ) + + if self._current_state_events_membership_up_to_date: + sql = f""" + SELECT c.state_key, room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND c.membership = ? + AND {clause} + """ + else: + sql = f""" + SELECT c.state_key, room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND m.membership = ? + AND {clause} + """ + + txn.execute(sql, [Membership.JOIN] + args) + + result = {user_id: set() for user_id in user_ids} + for user_id, room_id, instance, stream_id in txn: + result[user_id].add( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + ) + + return {user_id: frozenset(v) for user_id, v in result.items()} + async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 2d085a5764a7..e23b1190726b 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -28,6 +28,7 @@ ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -114,6 +115,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" + EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings" def __init__( self, @@ -146,6 +148,10 @@ def __init__( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) + self.db_pool.updates.register_background_update_handler( + self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings + ) + async def _background_reindex_search(self, progress, batch_size): # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] @@ -371,6 +377,27 @@ def reindex_search_txn(txn): return num_rows + async def _background_delete_non_strings( + self, progress: JsonDict, batch_size: int + ) -> int: + """Deletes rows with non-string `value`s from `event_search` if using sqlite. + + Prior to Synapse 1.44.0, malformed events received over federation could cause integers + to be inserted into the `event_search` table when using sqlite. + """ + + def delete_non_strings_txn(txn: LoggingTransaction) -> None: + txn.execute("DELETE FROM event_search WHERE typeof(value) != 'text'") + + await self.db_pool.runInteraction( + self.EVENT_SEARCH_DELETE_NON_STRINGS, delete_non_strings_txn + ) + + await self.db_pool.updates._end_background_update( + self.EVENT_SEARCH_DELETE_NON_STRINGS + ) + return 1 + class SearchStore(SearchBackgroundUpdateStore): def __init__( @@ -381,17 +408,19 @@ def __init__( ): super().__init__(database, db_conn, hs) - async def search_msgs(self, room_ids, search_term, keys): + async def search_msgs( + self, room_ids: Collection[str], search_term: str, keys: Iterable[str] + ) -> JsonDict: """Performs a full text search over events with given keys. Args: - room_ids (list): List of room ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports "content.body", "content.name", "content.topic" Returns: - list of dicts + Dictionary of results """ clauses = [] @@ -499,10 +528,10 @@ async def search_rooms( self, room_ids: Collection[str], search_term: str, - keys: List[str], + keys: Iterable[str], limit, pagination_token: Optional[str] = None, - ) -> List[dict]: + ) -> JsonDict: """Performs a full text search over events with given keys. Args: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 2fb3e65192f5..417aef1dbcf3 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -42,6 +42,16 @@ MAX_STATE_DELTA_HOPS = 100 +def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: + v = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not v: + raise UnsupportedRoomVersionError( + "Room %s uses a room version %s which is no longer supported" + % (room_id, room_version_id) + ) + return v + + # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" @@ -62,11 +72,8 @@ async def get_room_version(self, room_id: str) -> RoomVersion: Typically this happens if support for the room's version has been removed from Synapse. """ - return await self.db_pool.runInteraction( - "get_room_version_txn", - self.get_room_version_txn, - room_id, - ) + room_version_id = await self.get_room_version_id(room_id) + return _retrieve_and_check_room_version(room_id, room_version_id) def get_room_version_txn( self, txn: LoggingTransaction, room_id: str @@ -82,15 +89,7 @@ def get_room_version_txn( removed from Synapse. """ room_version_id = self.get_room_version_id_txn(txn, room_id) - v = KNOWN_ROOM_VERSIONS.get(room_version_id) - - if not v: - raise UnsupportedRoomVersionError( - "Room %s uses a room version %s which is no longer supported" - % (room_id, room_version_id) - ) - - return v + return _retrieve_and_check_room_version(room_id, room_version_id) @cached(max_entries=10000) async def get_room_version_id(self, room_id: str) -> str: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f7c778bdf22b..e7fddd24262a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -58,7 +58,7 @@ def __init__( database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -234,10 +234,10 @@ def _get_next_batch( processed_event_count = 0 for room_id, event_count in rooms_to_work_on: - is_in_room = await self.is_host_joined(room_id, self.server_name) + is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined] if is_in_room: - users_with_profile = await self.get_users_in_room_with_profiles(room_id) + users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined] # Throw away users excluded from the directory. users_with_profile = { user_id: profile @@ -368,7 +368,7 @@ def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]: for user_id in users_to_work_on: if await self.should_include_local_user_in_dir(user_id): - profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined] await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -397,7 +397,7 @@ async def should_include_local_user_in_dir(self, user: str) -> bool: # technically it could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice sender can be # contacted. - if self.get_app_service_by_user_id(user) is not None: + if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined] return False # We're opting to exclude appservice users (anyone matching the user @@ -405,17 +405,17 @@ async def should_include_local_user_in_dir(self, user: str) -> bool: # they could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice users can be # contacted. - if self.get_if_app_services_interested_in_user(user): + if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined] # TODO we might want to make this configurable for each app service return False # Support users are for diagnostics and should not appear in the user directory. - if await self.is_support_user(user): + if await self.is_support_user(user): # type: ignore[attr-defined] return False # Deactivated users aren't contactable, so should not appear in the user directory. try: - if await self.get_user_deactivated_status(user): + if await self.get_user_deactivated_status(user): # type: ignore[attr-defined] return False except StoreError: # No such user in the users table. No need to do this when calling @@ -433,20 +433,20 @@ async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> boo (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( + current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) if join_rules_id: - join_rule_ev = await self.get_event(join_rules_id, allow_none=True) + join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined] if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: - hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) + hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined] if hist_vis_ev: if ( hist_vis_ev.content.get("history_visibility") diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 428d66a617b1..7d543fdbe08a 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -427,21 +427,21 @@ async def _persist_event_batch( # NB: Assumes that we are only persisting events for one room # at a time. - # map room_id->list[event_ids] giving the new forward + # map room_id->set[event_ids] giving the new forward # extremities in each room - new_forward_extremeties = {} + new_forward_extremities: Dict[str, Set[str]] = {} # map room_id->(type,state_key)->event_id tracking the full # state in each room after adding these events. # This is simply used to prefill the get_current_state_ids # cache - current_state_for_room = {} + current_state_for_room: Dict[str, StateMap[str]] = {} # map room_id->(to_delete, to_insert) where to_delete is a list # of type/state keys to remove from current state, and to_insert # is a map (type,key)->event_id giving the state delta in each # room - state_delta_for_room = {} + state_delta_for_room: Dict[str, DeltaState] = {} # Set of remote users which were in rooms the server has left. We # should check if we still share any rooms and if not we mark their @@ -460,14 +460,13 @@ async def _persist_event_batch( ) for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = ( + latest_event_ids = set( await self.main_store.get_latest_event_ids_in_room(room_id) ) new_latest_event_ids = await self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) - latest_event_ids = set(latest_event_ids) if new_latest_event_ids == latest_event_ids: # No change in extremities, so no change in state continue @@ -478,7 +477,7 @@ async def _persist_event_batch( # extremities, so we'll `continue` above and skip this bit.) assert new_latest_event_ids, "No forward extremities left!" - new_forward_extremeties[room_id] = new_latest_event_ids + new_forward_extremities[room_id] = new_latest_event_ids len_1 = ( len(latest_event_ids) == 1 @@ -533,7 +532,7 @@ async def _persist_event_batch( # extremities, so we'll `continue` above and skip this bit.) assert new_latest_event_ids, "No forward extremities left!" - new_forward_extremeties[room_id] = new_latest_event_ids + new_forward_extremities[room_id] = new_latest_event_ids # If either are not None then there has been a change, # and we need to work out the delta (or use that @@ -567,7 +566,7 @@ async def _persist_event_batch( ) if not is_still_joined: logger.info("Server no longer in room %s", room_id) - latest_event_ids = [] + latest_event_ids = set() current_state = {} delta.no_longer_in_room = True @@ -582,7 +581,7 @@ async def _persist_event_batch( chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, use_negative_stream_ordering=backfilled, inhibit_local_membership_updates=backfilled, ) @@ -596,7 +595,7 @@ async def _calculate_new_extremities( room_id: str, event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: Collection[str], - ): + ) -> Set[str]: """Calculates the new forward extremities for a room given events to persist. @@ -906,9 +905,9 @@ async def _prune_extremities( # Ideally we'd figure out a way of still being able to drop old # dummy events that reference local events, but this is good enough # as a first cut. - events_to_check = [event] + events_to_check: Collection[EventBase] = [event] while events_to_check: - new_events = set() + new_events: Set[str] = set() for event_to_check in events_to_check: if self.is_mine_id(event_to_check.sender): if event_to_check.type != EventTypes.Dummy: diff --git a/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql b/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql new file mode 100644 index 000000000000..09305638ea54 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql @@ -0,0 +1,28 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- next_token_id is a foreign key reference, so previously required a table scan +-- when a row in the referenced table was deleted. +-- As it was self-referential and cascaded deletes, this led to O(t*n) time to +-- delete a row, where t: number of rows in the table and n: number of rows in +-- the ancestral 'chain' of access tokens. +-- +-- This index is partial since we only require it for rows which reference +-- another. +-- Performance was tested to be the same regardless of whether the index was +-- full or partial, but a partial index can be smaller. +CREATE INDEX refresh_tokens_next_token_id + ON refresh_tokens(next_token_id) + WHERE next_token_id IS NOT NULL; diff --git a/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql new file mode 100644 index 000000000000..815c0cc390b6 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql @@ -0,0 +1,41 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- rooms which we have done a partial-state-style join to +CREATE TABLE IF NOT EXISTS partial_state_rooms ( + room_id TEXT PRIMARY KEY, + FOREIGN KEY(room_id) REFERENCES rooms(room_id) +); + +-- a list of remote servers we believe are in the room +CREATE TABLE IF NOT EXISTS partial_state_rooms_servers ( + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + server_name TEXT NOT NULL, + UNIQUE(room_id, server_name) +); + +-- a list of events with partial state. We can't store this in the `events` table +-- itself, because `events` is meant to be append-only. +CREATE TABLE IF NOT EXISTS partial_state_events ( + -- the room_id is denormalised for efficient indexing (the canonical source is `events`) + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + event_id TEXT NOT NULL REFERENCES events(event_id), + UNIQUE(event_id) +); + +CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx + ON partial_state_events (room_id); + + diff --git a/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite new file mode 100644 index 000000000000..140df652642d --- /dev/null +++ b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite @@ -0,0 +1,22 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Delete rows with non-string `value`s from `event_search` if using sqlite. +-- +-- Prior to Synapse 1.44.0, malformed events received over federation could +-- cause integers to be inserted into the `event_search` table. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6805, 'event_search_sqlite_delete_non_strings', '{}'); diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py new file mode 100644 index 000000000000..a2ec4fc26edb --- /dev/null +++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py @@ -0,0 +1,72 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This migration adds triggers to the partial_state_events tables to enforce uniqueness + +Triggers cannot be expressed in .sql files, so we have to use a separate file. +""" +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # complain if the room_id in partial_state_events doesn't match + # that in `events`. We already have a fk constraint which ensures that the event + # exists in `events`, so all we have to do is raise if there is a row with a + # matching stream_ordering but not a matching room_id. + if isinstance(database_engine, Sqlite3Engine): + cur.execute( + """ + CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id + BEFORE INSERT ON partial_state_events + FOR EACH ROW + BEGIN + SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events') + WHERE EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ); + END; + """ + ) + elif isinstance(database_engine, PostgresEngine): + cur.execute( + """ + CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$ + BEGIN + IF EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ) THEN + RAISE EXCEPTION 'Incorrect room_id in partial_state_events'; + END IF; + RETURN NEW; + END; + $BODY$ LANGUAGE plpgsql; + """ + ) + + cur.execute( + """ + CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events + FOR EACH ROW + EXECUTE PROCEDURE check_partial_state_events() + """ + ) + else: + raise NotImplementedError("Unknown database engine") diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 913448f0f95e..e79ecf64a0ec 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -204,13 +204,16 @@ def return_expanded(self) -> "StateFilter": if get_all_members: # We want to return everything. return StateFilter.all() - else: + elif EventTypes.Member in self.types: # We want to return all non-members, but only particular # memberships return StateFilter( types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), include_others=True, ) + else: + # We want to return all non-members + return _ALL_NON_MEMBER_STATE_FILTER def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. @@ -528,6 +531,9 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter": _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_NON_MEMBER_STATE_FILTER = StateFilter( + types=frozendict({EventTypes.Member: frozenset()}), include_others=True +) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 21591d0bfd2f..fb8fe1729516 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -37,16 +37,18 @@ class _EventSourcesInner: account_data: AccountDataEventSource def get_sources(self) -> Iterator[Tuple[str, EventSource]]: - for attribute in _EventSourcesInner.__attrs_attrs__: # type: ignore[attr-defined] + for attribute in attr.fields(_EventSourcesInner): yield attribute.name, getattr(self, attribute.name) class EventSources: def __init__(self, hs: "HomeServer"): self.sources = _EventSourcesInner( - *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__) # type: ignore[attr-defined] + # mypy thinks attribute.type is `Optional`, but we know it's never `None` here since + # all the attributes of `_EventSourcesInner` are annotated. + *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) # type: ignore[misc] ) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_token(self) -> StreamToken: push_rules_key = self.store.get_max_push_rules_stream_id() diff --git a/synapse/types.py b/synapse/types.py index f89fb216a656..53be3583a013 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -51,7 +51,7 @@ if TYPE_CHECKING: from synapse.appservice.api import ApplicationService - from synapse.storage.databases.main import DataStore + from synapse.storage.databases.main import DataStore, PurgeEventsStore # Define a state map type from type/state_key to T (usually an event ID or # event) @@ -485,7 +485,7 @@ def __attrs_post_init__(self) -> None: ) @classmethod - async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": + async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -502,7 +502,7 @@ async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": instance_id = int(key) pos = int(value) - instance_name = await store.get_name_from_instance_id(instance_id) + instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] instance_map[instance_name] = pos return cls( diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 511f52534b3c..58b4220ff355 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -81,7 +81,9 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]: def unwrapFirstError(failure: Failure) -> Failure: - # defer.gatherResults and DeferredLists wrap failures. + # Deprecated: you probably just want to catch defer.FirstError and reraise + # the subFailure's value, which will do a better job of preserving stacktraces. + # (actually, you probably want to use yieldable_gather_results anyway) failure.trap(defer.FirstError) return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 3f7299aff7eb..60c03a66fd16 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -29,6 +29,7 @@ Hashable, Iterable, Iterator, + List, Optional, Set, Tuple, @@ -51,7 +52,7 @@ make_deferred_yieldable, run_in_background, ) -from synapse.util import Clock, unwrapFirstError +from synapse.util import Clock logger = logging.getLogger(__name__) @@ -193,9 +194,9 @@ def __repr__(self) -> str: T = TypeVar("T") -def concurrently_execute( +async def concurrently_execute( func: Callable[[T], Any], args: Iterable[T], limit: int -) -> defer.Deferred: +) -> None: """Executes the function with each argument concurrently while limiting the number of concurrent executions. @@ -221,20 +222,14 @@ async def _concurrently_execute_inner(value: T) -> None: # We use `itertools.islice` to handle the case where the number of args is # less than the limit, avoiding needlessly spawning unnecessary background # tasks. - return make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background(_concurrently_execute_inner, value) - for value in itertools.islice(it, limit) - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) + await yieldable_gather_results( + _concurrently_execute_inner, (value for value in itertools.islice(it, limit)) + ) -def yieldable_gather_results( - func: Callable, iter: Iterable, *args: Any, **kwargs: Any -) -> defer.Deferred: +async def yieldable_gather_results( + func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any +) -> List[T]: """Executes the function with each argument concurrently. Args: @@ -245,15 +240,30 @@ def yieldable_gather_results( **kwargs: Keyword arguments to be passed to each call to func Returns - Deferred[list]: Resolved when all functions have been invoked, or errors if - one of the function calls fails. + A list containing the results of the function """ - return make_deferred_yieldable( - defer.gatherResults( - [run_in_background(func, item, *args, **kwargs) for item in iter], - consumeErrors=True, + try: + return await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], + consumeErrors=True, + ) ) - ).addErrback(unwrapFirstError) + except defer.FirstError as dfe: + # unwrap the error from defer.gatherResults. + + # The raised exception's traceback only includes func() etc if + # the 'await' happens before the exception is thrown - ie if the failure + # happens *asynchronously* - otherwise Twisted throws away the traceback as it + # could be large. + # + # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe + # we could throw Twisted into the fires of Mordor. + + # suppress exception chaining, because the FirstError doesn't tell us anything + # very interesting. + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None T1 = TypeVar("T1") @@ -545,7 +555,10 @@ def _ctx_manager() -> Iterator[None]: finally: with PreserveLoggingContext(): new_defer.callback(None) - if self.key_to_current_writer[key] == new_defer: + # `self.key_to_current_writer[key]` may be missing if there was another + # writer waiting for us and it completed entirely within the + # `new_defer.callback()` call above. + if self.key_to_current_writer.get(key) == new_defer: self.key_to_current_writer.pop(key) return _ctx_manager() @@ -655,3 +668,22 @@ def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: return value return DoneAwaitable(value) + + +def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`. + + Args: + deferred: The `Deferred` to protect against cancellation. Must not follow the + Synapse logcontext rules. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`, + but will not propagate cancellation through to the original. When cancelled, + the new `Deferred` will fail with a `CancelledError` and will not follow the + Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap + the new `Deferred`. + """ + new_deferred: defer.Deferred[T] = defer.Deferred() + deferred.chainDeferred(new_deferred) + return new_deferred diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 15debd6c460f..1cbc180eda72 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -56,6 +56,7 @@ class EvictionReason(Enum): size = auto() time = auto() + invalidation = auto() @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index df4fb156c2d0..1cdead02f14b 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,7 @@ import logging from typing import ( Any, + Awaitable, Callable, Dict, Generic, @@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given an iterable of keys it looks in the cache to find any hits, then passes - the tuple of missing keys to the wrapped function. + the set of missing keys to the wrapped function. - Once wrapped, the function returns a Deferred which resolves to the list - of results. + Once wrapped, the function returns a Deferred which resolves to a Dict mapping from + input key to output value. """ def __init__( self, - orig: Callable[..., Any], + orig: Callable[..., Awaitable[Dict]], cached_method_name: str, list_name: str, num_args: Optional[int] = None, @@ -385,13 +386,13 @@ def __init__( def __get__( self, obj: Optional[Any], objtype: Optional[Type] = None - ) -> Callable[..., Any]: + ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) - def wrapped(*args: Any, **kwargs: Any) -> Any: + def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": # If we're passed a cache_context then we'll want to call its # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -444,39 +445,38 @@ def arg_to_cache_key(arg: Hashable) -> Hashable: deferred: "defer.Deferred[Any]" = defer.Deferred() deferreds_map[arg] = deferred key = arg_to_cache_key(arg) - cache.set(key, deferred, callback=invalidate_callback) + cached_defers.append( + cache.set(key, deferred, callback=invalidate_callback) + ) def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a - # a dict. We can now resolve the observable deferreds in - # the cache and update our own result map. - for e in missing: + # the wrapped function has completed. It returns a dict. + # We can now update our own result map, and then resolve the + # observable deferreds in the cache. + for e, d1 in deferreds_map.items(): val = res.get(e, None) - deferreds_map[e].callback(val) + # make sure we update the results map before running the + # deferreds, because as soon as we run the last deferred, the + # gatherResults() below will complete and return the result + # dict to our caller. results[e] = val + d1.callback(val) - def errback(f: Failure) -> Failure: - # the wrapped function has failed. Invalidate any cache - # entries we're supposed to be populating, and fail - # their deferreds. - for e in missing: - key = arg_to_cache_key(e) - cache.invalidate(key) - deferreds_map[e].errback(f) - - # return the failure, to propagate to our caller. - return f + def errback_all(f: Failure) -> None: + # the wrapped function has failed. Propagate the failure into + # the cache, which will invalidate the entry, and cause the + # relevant cached_deferreds to fail, which will propagate the + # failure to our caller. + for d1 in deferreds_map.values(): + d1.errback(f) args_to_call = dict(arg_dict) - # copy the missing set before sending it to the callee, to guard against - # modification. - args_to_call[self.list_name] = tuple(missing) - - cached_defers.append( - defer.maybeDeferred( - preserve_fn(self.orig), **args_to_call - ).addCallbacks(complete_all, errback) - ) + args_to_call[self.list_name] = missing + + # dispatch the call, and attach the two handlers + defer.maybeDeferred( + preserve_fn(self.orig), **args_to_call + ).addCallbacks(complete_all, errback_all) if cached_defers: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 67ee4c693bc5..c6a5d0dfc0a9 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -133,6 +133,11 @@ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: raise KeyError(key) return default + if self.iterable: + self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) + else: + self.metrics.inc_evictions(EvictionReason.invalidation) + return value.value def __contains__(self, key: KT) -> bool: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 7548b38548bd..45ff0de638a4 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -560,8 +560,10 @@ def cache_pop(key: KT, default: T) -> Union[T, VT]: def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: node = cache.get(key, None) if node: - delete_node(node) + evicted_len = delete_node(node) cache.pop(node.key, None) + if metrics: + metrics.inc_evictions(EvictionReason.invalidation, evicted_len) return node.value else: return default diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py new file mode 100644 index 000000000000..12cd8049392f --- /dev/null +++ b/synapse/util/check_dependencies.py @@ -0,0 +1,175 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module exposes a single function which checks synapse's dependencies are present +and correctly versioned. It makes use of `importlib.metadata` to do so. The details +are a bit murky: there's no easy way to get a map from "extras" to the packages they +require. But this is probably just symptomatic of Python's package management. +""" + +import logging +from typing import Iterable, NamedTuple, Optional + +from packaging.requirements import Requirement + +DISTRIBUTION_NAME = "matrix-synapse" + +try: + from importlib import metadata +except ImportError: + import importlib_metadata as metadata # type: ignore[no-redef] + +__all__ = ["check_requirements"] + + +class DependencyException(Exception): + @property + def message(self) -> str: + return "\n".join( + [ + "Missing Requirements: %s" % (", ".join(self.dependencies),), + "To install run:", + " pip install --upgrade --force %s" % (" ".join(self.dependencies),), + "", + ] + ) + + @property + def dependencies(self) -> Iterable[str]: + for i in self.args[0]: + yield '"' + i + '"' + + +DEV_EXTRAS = {"lint", "mypy", "test", "dev"} +RUNTIME_EXTRAS = ( + set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) - DEV_EXTRAS +) +VERSION = metadata.version(DISTRIBUTION_NAME) + + +def _is_dev_dependency(req: Requirement) -> bool: + return req.marker is not None and any( + req.marker.evaluate({"extra": e}) for e in DEV_EXTRAS + ) + + +class Dependency(NamedTuple): + requirement: Requirement + must_be_installed: bool + + +def _generic_dependencies() -> Iterable[Dependency]: + """Yield pairs (requirement, must_be_installed).""" + requirements = metadata.requires(DISTRIBUTION_NAME) + assert requirements is not None + for raw_requirement in requirements: + req = Requirement(raw_requirement) + if _is_dev_dependency(req): + continue + + # https://packaging.pypa.io/en/latest/markers.html#usage notes that + # > Evaluating an extra marker with no environment is an error + # so we pass in a dummy empty extra value here. + must_be_installed = req.marker is None or req.marker.evaluate({"extra": ""}) + yield Dependency(req, must_be_installed) + + +def _dependencies_for_extra(extra: str) -> Iterable[Dependency]: + """Yield additional dependencies needed for a given `extra`.""" + requirements = metadata.requires(DISTRIBUTION_NAME) + assert requirements is not None + for raw_requirement in requirements: + req = Requirement(raw_requirement) + if _is_dev_dependency(req): + continue + # Exclude mandatory deps by only selecting deps needed with this extra. + if ( + req.marker is not None + and req.marker.evaluate({"extra": extra}) + and not req.marker.evaluate({"extra": ""}) + ): + yield Dependency(req, True) + + +def _not_installed(requirement: Requirement, extra: Optional[str] = None) -> str: + if extra: + return ( + f"Synapse {VERSION} needs {requirement.name} for {extra}, " + f"but it is not installed" + ) + else: + return f"Synapse {VERSION} needs {requirement.name}, but it is not installed" + + +def _incorrect_version( + requirement: Requirement, got: str, extra: Optional[str] = None +) -> str: + if extra: + return ( + f"Synapse {VERSION} needs {requirement} for {extra}, " + f"but got {requirement.name}=={got}" + ) + else: + return ( + f"Synapse {VERSION} needs {requirement}, but got {requirement.name}=={got}" + ) + + +def check_requirements(extra: Optional[str] = None) -> None: + """Check Synapse's dependencies are present and correctly versioned. + + If provided, `extra` must be the name of an pacakging extra (e.g. "saml2" in + `pip install matrix-synapse[saml2]`). + + If `extra` is None, this function checks that + - all mandatory dependencies are installed and correctly versioned, and + - each optional dependency that's installed is correctly versioned. + + If `extra` is not None, this function checks that + - the dependencies needed for that extra are installed and correctly versioned. + + :raises DependencyException: if a dependency is missing or incorrectly versioned. + :raises ValueError: if this extra does not exist. + """ + # First work out which dependencies are required, and which are optional. + if extra is None: + dependencies = _generic_dependencies() + elif extra in RUNTIME_EXTRAS: + dependencies = _dependencies_for_extra(extra) + else: + raise ValueError(f"Synapse {VERSION} does not provide the feature '{extra}'") + + deps_unfulfilled = [] + errors = [] + + for (requirement, must_be_installed) in dependencies: + try: + dist: metadata.Distribution = metadata.distribution(requirement.name) + except metadata.PackageNotFoundError: + if must_be_installed: + deps_unfulfilled.append(requirement.name) + errors.append(_not_installed(requirement, extra)) + else: + # We specify prereleases=True to allow prereleases such as RCs. + if not requirement.specifier.contains(dist.version, prereleases=True): + deps_unfulfilled.append(requirement.name) + errors.append(_incorrect_version(requirement, dist.version, extra)) + + if deps_unfulfilled: + for err in errors: + logging.error(err) + + raise DependencyException(deps_unfulfilled) diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index de04f34e4e5c..031880ec39e1 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -20,7 +20,7 @@ import signal import sys from types import FrameType, TracebackType -from typing import NoReturn, Type +from typing import NoReturn, Optional, Type def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None: @@ -100,7 +100,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # also catch any other uncaught exceptions before we get that far.) def excepthook( - type_: Type[BaseException], value: BaseException, traceback: TracebackType + type_: Type[BaseException], + value: BaseException, + traceback: Optional[TracebackType], ) -> None: logger.critical("Unhanded exception", exc_info=(type_, value, traceback)) @@ -123,7 +125,7 @@ def excepthook( sys.exit(1) # write a log line on SIGTERM. - def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn: + def sigterm(signum: int, frame: Optional[FrameType]) -> NoReturn: logger.warning("Caught signal %s. Stopping daemon." % signum) sys.exit(0) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 1f18654d47e7..6d4b0b7c5adb 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -14,7 +14,7 @@ import functools import sys -from typing import Any, Callable, Generator, List, TypeVar +from typing import Any, Callable, Generator, List, TypeVar, cast from twisted.internet import defer from twisted.internet.defer import Deferred @@ -174,7 +174,9 @@ def check_yield_points_inner( ) ) changes.append(err) - return getattr(e, "value", None) + # The `StopIteration` or `_DefGen_Return` contains the return value from the + # generator. + return cast(T, e.value) frame = gen.gi_frame diff --git a/synctl b/synctl index 0e54f4847bbd..1ab36949c741 100755 --- a/synctl +++ b/synctl @@ -37,6 +37,13 @@ YELLOW = "\x1b[1;33m" RED = "\x1b[1;31m" NORMAL = "\x1b[m" +SYNCTL_CACHE_FACTOR_WARNING = """\ +Setting 'synctl_cache_factor' in the config is deprecated. Instead, please do +one of the following: + - Either set the environment variable 'SYNAPSE_CACHE_FACTOR' + - or set 'caches.global_factor' in the homeserver config. +--------------------------------------------------------------------------------""" + def pid_running(pid): try: @@ -228,6 +235,7 @@ def main(): start_stop_synapse = True if cache_factor: + write(SYNCTL_CACHE_FACTOR_WARNING) os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) cache_factors = config.get("synctl_cache_factors", {}) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4b53b6d40b31..3e057899236d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -16,6 +16,8 @@ import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -26,8 +28,10 @@ ResourceLimitError, ) from synapse.appservice import ApplicationService +from synapse.server import HomeServer from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester +from synapse.util import Clock from tests import unittest from tests.test_utils import simple_async_mock @@ -36,10 +40,10 @@ class AuthTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): self.store = Mock() - hs.get_datastore = Mock(return_value=self.store) + hs.datastores.main = self.store hs.get_auth_handler().store = self.store self.auth = Auth(hs) @@ -67,7 +71,7 @@ def test_get_user_by_req_user_valid_token(self): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): self.store.get_user_by_access_token = simple_async_mock(None) @@ -105,7 +109,7 @@ def test_get_user_by_req_appservice_valid_token(self): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_good_ip(self): from netaddr import IPSet @@ -124,7 +128,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): from netaddr import IPSet @@ -191,7 +195,7 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self): request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals( + self.assertEqual( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -238,10 +242,10 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self): request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals( + self.assertEqual( requester.user.to_string(), masquerading_user_id.decode("utf8") ) - self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8")) + self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8")) @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): @@ -271,8 +275,8 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): request.requestHeaders.getRawHeaders = mock_getRawHeaders() failure = self.get_failure(self.auth.get_user_by_req(request), AuthError) - self.assertEquals(failure.value.code, 400) - self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE) + self.assertEqual(failure.value.code, 400) + self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self): self.store.get_user_by_access_token = simple_async_mock( @@ -305,7 +309,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(self.store.insert_client_ip.call_count, 2) + self.assertEqual(self.store.insert_client_ip.call_count, 2) def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = simple_async_mock( @@ -365,9 +369,9 @@ def test_blocking_mau(self): self.store.get_monthly_active_count = simple_async_mock(lots_of_users) e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) @@ -469,9 +473,9 @@ def test_hs_disabled(self): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) def test_hs_disabled_no_server_notices_user(self): """Check that 'hs_disabled_message' works correctly when there is no @@ -484,9 +488,9 @@ def test_hs_disabled_no_server_notices_user(self): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) def test_server_notices_mxid_special_cased(self): self.auth_blocking._hs_disabled = True diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index b7fc33dc94d1..8c3354ce3c79 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -18,6 +18,7 @@ from unittest.mock import patch import jsonschema +from frozendict import frozendict from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError @@ -40,7 +41,7 @@ def MockEvent(**kwargs): class FilteringTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main def test_errors_on_invalid_filters(self): invalid_filters = [ @@ -327,6 +328,15 @@ def test_filter_labels(self): self.assertFalse(Filter(self.hs, definition)._check(event)) + # check it works with frozendicts too + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content=frozendict({EventContentFields.LABELS: ["#fun"]}), + ) + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_not_labels(self): definition = {"org.matrix.not_labels": ["#fun"]} event = MockEvent( @@ -364,7 +374,7 @@ def test_filter_presence_match(self): ) results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEquals(events, results) + self.assertEqual(events, results) def test_filter_presence_no_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} @@ -388,7 +398,7 @@ def test_filter_presence_no_match(self): ) results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEquals([], results) + self.assertEqual([], results) def test_filter_room_state_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -407,7 +417,7 @@ def test_filter_room_state_match(self): ) results = self.get_success(user_filter.filter_room_state(events=events)) - self.assertEquals(events, results) + self.assertEqual(events, results) def test_filter_room_state_no_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -428,7 +438,7 @@ def test_filter_room_state_no_match(self): ) results = self.get_success(user_filter.filter_room_state(events)) - self.assertEquals([], results) + self.assertEqual([], results) def test_filter_rooms(self): definition = { @@ -444,7 +454,7 @@ def test_filter_rooms(self): filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids)) - self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) + self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): @@ -486,7 +496,7 @@ async def events_have_relations(*args, **kwargs): Filter(self.hs, definition)._check_event_relations(events) ) ) - self.assertEquals(filtered_events, events[1:]) + self.assertEqual(filtered_events, events[1:]) def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -497,8 +507,8 @@ def test_add_filter(self): ) ) - self.assertEquals(filter_id, 0) - self.assertEquals( + self.assertEqual(filter_id, 0) + self.assertEqual( user_filter_json, ( self.get_success( @@ -524,6 +534,6 @@ def test_get_filter(self): ) ) - self.assertEquals(filter.get_filter_json(), user_filter_json) + self.assertEqual(filter.get_filter_json(), user_filter_json) - self.assertRegexpMatches(repr(filter), r"") + self.assertRegex(repr(filter), r"") diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dcf0110c16e0..483d5463add2 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -8,25 +8,25 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( @@ -39,25 +39,25 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( @@ -70,29 +70,29 @@ def test_allowed_appservice_via_can_requester_do_action(self): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) def test_allowed_via_ratelimit(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # Shouldn't raise @@ -116,7 +116,7 @@ def test_allowed_via_can_do_action_and_overriding_parameters(self): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -162,7 +162,7 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -190,7 +190,7 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self): def test_pruning(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -208,7 +208,7 @@ def test_db_user_override(self): """Test that users that have ratelimiting disabled in the DB aren't ratelimited. """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = "@user:test" requester = create_requester(user_id) @@ -233,7 +233,7 @@ def test_db_user_override(self): def test_multiple_actions(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( @@ -246,7 +246,7 @@ def test_multiple_actions(self): limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that, after doing these 3 actions, we can't do any more action without # waiting. @@ -254,7 +254,7 @@ def test_multiple_actions(self): limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that after waiting we can do only 1 action. allowed, time_allowed = self.get_success_or_raise( @@ -269,7 +269,7 @@ def test_multiple_actions(self): self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10) @@ -277,7 +277,7 @@ def test_multiple_actions(self): self.assertFalse(allowed) # The time allowed doesn't change despite allowed being False because, while we # don't allow 2 actions, we could still do 1. - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that after waiting a bit more we can do 2 actions. allowed, time_allowed = self.get_success_or_raise( @@ -286,4 +286,4 @@ def test_multiple_actions(self): self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 19eb4c79d043..df731eb599cc 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -32,7 +32,7 @@ def test_r30_minimum_usage(self): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -40,7 +40,7 @@ def test_r30_minimum_usage(self): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -51,21 +51,21 @@ def test_r30_minimum_usage(self): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 29 days. The user has now not posted for 29 days. self.reactor.advance(29 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 30 days. self.reactor.advance(ONE_DAY_IN_SECONDS) # The user is now no longer counted in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_minimum_usage_using_default_config(self): @@ -84,7 +84,7 @@ def test_r30_minimum_usage_using_default_config(self): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -92,7 +92,7 @@ def test_r30_minimum_usage_using_default_config(self): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -103,14 +103,14 @@ def test_r30_minimum_usage_using_default_config(self): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 27 days. The user has now not posted for 27 days. self.reactor.advance(27 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 28 days. @@ -119,7 +119,7 @@ def test_r30_minimum_usage_using_default_config(self): # The user is now no longer counted in R30. # (This is because the user_ips table has been pruned, which by default # only preserves the last 28 days of entries.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_user_must_be_retained_for_at_least_a_month(self): @@ -135,7 +135,7 @@ def test_r30_user_must_be_retained_for_at_least_a_month(self): self.helper.send(room_id, "message", tok=access_token) # Check the user does not contribute to R30 yet. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) for _ in range(30): @@ -144,14 +144,16 @@ def test_r30_user_must_be_retained_for_at_least_a_month(self): self.helper.send(room_id, "I'm still here", tok=access_token) # Notice that the user *still* does not contribute to R30! - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success( + self.hs.get_datastores().main.count_r30_users() + ) self.assertEqual(r30_results, {"all": 0}) self.reactor.advance(ONE_DAY_IN_SECONDS) self.helper.send(room_id, "Still here!", tok=access_token) # *Now* the user appears in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) @@ -196,7 +198,7 @@ def test_r30v2_minimum_usage(self): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the R30 results do not count that user. r30_results = self.get_success(store.count_r30v2_users()) @@ -275,7 +277,7 @@ def test_r30v2_user_must_be_retained_for_at_least_a_month(self): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the user does not contribute to R30 yet. r30_results = self.get_success(store.count_r30v2_users()) @@ -347,7 +349,7 @@ def test_r30v2_returning_dormant_users_not_counted(self): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user does not contribute to R30v2, even though it's been # more than 30 days since registration. diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 8fb6687f89eb..1cbb059357fa 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -68,8 +68,10 @@ def test_single_service_up_txn_sent(self): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) - self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made + self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed def test_single_service_down(self): @@ -92,9 +94,11 @@ def test_single_service_down(self): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) - self.assertEquals(0, txn.send.call_count) # txn not sent though - self.assertEquals(0, txn.complete.call_count) # or completed + self.assertEqual(0, txn.send.call_count) # txn not sent though + self.assertEqual(0, txn.complete.call_count) # or completed def test_single_service_up_txn_not_sent(self): # Test: The AS is up and the txn is not sent. A Recoverer is made and @@ -114,12 +118,17 @@ def test_single_service_up_txn_not_sent(self): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[], to_device_messages=[] + service=service, + events=events, + ephemeral=[], + to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) - self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made - self.assertEquals(1, self.recoverer.recover.call_count) # and invoked - self.assertEquals(1, len(self.txnctrl.recoverers)) # and stored - self.assertEquals(0, txn.complete.call_count) # txn not completed + self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made + self.assertEqual(1, self.recoverer.recover.call_count) # and invoked + self.assertEqual(1, len(self.txnctrl.recoverers)) # and stored + self.assertEqual(0, txn.complete.call_count) # txn not completed self.store.set_appservice_state.assert_called_once_with( service, ApplicationServiceState.DOWN # service marked as down ) @@ -152,17 +161,17 @@ def take_txn(*args, **kwargs): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff - self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = simple_async_mock(True) txn.complete = simple_async_mock(None) # wait for exp backoff self.clock.advance_time(2) - self.assertEquals(1, txn.send.call_count) - self.assertEquals(1, txn.complete.call_count) + self.assertEqual(1, txn.send.call_count) + self.assertEqual(1, txn.complete.call_count) # 2 because it needs to get None to know there are no more txns - self.assertEquals(2, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(2, self.store.get_oldest_unsent_txn.call_count) self.callback.assert_called_once_with(self.recoverer) - self.assertEquals(self.recoverer.service, self.service) + self.assertEqual(self.recoverer.service, self.service) def test_recover_retry_txn(self): txn = Mock() @@ -178,26 +187,26 @@ def take_txn(*args, **kwargs): self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.recoverer.recover() - self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = simple_async_mock(False) txn.complete = simple_async_mock(None) self.clock.advance_time(2) - self.assertEquals(1, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(1, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) self.clock.advance_time(4) - self.assertEquals(2, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(2, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) self.clock.advance_time(8) - self.assertEquals(3, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(3, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) txn.send = simple_async_mock(True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) - self.assertEquals(1, txn.send.call_count) # new mock reset call count - self.assertEquals(1, txn.complete.call_count) + self.assertEqual(1, txn.send.call_count) # new mock reset call count + self.assertEqual(1, txn.complete.call_count) self.callback.assert_called_once_with(self.recoverer) @@ -216,7 +225,7 @@ def test_send_single_event_no_queue(self): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -231,12 +240,14 @@ def test_send_single_event_with_queue(self): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], []) - self.assertEquals(1, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) + self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], []) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with( + service, [event2, event3], [], [], None, None + ) + self.assertEqual(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): # Tests that each service has its own queue, and that they don't block @@ -261,16 +272,16 @@ def do_send(*args, **kwargs): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], []) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], []) - self.assertEquals(3, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) + self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): srv_1_defer = defer.Deferred() @@ -288,28 +299,38 @@ def do_send(*args, **kwargs): self.scheduler.enqueue_for_appservice(service, [event], []) # Expect the first event to be sent immediately. - self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event_list[0]], [], [], None, None + ) srv_1_defer.callback(service) # Then send the next 100 events - self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[1:101], [], [], None, None + ) srv_2_defer.callback(service) # Then the final 99 events - self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], []) - self.assertEquals(3, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with( + service, event_list[101:], [], [], None, None + ) + self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_multiple_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_single_ephemeral_with_queue(self): d = defer.Deferred() @@ -324,15 +345,15 @@ def test_send_single_ephemeral_with_queue(self): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, []) - self.assertEquals(1, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) + self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [] + service, [], event_list_2 + event_list_3, [], None, None ) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.assertEqual(2, self.txn_ctrl.send.call_count) def test_send_large_txns_ephemeral(self): d = defer.Deferred() @@ -343,7 +364,9 @@ def test_send_large_txns_ephemeral(self): second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], first_chunk, [], None, None + ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, []) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) + self.assertEqual(2, self.txn_ctrl.send.call_count) diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index a72a0103d3b8..694020fbef85 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -63,14 +63,14 @@ def test_sign_minimal(self): self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) - self.assertEquals( + self.assertEqual( event.hashes["sha256"], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" ) self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) - self.assertEquals( + self.assertEqual( event.signatures[HOSTNAME][KEY_NAME], "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+" "aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA", @@ -97,14 +97,14 @@ def test_sign_message(self): self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) - self.assertEquals( + self.assertEqual( event.hashes["sha256"], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" ) self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) - self.assertEquals( + self.assertEqual( event.signatures[HOSTNAME][KEY_NAME], "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA", diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 17a9fb63a176..d00ef24ca806 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -76,7 +76,7 @@ class FakeRequest: @logcontext_clean class KeyringTestCase(unittest.HomeserverTestCase): def check_context(self, val, expected): - self.assertEquals(getattr(current_context(), "request", None), expected) + self.assertEqual(getattr(current_context(), "request", None), expected) return val def test_verify_json_objects_for_server_awaits_previous_requests(self): @@ -96,7 +96,7 @@ def test_verify_json_objects_for_server_awaits_previous_requests(self): async def first_lookup_fetch( server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: - # self.assertEquals(current_context().request.id, "context_11") + # self.assertEqual(current_context().request.id, "context_11") self.assertEqual(server_name, "server10") self.assertEqual(key_ids, [get_key_id(key1)]) self.assertEqual(minimum_valid_until_ts, 0) @@ -137,7 +137,7 @@ async def first_lookup(): async def second_lookup_fetch( server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: - # self.assertEquals(current_context().request.id, "context_12") + # self.assertEqual(current_context().request.id, "context_12") return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)} mock_fetcher.get_keys.reset_mock() @@ -179,7 +179,7 @@ def test_verify_json_for_server(self): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -272,7 +272,7 @@ def test_verify_json_for_server_with_null_valid_until_ms(self): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], @@ -448,7 +448,7 @@ async def get_json(destination, path, **kwargs): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -564,7 +564,7 @@ def test_get_keys_from_perspectives(self): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -683,7 +683,7 @@ def test_get_perspectives_own_key(self): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index ca27388ae8a3..defbc68c18cd 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -28,7 +28,7 @@ class TestEventContext(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.user_id = self.register_user("u1", "pass") diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 1dea09e4800d..45e3395b3361 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -395,7 +395,7 @@ def serialize(self, ev, fields): return serialize_event(ev, 1479807801915, only_event_fields=fields) def test_event_fields_works_with_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"] ), @@ -403,7 +403,7 @@ def test_event_fields_works_with_keys(self): ) def test_event_fields_works_with_nested_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -416,7 +416,7 @@ def test_event_fields_works_with_nested_keys(self): ) def test_event_fields_works_with_dot_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -429,7 +429,7 @@ def test_event_fields_works_with_dot_keys(self): ) def test_event_fields_works_with_nested_dot_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -445,7 +445,7 @@ def test_event_fields_works_with_nested_dot_keys(self): ) def test_event_fields_nops_with_unknown_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -458,7 +458,7 @@ def test_event_fields_nops_with_unknown_keys(self): ) def test_event_fields_nops_with_non_dict_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -471,7 +471,7 @@ def test_event_fields_nops_with_non_dict_keys(self): ) def test_event_fields_nops_with_array_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -484,7 +484,7 @@ def test_event_fields_nops_with_array_keys(self): ) def test_event_fields_all_fields_if_empty(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( type="foo", diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index e40ef9587417..9f1115dd23b8 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -50,19 +50,19 @@ def test_complexity_simple(self): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) # Artificially raise the complexity - store = self.hs.get_datastore() + store = self.hs.get_datastores().main store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) @@ -149,7 +149,7 @@ def test_join_too_large_once_joined(self): ) # Artificially raise the complexity - self.hs.get_datastore().get_current_state_event_counts = ( + self.hs.get_datastores().main.get_current_state_event_counts = ( lambda x: make_awaitable(600) ) diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index f0aa8ed9db4d..2873b4d43012 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -64,7 +64,7 @@ def get_destination_room(self, room: str, destination: str = "host2") -> dict: Dictionary of { event_id: str, stream_ordering: int } """ event_id, stream_ordering = self.get_success( - self.hs.get_datastore().db_pool.execute( + self.hs.get_datastores().main.db_pool.execute( "test:get_destination_rooms", None, """ @@ -125,7 +125,7 @@ def test_catch_up_last_successful_stream_ordering_tracking(self): self.pump() lsso_1 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -141,7 +141,7 @@ def test_catch_up_last_successful_stream_ordering_tracking(self): event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] lsso_2 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -216,7 +216,9 @@ def test_catch_up_from_blank_state(self): # let's also clear any backoffs self.get_success( - self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0) + self.hs.get_datastores().main.set_destination_retry_timings( + "host2", None, 0, 0 + ) ) # bring the remote online and clear the received pdu list @@ -296,13 +298,13 @@ def test_catch_up_loop(self): # destination_rooms should already be populated, but let us pretend that we already # sent (successfully) up to and including event id 2 - event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2)) + event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2)) # also fetch event 5 so we know its last_successful_stream_ordering later - event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5)) + event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5)) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_2.internal_metadata.stream_ordering ) ) @@ -359,7 +361,7 @@ def test_catch_up_on_synapse_startup(self): # ASSERT: # - All servers are up to date so none should have outstanding catch-up outstanding_when_successful = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(outstanding_when_successful, []) @@ -370,7 +372,7 @@ def test_catch_up_on_synapse_startup(self): # - Mark zzzerver as being backed-off from now = self.clock.time_msec() self.get_success( - self.hs.get_datastore().set_destination_retry_timings( + self.hs.get_datastores().main.set_destination_retry_timings( "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day ) ) @@ -382,14 +384,14 @@ def test_catch_up_on_synapse_startup(self): # - all remotes are outstanding # - they are returned in batches of 25, in order outstanding_1 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(len(outstanding_1), 25) self.assertEqual(outstanding_1, server_names[0:25]) outstanding_2 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations( + self.hs.get_datastores().main.get_catch_up_outstanding_destinations( outstanding_1[-1] ) ) @@ -457,7 +459,7 @@ def test_not_latest_event(self): ) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_1.internal_metadata.stream_ordering ) ) diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py new file mode 100644 index 000000000000..ec8864dafe37 --- /dev/null +++ b/tests/federation/test_federation_client.py @@ -0,0 +1,149 @@ +# Copyright 2022 Matrix.org Federation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest import mock + +import twisted.web.client +from twisted.internet import defer +from twisted.internet.protocol import Protocol +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.room_versions import RoomVersions +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.unittest import FederatingHomeserverTestCase + + +class FederationClientTest(FederatingHomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + super().prepare(reactor, clock, homeserver) + + # mock out the Agent used by the federation client, which is easier than + # catching the HTTPS connection and do the TLS stuff. + self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True) + homeserver.get_federation_http_client().agent = self._mock_agent + + def test_get_room_state(self): + creator = f"@creator:{self.OTHER_SERVER_NAME}" + test_room_id = "!room_id" + + # mock up some events to use in the response. + # In real life, these would have things in `prev_events` and `auth_events`, but that's + # a bit annoying to mock up, and the code under test doesn't care, so we don't bother. + create_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.create", + "state_key": "", + "sender": creator, + "content": {"creator": creator}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 500, + } + ) + member_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.member", + "sender": creator, + "state_key": creator, + "content": {"membership": "join"}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 600, + } + ) + pl_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.power_levels", + "sender": creator, + "state_key": "", + "content": {}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 700, + } + ) + + # mock up the response, and have the agent return it + self._mock_agent.request.return_value = defer.succeed( + _mock_response( + { + "pdus": [ + create_event_dict, + member_event_dict, + pl_event_dict, + ], + "auth_chain": [ + create_event_dict, + member_event_dict, + ], + } + ) + ) + + # now fire off the request + state_resp, auth_resp = self.get_success( + self.hs.get_federation_client().get_room_state( + "yet_another_server", + test_room_id, + "event_id", + RoomVersions.V9, + ) + ) + + # check the right call got made to the agent + self._mock_agent.request.assert_called_once_with( + b"GET", + b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id", + headers=mock.ANY, + bodyProducer=None, + ) + + # ... and that the response is correct. + + # the auth_resp should be empty because all the events are also in state + self.assertEqual(auth_resp, []) + + # all of the events should be returned in state_resp, though not necessarily + # in the same order. We just check the type on the assumption that if the type + # is right, so is the rest of the event. + self.assertCountEqual( + [e.type for e in state_resp], + ["m.room.create", "m.room.member", "m.room.power_levels"], + ) + + +def _mock_response(resp: JsonDict): + body = json.dumps(resp).encode("utf-8") + + def deliver_body(p: Protocol): + p.dataReceived(body) + p.connectionLost(Failure(twisted.web.client.ResponseDone())) + + response = mock.Mock( + code=200, + phrase=b"OK", + headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), + length=len(body), + deliverBody=deliver_body, + ) + mock.seal(response) + return response diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b2376e2db925..60e0c31f4384 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -176,7 +176,7 @@ def prepare(self, reactor, clock, hs): def get_users_who_share_room_with_user(user_id): return defer.succeed({"@user2:host2"}) - hs.get_datastore().get_users_who_share_room_with_user = ( + hs.get_datastores().main.get_users_who_share_room_with_user = ( get_users_who_share_room_with_user ) @@ -395,7 +395,7 @@ def test_prune_outbound_device_pokes1(self): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server @@ -445,7 +445,7 @@ def test_prune_outbound_device_pokes2(self): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index d084919ef7d7..30e7e5093a17 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -59,7 +59,7 @@ def test_bad_request(self, query_content): "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, ) - self.assertEquals(400, channel.code, channel.result) + self.assertEqual(400, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") @@ -125,7 +125,7 @@ def test_without_event_id(self): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertEqual( channel.json_body["room_version"], @@ -157,7 +157,7 @@ def test_needs_to_be_in_room(self): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -189,7 +189,7 @@ def _make_join(self, user_id) -> JsonDict: f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}" f"?ver={DEFAULT_ROOM_VERSION}", ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) return channel.json_body def test_send_join(self): @@ -209,7 +209,7 @@ def test_send_join(self): f"/_matrix/federation/v2/send_join/{self._room_id}/x", content=join_event_dict, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # we should get complete room state back returned_state = [ @@ -266,7 +266,7 @@ def test_send_join_partial_state(self): f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", content=join_event_dict, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # expect a reduced room state returned_state = [ diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index a7031a55f28c..c2320ce133aa 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -62,3 +62,35 @@ def test_two_writes(self) -> None: self.assertEqual(len(parsed_response.state), 1, parsed_response) self.assertEqual(parsed_response.event_dict, {}, parsed_response) self.assertIsNone(parsed_response.event, parsed_response) + self.assertFalse(parsed_response.partial_state, parsed_response) + self.assertEqual(parsed_response.servers_in_room, None, parsed_response) + + def test_partial_state(self) -> None: + """Check that the partial_state flag is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = { + "org.matrix.msc3706.partial_state": True, + } + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertTrue(parsed_response.partial_state) + + def test_servers_in_room(self) -> None: + """Check that the servers_in_room field is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]} + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"]) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 686f42ab48ac..648a01618e8b 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -169,7 +169,7 @@ def check_knock_room_state_against_room_state( self.assertIn(event_type, expected_room_state) # Check the state content matches - self.assertEquals( + self.assertEqual( expected_room_state[event_type]["content"], event["content"] ) @@ -198,7 +198,7 @@ class FederationKnockingTestCase( ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # We're not going to be properly signing events as our remote homeserver is fake, # therefore disable event signature checks. @@ -256,7 +256,7 @@ def test_room_state_returned_when_knocking(self): RoomVersions.V7.identifier, ), ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Note: We don't expect the knock membership event to be sent over federation as # part of the stripped room state, as the knocking homeserver already has that @@ -266,11 +266,11 @@ def test_room_state_returned_when_knocking(self): knock_event = channel.json_body["event"] # Check that the event has things we expect in it - self.assertEquals(knock_event["room_id"], room_id) - self.assertEquals(knock_event["sender"], fake_knocking_user_id) - self.assertEquals(knock_event["state_key"], fake_knocking_user_id) - self.assertEquals(knock_event["type"], EventTypes.Member) - self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK) + self.assertEqual(knock_event["room_id"], room_id) + self.assertEqual(knock_event["sender"], fake_knocking_user_id) + self.assertEqual(knock_event["state_key"], fake_knocking_user_id) + self.assertEqual(knock_event["type"], EventTypes.Member) + self.assertEqual(knock_event["content"]["membership"], Membership.KNOCK) # Turn the event json dict into a proper event. # We won't sign it properly, but that's OK as we stub out event auth in `prepare` @@ -294,7 +294,7 @@ def test_room_state_returned_when_knocking(self): % (room_id, signed_knock_event.event_id), signed_knock_event_json, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Check that we got the stripped room state in return room_state_events = channel.json_body["knock_state_events"] diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index eb62addda8c6..5f001c33b05e 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -13,7 +13,7 @@ # limitations under the License. from tests import unittest -from tests.unittest import override_config +from tests.unittest import DEBUG, override_config class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): @@ -26,7 +26,7 @@ def test_blocked_public_room_list_over_federation(self): "GET", "/_matrix/federation/v1/publicRooms", ) - self.assertEquals(403, channel.code) + self.assertEqual(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) def test_open_public_room_list_over_federation(self): @@ -37,4 +37,22 @@ def test_open_public_room_list_over_federation(self): "GET", "/_matrix/federation/v1/publicRooms", ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) + + @DEBUG + def test_edu_debugging_doesnt_explode(self): + """Sanity check incoming federation succeeds with `synapse.debug_8631` enabled. + + Remove this when we strip out issue_8631_logger. + """ + channel = self.make_signed_federation_request( + "PUT", + "/_matrix/federation/v1/send/txn_id_1234/", + content={ + "edus": [ + {"edu_type": "m.device_list_update", "content": {"foo": "bar"}} + ], + "pdus": [], + }, + ) + self.assertEqual(200, channel.code) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index fe57ff267119..072e6bbcdd6e 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -16,17 +16,25 @@ from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.appservice import ApplicationService +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.rest.client import login, receipts, room, sendtodevice +from synapse.rest.client import login, receipts, register, room, sendtodevice +from synapse.server import HomeServer from synapse.types import RoomStreamToken +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config from tests.utils import MockClock @@ -38,7 +46,7 @@ def setUp(self): self.mock_as_api = Mock() self.mock_scheduler = Mock() hs = Mock() - hs.get_datastore.return_value = self.mock_store + hs.get_datastores.return_value = Mock(main=self.mock_store) self.mock_store.get_received_ts.return_value = make_awaitable(0) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( @@ -139,8 +147,8 @@ def test_query_room_alias_exists(self): self.mock_as_api.query_alias.assert_called_once_with( interested_service, room_alias_str ) - self.assertEquals(result.room_id, room_id) - self.assertEquals(result.servers, servers) + self.assertEqual(result.room_id, room_id) + self.assertEqual(result.servers, servers) def test_get_3pe_protocols_no_appservices(self): self.mock_store.get_app_services.return_value = [] @@ -148,7 +156,7 @@ def test_get_3pe_protocols_no_appservices(self): defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) ) self.mock_as_api.get_3pe_protocol.assert_not_called() - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_no_protocols(self): service = self._mkservice(False, []) @@ -157,7 +165,7 @@ def test_get_3pe_protocols_no_protocols(self): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) self.mock_as_api.get_3pe_protocol.assert_not_called() - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_protocol_no_response(self): service = self._mkservice(False, ["my-protocol"]) @@ -169,7 +177,7 @@ def test_get_3pe_protocols_protocol_no_response(self): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_select_one_protocol(self): service = self._mkservice(False, ["my-protocol"]) @@ -183,7 +191,7 @@ def test_get_3pe_protocols_select_one_protocol(self): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals( + self.assertEqual( response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) @@ -199,7 +207,7 @@ def test_get_3pe_protocols_one_protocol(self): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals( + self.assertEqual( response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) @@ -214,7 +222,7 @@ def test_get_3pe_protocols_multiple_protocol(self): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) self.mock_as_api.get_3pe_protocol.assert_called() - self.assertEquals( + self.assertEqual( response, { "my-protocol": {"x-protocol-data": 42, "instances": []}, @@ -246,7 +254,7 @@ async def get_3pe_protocol(service, unusedProtocol): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) # It's expected that the second service's data doesn't appear in the response - self.assertEquals( + self.assertEqual( response, { "my-protocol": { @@ -355,7 +363,9 @@ def prepare(self, reactor, clock, hs): # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastore().get_app_services = Mock(return_value=self._services) + self.hs.get_datastores().main.get_app_services = Mock( + return_value=self._services + ) # A user on the homeserver. self.local_user_device_id = "local_device" @@ -426,7 +436,14 @@ def test_application_services_receive_local_to_device(self): # # The uninterested application service should not have been notified at all. self.send_mock.assert_called_once() - service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent self.assertEqual(service, interested_appservice) @@ -494,7 +511,7 @@ def test_application_services_receive_bursts_of_to_device(self): # Create a fake device per message. We can't send to-device messages to # a device that doesn't exist. self.get_success( - self.hs.get_datastore().db_pool.simple_insert_many( + self.hs.get_datastores().main.db_pool.simple_insert_many( desc="test_application_services_receive_burst_of_to_device", table="devices", keys=("user_id", "device_id"), @@ -510,7 +527,7 @@ def test_application_services_receive_bursts_of_to_device(self): # Seed the device_inbox table with our fake messages self.get_success( - self.hs.get_datastore().add_messages_to_device_inbox(messages, {}) + self.hs.get_datastores().main.add_messages_to_device_inbox(messages, {}) ) # Now have local_user send a final to-device message to exclusive_as_user. All unsent @@ -538,7 +555,7 @@ def test_application_services_receive_bursts_of_to_device(self): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages = call[0] + service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) @@ -580,3 +597,174 @@ def _register_application_service( self._services.append(appservice) return appservice + + +class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): + # Argument indices for pulling out arguments from a `send_mock`. + ARG_OTK_COUNTS = 4 + ARG_FALLBACK_KEYS = 5 + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track what's going out + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + + # Define an application service for the tests + self._service_token = "VERYSECRET" + self._service = ApplicationService( + self._service_token, + "as1.invalid", + "as1", + "@as.sender:test", + namespaces={ + "users": [ + {"regex": "@_as_.*:test", "exclusive": True}, + {"regex": "@as.sender:test", "exclusive": True}, + ] + }, + msc3202_transaction_extensions=True, + ) + self.hs.get_datastores().main.services_cache = [self._service] + + # Register some appservice users + self._sender_user, self._sender_device = self.register_appservice_user( + "as.sender", self._service_token + ) + self._namespaced_user, self._namespaced_device = self.register_appservice_user( + "_as_user1", self._service_token + ) + + # Register a real user as well. + self._real_user = self.register_user("real.user", "meow") + self._real_user_token = self.login("real.user", "meow") + + async def _add_otks_for_device( + self, user_id: str, device_id: str, otk_count: int + ) -> None: + """ + Add some dummy keys. It doesn't matter if they're not a real algorithm; + that should be opaque to the server anyway. + """ + await self.hs.get_datastores().main.add_e2e_one_time_keys( + user_id, + device_id, + self.clock.time_msec(), + [("algo", f"k{i}", "{}") for i in range(otk_count)], + ) + + async def _add_fallback_key_for_device( + self, user_id: str, device_id: str, used: bool + ) -> None: + """ + Adds a fake fallback key to a device, optionally marking it as used + right away. + """ + store = self.hs.get_datastores().main + await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"}) + if used is True: + # Mark the key as used + await store.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": "algo", + "key_id": "fk", + }, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", + ) + + def _set_up_devices_and_a_room(self) -> str: + """ + Helper to set up devices for all the users + and a room for the users to talk in. + """ + + async def preparation(): + await self._add_otks_for_device(self._sender_user, self._sender_device, 42) + await self._add_fallback_key_for_device( + self._sender_user, self._sender_device, used=True + ) + await self._add_otks_for_device( + self._namespaced_user, self._namespaced_device, 36 + ) + await self._add_fallback_key_for_device( + self._namespaced_user, self._namespaced_device, used=False + ) + + # Register a device for the real user, too, so that we can later ensure + # that we don't leak information to the AS about the non-AS user. + await self.hs.get_datastores().main.store_device( + self._real_user, "REALDEV", "UltraMatrix 3000" + ) + await self._add_otks_for_device(self._real_user, "REALDEV", 50) + + self.get_success(preparation()) + + room_id = self.helper.create_room_as( + self._real_user, is_public=True, tok=self._real_user_token + ) + self.helper.join( + room_id, + self._namespaced_user, + tok=self._service_token, + appservice_user_id=self._namespaced_user, + ) + + # Check it was called for sanity. (This was to send the join event to the AS.) + self.send_mock.assert_called() + self.send_mock.reset_mock() + + return room_id + + @override_config( + {"experimental_features": {"msc3202_transaction_extensions": True}} + ) + def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus( + self, + ) -> None: + """ + Tests that: + - the AS receives one-time key counts and unused fallback keys for: + - the specified sender; and + - any user who is in receipt of the PDUs + """ + + room_id = self._set_up_devices_and_a_room() + + # Send a message into the AS's room + self.helper.send(room_id, "woof woof", tok=self._real_user_token) + + # Capture what was sent as an AS transaction. + self.send_mock.assert_called() + last_args, _last_kwargs = self.send_mock.call_args + otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS] + unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ + self.ARG_FALLBACK_KEYS + ] + + self.assertEqual( + otks, + { + "@as.sender:test": {self._sender_device: {"algo": 42}}, + "@_as_user1:test": {self._namespaced_device: {"algo": 36}}, + }, + ) + self.assertEqual( + unused_fallbacks, + { + "@as.sender:test": {self._sender_device: []}, + "@_as_user1:test": {self._namespaced_device: ["algo"]}, + }, + ) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 03b8b8615c62..0c6e55e72592 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -129,7 +129,7 @@ def test_mau_limits_disabled(self): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) @@ -140,7 +140,7 @@ def test_mau_limits_exceeded_large(self): ResourceLimitError, ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) self.get_failure( @@ -156,7 +156,7 @@ def test_mau_limits_parity(self): self.auth_blocking._limit_usage_by_mau = True # Set the server to be at the edge of too many users. - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.auth_blocking._max_mau_value) ) @@ -175,7 +175,7 @@ def test_mau_limits_parity(self): ) # If in monthly active cohort - self.hs.get_datastore().user_last_seen_monthly_active = Mock( + self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( return_value=make_awaitable(self.clock.time_msec()) ) self.get_success( @@ -192,7 +192,7 @@ def test_mau_limits_parity(self): def test_mau_limits_not_exceeded(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception @@ -202,7 +202,7 @@ def test_mau_limits_not_exceeded(self): ) ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) self.get_success( diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 8705ff894343..a2672288465a 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -77,7 +77,7 @@ def test_map_cas_user_to_user(self): def test_map_cas_user_to_existing_user(self): """Existing users can log in with CAS account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 01096a1581e3..ddda36c5a93b 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -34,7 +34,7 @@ class DeactivateAccountTestCase(HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 43031e07ea77..683677fd0770 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -28,7 +28,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def prepare(self, reactor, clock, hs): @@ -263,7 +263,7 @@ def make_homeserver(self, reactor, clock): self.handler = hs.get_device_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_dehydrate_and_rehydrate_device(self): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 0ea4e753e2b2..6e403a87c5d0 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -46,7 +46,7 @@ def register_query_handler(query_type, handler): self.handler = hs.get_directory_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.my_room = RoomAlias.from_string("#my-room:test") self.your_room = RoomAlias.from_string("#your-room:test") @@ -63,7 +63,7 @@ def test_get_local_association(self): result = self.get_success(self.handler.get_association(self.my_room)) - self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) + self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): self.mock_federation.make_query.return_value = make_awaitable( @@ -72,7 +72,7 @@ def test_get_remote_association(self): result = self.get_success(self.handler.get_association(self.remote_room)) - self.assertEquals( + self.assertEqual( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result ) self.mock_federation.make_query.assert_called_with( @@ -94,7 +94,7 @@ def test_incoming_fed_query(self): self.handler.on_directory_query({"room_alias": "#your-room:test"}) ) - self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + self.assertEqual({"room_id": "!8765asdf:test", "servers": ["test"]}, response) class TestCreateAlias(unittest.HomeserverTestCase): @@ -174,7 +174,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -224,7 +224,7 @@ def test_delete_alias_creator(self): create_requester(self.test_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -243,7 +243,7 @@ def test_delete_alias_admin(self): create_requester(self.admin_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -269,7 +269,7 @@ def test_delete_alias_sufficient_power(self): create_requester(self.test_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -411,7 +411,7 @@ def test_denied(self): b"directory/room/%23test%3Atest", {"room_id": room_id}, ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) def test_allowed(self): room_id = self.helper.create_room_as(self.user_id) @@ -421,7 +421,7 @@ def test_allowed(self): b"directory/room/%23unofficial_test%3Atest", {"room_id": room_id}, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def test_denied_during_creation(self): """A room alias that is not allowed should be rejected during creation.""" @@ -443,8 +443,8 @@ def test_allowed_during_creation(self): "GET", b"directory/room/%23unofficial_test%3Atest", ) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(channel.json_body["room_id"], room_id) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(channel.json_body["room_id"], room_id) class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): @@ -572,7 +572,7 @@ def prepare(self, reactor, clock, hs): channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.room_list_handler = hs.get_room_list_handler() self.directory_handler = hs.get_directory_handler() @@ -585,7 +585,7 @@ def test_disabling_room_list(self): # Room list is enabled so we should get some results channel = self.make_request("GET", b"publicRooms") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) self.room_list_handler.enable_room_list_search = False @@ -593,7 +593,7 @@ def test_disabling_room_list(self): # Room list disabled so we should get no results channel = self.make_request("GET", b"publicRooms") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) # Room list disabled so we shouldn't be allowed to publish rooms @@ -601,4 +601,4 @@ def test_disabling_room_list(self): channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 734ed84d78d0..9338ab92e98e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ def make_homeserver(self, reactor, clock): def prepare(self, reactor, clock, hs): self.handler = hs.get_e2e_keys_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_query_local_devices_no_devices(self): """If the user has no devices, we expect an empty list.""" diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 496b58172643..e8b4e39d1a32 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -45,7 +45,7 @@ class FederationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self._event_auth_handler = hs.get_event_auth_handler() return hs diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 5816295d8b97..f4f7ab48458e 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -44,7 +44,7 @@ def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) self.info = self.get_success( - self.hs.get_datastore().get_user_by_access_token( + self.hs.get_datastores().main.get_user_by_access_token( self.access_token, ) ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a552d8182e09..e8418b6638e4 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -856,7 +856,7 @@ def test_map_userinfo_to_user(self): auth_handler.complete_sso_login.reset_mock() # Test if the mxid is already taken - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user3 = UserID.from_string("@test_user_3:test") self.get_success( store.register_user(user_id=user3.to_string(), password_hash=None) @@ -872,7 +872,7 @@ def test_map_userinfo_to_user(self): @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user = UserID.from_string("@test_user:test") self.get_success( store.register_user(user_id=user.to_string(), password_hash=None) @@ -996,7 +996,7 @@ def test_map_userinfo_to_user_retries(self): auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 4740dd0a65e3..49d832de814d 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -84,7 +84,7 @@ def parse_config(self): def __init__(self, config, api: ModuleApi): api.register_password_auth_provider_callbacks( - auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + auth_checkers={("test.login_type", ("test_field",)): self.check_auth} ) def check_auth(self, *args): @@ -122,7 +122,7 @@ def __init__(self, config, api: ModuleApi): auth_checkers={ ("test.login_type", ("test_field",)): self.check_auth, ("m.login.password", ("password",)): self.check_auth, - }, + } ) pass @@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): account.register_servlets, ] + CALLBACK_USERNAME = "get_username_for_registration" + CALLBACK_DISPLAYNAME = "get_displayname_for_registration" + def setUp(self): # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() @@ -754,7 +757,9 @@ def test_username(self): """Tests that the get_username_for_registration callback can define the username of a user when registering. """ - self._setup_get_username_for_registration() + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) username = "rin" channel = self.make_request( @@ -777,30 +782,14 @@ def test_username_uia(self): """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ - m = self._setup_get_username_for_registration() - - # Initiate the UIA flow. - username = "rin" - channel = self.make_request( - "POST", - "register", - {"username": username, "type": "m.login.password", "password": "bar"}, + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, ) - self.assertEqual(channel.code, 401) - self.assertIn("session", channel.json_body) - # Check that the callback hasn't been called yet. - m.assert_not_called() + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) - # Finish the UIA flow. - session = channel.json_body["session"] - channel = self.make_request( - "POST", - "register", - {"auth": {"session": session, "type": LoginType.DUMMY}}, - ) - self.assertEqual(channel.code, 200, channel.json_body) - mxid = channel.json_body["user_id"] + mxid = res["user_id"] self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") # Check that the callback has been called. @@ -817,6 +806,56 @@ def test_3pid_allowed(self): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) + def test_displayname(self): + """Tests that the get_displayname_for_registration callback can define the + display name of a user when registering. + """ + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + channel = self.make_request( + "POST", + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(channel.code, 200) + + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + user_id = UserID.from_string(channel.json_body["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + def test_displayname_uia(self): + """Tests that the get_displayname_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) + + user_id = UserID.from_string(res["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + def _test_3pid_allowed(self, username: str, registration: bool): """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -877,23 +916,47 @@ def _test_3pid_allowed(self, username: str, registration: bool): m.assert_called_once_with("email", "bar@test.com", registration) - def _setup_get_username_for_registration(self) -> Mock: - """Registers a get_username_for_registration callback that appends "-foo" to the - username the client is trying to register. + def _setup_get_name_for_registration(self, callback_name: str) -> Mock: + """Registers either a get_username_for_registration callback or a + get_displayname_for_registration callback that appends "-foo" to the username the + client is trying to register. """ - async def get_username_for_registration(uia_results, params): + async def callback(uia_results, params): self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" - m = Mock(side_effect=get_username_for_registration) + m = Mock(side_effect=callback) password_auth_provider = self.hs.get_password_auth_provider() - password_auth_provider.get_username_for_registration_callbacks.append(m) + getattr(password_auth_provider, callback_name + "_callbacks").append(m) return m + def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict: + # Initiate the UIA flow. + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 671dc7d083c8..6ddec9ecf1fe 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -43,7 +43,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_offline_to_online(self): wheel_timer = Mock() @@ -61,11 +61,11 @@ def test_offline_to_online(self): self.assertTrue(persist_and_notify) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -104,11 +104,11 @@ def test_online_to_online(self): self.assertFalse(persist_and_notify) self.assertTrue(federation_ping) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -149,11 +149,11 @@ def test_online_to_online_last_active_noop(self): self.assertFalse(persist_and_notify) self.assertTrue(federation_ping) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -191,11 +191,11 @@ def test_online_to_online_last_active(self): self.assertTrue(persist_and_notify) self.assertFalse(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 2) + self.assertEqual(wheel_timer.insert.call_count, 2) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -227,10 +227,10 @@ def test_remote_ping_timer(self): self.assertFalse(persist_and_notify) self.assertFalse(federation_ping) self.assertFalse(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) - self.assertEquals(wheel_timer.insert.call_count, 1) + self.assertEqual(wheel_timer.insert.call_count, 1) wheel_timer.insert.assert_has_calls( [ call( @@ -259,10 +259,10 @@ def test_online_to_offline(self): ) self.assertTrue(persist_and_notify) - self.assertEquals(new_state.state, state.state) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 0) + self.assertEqual(wheel_timer.insert.call_count, 0) def test_online_to_idle(self): wheel_timer = Mock() @@ -281,12 +281,12 @@ def test_online_to_idle(self): ) self.assertTrue(persist_and_notify) - self.assertEquals(new_state.state, state.state) - self.assertEquals(state.last_federation_update_ts, now) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEqual(new_state.state, state.state) + self.assertEqual(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) - self.assertEquals(wheel_timer.insert.call_count, 1) + self.assertEqual(wheel_timer.insert.call_count, 1) wheel_timer.insert.assert_has_calls( [ call( @@ -357,8 +357,8 @@ def test_idle_timer(self): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) + self.assertEqual(new_state.status_msg, status_msg) def test_busy_no_idle(self): """ @@ -380,8 +380,8 @@ def test_busy_no_idle(self): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.BUSY) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.BUSY) + self.assertEqual(new_state.status_msg, status_msg) def test_sync_timeout(self): user_id = "@foo:bar" @@ -399,8 +399,8 @@ def test_sync_timeout(self): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.OFFLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.OFFLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_sync_online(self): user_id = "@foo:bar" @@ -420,8 +420,8 @@ def test_sync_online(self): ) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.ONLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.ONLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_federation_ping(self): user_id = "@foo:bar" @@ -440,7 +440,7 @@ def test_federation_ping(self): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(state, new_state) + self.assertEqual(state, new_state) def test_no_timeout(self): user_id = "@foo:bar" @@ -477,8 +477,8 @@ def test_federation_timeout(self): ) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.OFFLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.OFFLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_last_active(self): user_id = "@foo:bar" @@ -497,7 +497,7 @@ def test_last_active(self): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(state, new_state) + self.assertEqual(state, new_state) class PresenceHandlerTestCase(unittest.HomeserverTestCase): @@ -891,7 +891,7 @@ def prepare(self, reactor, clock, hs): # self.event_builder_for_2 = EventBuilderFactory(hs) # self.event_builder_for_2.hostname = "test2" - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 60235e569987..972cbac6e496 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -48,7 +48,7 @@ def register_query_handler(query_type, handler): return hs def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.frank = UserID.from_string("@1234abcd:test") self.bob = UserID.from_string("@4567:test") @@ -65,7 +65,7 @@ def test_get_my_name(self): displayname = self.get_success(self.handler.get_displayname(self.frank)) - self.assertEquals("Frank", displayname) + self.assertEqual("Frank", displayname) def test_set_my_name(self): self.get_success( @@ -74,7 +74,7 @@ def test_set_my_name(self): ) ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -90,7 +90,7 @@ def test_set_my_name(self): ) ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -118,7 +118,7 @@ def test_set_my_name_if_disabled(self): self.store.set_profile_displayname(self.frank.localpart, "Frank") ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -150,7 +150,7 @@ def test_get_other_name(self): displayname = self.get_success(self.handler.get_displayname(self.alice)) - self.assertEquals(displayname, "Alice") + self.assertEqual(displayname, "Alice") self.mock_federation.make_query.assert_called_with( destination="remote", query_type="profile", @@ -172,7 +172,7 @@ def test_incoming_fed_query(self): ) ) - self.assertEquals({"displayname": "Caroline"}, response) + self.assertEqual({"displayname": "Caroline"}, response) def test_get_my_avatar(self): self.get_success( @@ -182,7 +182,7 @@ def test_get_my_avatar(self): ) avatar_url = self.get_success(self.handler.get_avatar_url(self.frank)) - self.assertEquals("http://my.server/me.png", avatar_url) + self.assertEqual("http://my.server/me.png", avatar_url) def test_set_my_avatar(self): self.get_success( @@ -193,7 +193,7 @@ def test_set_my_avatar(self): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/pic.gif", ) @@ -207,7 +207,7 @@ def test_set_my_avatar(self): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) @@ -235,7 +235,7 @@ def test_set_my_avatar_if_disabled(self): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) @@ -325,7 +325,7 @@ def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5de89c873b94..5081b97573a0 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -314,4 +314,4 @@ def _test_filters_hidden( ): """Tests that the _filter_out_hidden returns the expected output""" filtered_events = self.event_source.filter_out_hidden(events, "@me:server.org") - self.assertEquals(filtered_events, expected_output) + self.assertEqual(filtered_events, expected_output) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index cd6f2c77aea4..45fd30cf4369 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -154,7 +154,7 @@ def make_homeserver(self, reactor, clock): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.lots_of_users = 100 self.small_number_of_users = 1 @@ -167,12 +167,12 @@ def test_user_is_created_and_logged_in_if_doesnt_exist(self): result_user_id, result_token = self.get_success( self.get_or_create_user(requester, frank.localpart, "Frankie") ) - self.assertEquals(result_user_id, user_id) + self.assertEqual(result_user_id, user_id) self.assertIsInstance(result_token, str) self.assertGreater(len(result_token), 20) def test_if_user_exists(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main frank = UserID.from_string("@frank:test") self.get_success( store.register_user(user_id=frank.to_string(), password_hash=None) @@ -183,7 +183,7 @@ def test_if_user_exists(self): result_user_id, result_token = self.get_success( self.get_or_create_user(requester, local_part, None) ) - self.assertEquals(result_user_id, user_id) + self.assertEqual(result_user_id, user_id) self.assertTrue(result_token is not None) @override_config({"limit_usage_by_mau": False}) @@ -760,7 +760,7 @@ async def lookup_room_alias(*args, **kwargs): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main @override_config({"auto_join_rooms": ["#room:remotetest"]}) def test_auto_create_auto_join_remote_room(self): diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 51b22d299812..cff07a8973b3 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -157,35 +157,6 @@ def _add_child( state_key=room_id, ) - def _assert_rooms( - self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] - ) -> None: - """ - Assert that the expected room IDs and events are in the response. - - Args: - result: The result from the API call. - rooms_and_children: An iterable of tuples where each tuple is: - The expected room ID. - The expected IDs of any children rooms. - """ - room_ids = [] - children_ids = [] - for room_id, children in rooms_and_children: - room_ids.append(room_id) - if children: - children_ids.extend([(room_id, child_id) for child_id in children]) - self.assertCountEqual( - [room.get("room_id") for room in result["rooms"]], room_ids - ) - self.assertCountEqual( - [ - (event.get("room_id"), event.get("state_key")) - for event in result["events"] - ], - children_ids, - ) - def _assert_hierarchy( self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] ) -> None: @@ -251,11 +222,9 @@ def _poke_fed_invite(self, room_id: str, from_user: str) -> None: def test_simple_space(self): """Test a simple space with a single room.""" - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have the space and the room in it, along with a link # from space -> room. expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -271,12 +240,6 @@ def test_large_space(self): self._add_child(self.space, room, self.token) rooms.append(room) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The spaces result should have the space and the first 50 rooms in it, - # along with the links from space -> room for those 50 rooms. - expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]] - self._assert_rooms(result, expected) - # The result should have the space and the rooms in it, along with the links # from space -> room. expected = [(self.space, rooms)] + [(room, []) for room in rooms] @@ -300,10 +263,7 @@ def test_visibility(self): token2 = self.login("user2", "pass") # The user can see the space since it is publicly joinable. - result = self.get_success(self.handler.get_space_summary(user2, self.space)) expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) @@ -316,7 +276,6 @@ def test_visibility(self): body={"join_rule": JoinRules.INVITE}, tok=self.token, ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure( self.handler.get_room_hierarchy(create_requester(user2), self.space), AuthError, @@ -329,9 +288,6 @@ def test_visibility(self): body={"history_visibility": HistoryVisibility.WORLD_READABLE}, tok=self.token, ) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) @@ -344,7 +300,6 @@ def test_visibility(self): body={"history_visibility": HistoryVisibility.JOINED}, tok=self.token, ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure( self.handler.get_room_hierarchy(create_requester(user2), self.space), AuthError, @@ -353,19 +308,12 @@ def test_visibility(self): # Join the space and results should be returned. self.helper.invite(self.space, targ=user2, tok=self.token) self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) self._assert_hierarchy(result, expected) # Attempting to view an unknown room returns the same error. - self.get_failure( - self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), - AuthError, - ) self.get_failure( self.handler.get_room_hierarchy( create_requester(user2), "#not-a-space:" + self.hs.hostname @@ -496,7 +444,6 @@ def test_filtering(self): # Join the space. self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) expected = [ ( self.space, @@ -520,7 +467,6 @@ def test_filtering(self): (world_readable_room, ()), (joined_room, ()), ] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) @@ -554,8 +500,6 @@ def test_complex_space(self): self._add_child(subspace, self.room, token=self.token) self._add_child(subspace, room2, self.token) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The result should include each room a single time and each link. expected = [ (self.space, [self.room, room2, subspace]), @@ -563,7 +507,6 @@ def test_complex_space(self): (subspace, [subroom, self.room, room2]), (subroom, ()), ] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -715,7 +658,7 @@ def test_max_depth(self): def test_unknown_room_version(self): """ - If an room with an unknown room version is encountered it should not cause + If a room with an unknown room version is encountered it should not cause the entire summary to skip. """ # Poke the database and update the room version to an unknown one. @@ -727,11 +670,12 @@ def test_unknown_room_version(self): desc="updated-room-version", ) ) + # Invalidate method so that it returns the currently updated version + # instead of the cached version. + self.hs.get_datastores().main.get_room_version_id.invalidate((self.room,)) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have only the space, along with a link from space -> room. expected = [(self.space, [self.room])] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -775,41 +719,18 @@ def test_fed_complex(self): "world_readable": True, } - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [ - requested_room_entry, - _RoomEntry( - subroom, - { - "room_id": subroom, - "world_readable": True, - }, - ), - ] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return requested_room_entry, {subroom: child_room}, set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, subspace]), (self.room, ()), (subspace, [subroom]), (subroom, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", @@ -881,7 +802,7 @@ def test_fed_filtering(self): "room_id": restricted_room, "world_readable": False, "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], + "allowed_room_ids": [], }, ), ( @@ -890,7 +811,7 @@ def test_fed_filtering(self): "room_id": restricted_accessible_room, "world_readable": False, "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], + "allowed_room_ids": [self.room], }, ), ( @@ -929,30 +850,12 @@ def test_fed_filtering(self): ], ) - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [subspace_room_entry] + [ - # A copy is made of the room data since the allowed_spaces key - # is removed. - _RoomEntry(child_room[0], dict(child_room[1])) - for child_room in children_rooms - ] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return subspace_room_entry, dict(children_rooms), set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, subspace]), (self.room, ()), @@ -976,7 +879,6 @@ async def summarize_remote_room_hierarchy(_self, room, suggested_only): (world_readable_room, ()), (joined_room, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", @@ -1010,31 +912,17 @@ def test_fed_invited(self): }, ) - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [fed_room_entry] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return fed_room_entry, {}, set() # Add a room to the space which is on another server. self._add_child(self.space, fed_room, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, fed_room]), (self.room, ()), (fed_room, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 50551aa6e3c2..23941abed852 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -142,7 +142,7 @@ def test_map_saml_response_to_user(self): @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): """Existing users can log in with SAML account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) @@ -217,7 +217,7 @@ def test_map_saml_response_to_user_retries(self): sso_handler.render_error = Mock(return_value=None) # register a user to occupy the first-choice MXID - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 56207f4db6f6..ecd78fa369a8 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -33,7 +33,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = self.hs.get_stats_handler() def _add_background_updates(self): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 07a760e91aed..3aedc0767b0a 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -41,7 +41,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): self.sync_handler = self.hs.get_sync_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' @@ -69,7 +69,7 @@ def test_wait_for_sync_for_user_auth_blocking(self): self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.auth_blocking._hs_disabled = False @@ -80,7 +80,7 @@ def test_wait_for_sync_for_user_auth_blocking(self): self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_unknown_room_version(self): """ @@ -122,7 +122,7 @@ def test_unknown_room_version(self): b"{}", tok, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # The rooms should appear in the sync response. result = self.get_success( @@ -248,7 +248,7 @@ def test_ban_wins_race_with_join(self): # the prev_events used when creating the join event, such that the ban does not # precede the join. mocked_get_prev_events = patch.object( - self.hs.get_datastore(), + self.hs.get_datastores().main, "get_prev_events_for_room", new_callable=MagicMock, return_value=make_awaitable([last_room_creation_event_id]), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 000f9b9fde2b..f91a80b9fa57 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -91,7 +91,7 @@ def prepare(self, reactor, clock, hs): self.event_source = hs.get_event_sources().sources.typing - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.datastore.get_destination_retry_timings = Mock( return_value=defer.succeed(None) ) @@ -156,7 +156,7 @@ async def get_users_in_room(room_id): def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.started_typing( @@ -169,13 +169,13 @@ def test_started_typing_local(self): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -220,7 +220,7 @@ def test_started_typing_remote_send(self): def test_started_typing_remote_recv(self): self.room_members = [U_APPLE, U_ONION] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) channel = self.make_request( "PUT", @@ -239,13 +239,13 @@ def test_started_typing_remote_recv(self): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -259,7 +259,7 @@ def test_started_typing_remote_recv(self): def test_started_typing_remote_recv_not_in_room(self): self.room_members = [U_APPLE, U_ONION] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) channel = self.make_request( "PUT", @@ -278,7 +278,7 @@ def test_started_typing_remote_recv_not_in_room(self): self.on_new_event.assert_not_called() - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -288,8 +288,8 @@ def test_started_typing_remote_recv_not_in_room(self): is_guest=False, ) ) - self.assertEquals(events[0], []) - self.assertEquals(events[1], 0) + self.assertEqual(events[0], []) + self.assertEqual(events[1], 0) @override_config({"send_federation": True}) def test_stopped_typing(self): @@ -302,7 +302,7 @@ def test_stopped_typing(self): self.handler._member_typing_until[member] = 1002000 self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.stopped_typing( @@ -332,13 +332,13 @@ def test_stopped_typing(self): try_trailing_slash_on_400=True, ) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) @@ -346,7 +346,7 @@ def test_stopped_typing(self): def test_typing_timeout(self): self.room_members = [U_APPLE, U_BANANA] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.started_typing( @@ -360,7 +360,7 @@ def test_typing_timeout(self): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -370,7 +370,7 @@ def test_typing_timeout(self): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -385,7 +385,7 @@ def test_typing_timeout(self): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 2) + self.assertEqual(self.event_source.get_current_key(), 2) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -395,7 +395,7 @@ def test_typing_timeout(self): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) @@ -414,7 +414,7 @@ def test_typing_timeout(self): self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() - self.assertEquals(self.event_source.get_current_key(), 3) + self.assertEqual(self.event_source.get_current_key(), 3) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -424,7 +424,7 @@ def test_typing_timeout(self): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 482c90ef68c5..92012cd6f7e6 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -77,7 +77,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.event_creation_handler = self.hs.get_event_creation_handler() @@ -1042,7 +1042,7 @@ def test_disabling_room_list(self) -> None: b'{"search_term":"user2"}', access_token=u1_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) > 0) # Disable user directory and check search returns nothing @@ -1053,5 +1053,5 @@ def test_disabling_room_list(self) -> None: b'{"search_term":"user2"}', access_token=u1_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index c49be33b9f7c..77ce8432ac53 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -65,9 +65,9 @@ def do_lookup(): servers = self.successResultOf(test_d) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, host_name) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) + self.assertEqual(servers[0].host, host_name) @defer.inlineCallbacks def test_from_cache_expired_and_dns_fail(self): @@ -88,8 +88,8 @@ def test_from_cache_expired_and_dns_fail(self): dns_client_mock.lookupService.assert_called_once_with(service_name) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks def test_from_cache(self): @@ -114,8 +114,8 @@ def test_from_cache(self): self.assertFalse(dns_client_mock.lookupService.called) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks def test_empty_cache(self): @@ -144,8 +144,8 @@ def test_name_error(self): servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) - self.assertEquals(len(servers), 0) - self.assertEquals(len(cache), 0) + self.assertEqual(len(servers), 0) + self.assertEqual(len(cache), 0) def test_disabled_service(self): """ @@ -201,6 +201,6 @@ def test_non_srv_answer(self): servers = self.successResultOf(resolve_d) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, b"host") + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) + self.assertEqual(servers[0].host, b"host") diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index d16cd141a750..c3f20f9692ac 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -41,7 +41,7 @@ class ModuleApiTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() self.sync_handler = homeserver.get_sync_handler() diff --git a/tests/push/test_email.py b/tests/push/test_email.py index f8cba7b64584..7a3b0d675592 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -102,13 +102,13 @@ def prepare(self, reactor, clock, hs): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(self.access_token) + self.hs.get_datastores().main.get_user_by_access_token(self.access_token) ) self.token_id = user_tuple.token_id # We need to add email to account before we can create a pusher. self.get_success( - hs.get_datastore().user_add_threepid( + hs.get_datastores().main.user_add_threepid( self.user_id, "email", "a@example.com", 0, 0 ) ) @@ -128,7 +128,7 @@ def prepare(self, reactor, clock, hs): ) self.auth_handler = hs.get_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated @@ -375,7 +375,7 @@ def test_no_email_sent_after_removed(self): # check that the pusher for that email address has been deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -388,14 +388,14 @@ def test_remove_unlinked_pushers_background_job(self): # This resembles the old behaviour, which the background update below is intended # to clean up. self.get_success( - self.hs.get_datastore().user_delete_threepid( + self.hs.get_datastores().main.user_delete_threepid( self.user_id, "email", "a@example.com" ) ) # Run the "remove_deleted_email_pushers" background job self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="background_updates", values={ "update_name": "remove_deleted_email_pushers", @@ -406,14 +406,14 @@ def test_remove_unlinked_pushers_background_job(self): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.hs.get_datastore().db_pool.updates._all_done = False + self.hs.get_datastores().main.db_pool.updates._all_done = False # Now let's actually drive the updates to completion self.wait_for_background_updates() # Check that all pushers with unlinked addresses were deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -428,7 +428,7 @@ def _check_for_mail(self) -> Tuple[Sequence, Dict]: """ # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -439,7 +439,7 @@ def _check_for_mail(self) -> Tuple[Sequence, Dict]: # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -458,7 +458,7 @@ def _check_for_mail(self) -> Tuple[Sequence, Dict]: # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index c068d329a98b..c284beb37ce0 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -62,7 +62,7 @@ def test_invalid_configuration(self): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -108,7 +108,7 @@ def test_sends_http(self): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -138,7 +138,7 @@ def test_sends_http(self): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -149,7 +149,7 @@ def test_sends_http(self): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -170,7 +170,7 @@ def test_sends_http(self): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -192,7 +192,7 @@ def test_sends_http(self): # The stream ordering has increased, again pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -224,7 +224,7 @@ def test_sends_high_priority_for_encrypted(self): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -344,7 +344,7 @@ def test_sends_high_priority_for_one_to_one_only(self): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -430,7 +430,7 @@ def test_sends_high_priority_for_mention(self): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -507,7 +507,7 @@ def test_sends_high_priority_for_atroom(self): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -571,9 +571,7 @@ def test_push_unread_count_group_by_room(self): # Carry out our option-value specific test # # This push should still only contain an unread count of 1 (for 1 unread room) - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 1 - ) + self._check_push_attempt(6, 1) @override_config({"push": {"group_unread_count_by_room": False}}) def test_push_unread_count_message_count(self): @@ -585,11 +583,9 @@ def test_push_unread_count_message_count(self): # Carry out our option-value specific test # - # We're counting every unread message, so there should now be 4 since the + # We're counting every unread message, so there should now be 3 since the # last read receipt - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 4 - ) + self._check_push_attempt(6, 3) def _test_push_unread_count(self): """ @@ -597,8 +593,9 @@ def _test_push_unread_count(self): Note that: * Sending messages will cause push notifications to go out to relevant users - * Sending a read receipt will cause a "badge update" notification to go out to - the user that sent the receipt + * Sending a read receipt will cause the HTTP pusher to check whether the unread + count has changed since the last push notification. If so, a "badge update" + notification goes out to the user that sent the receipt """ # Register the user who gets notified user_id = self.register_user("user", "pass") @@ -616,7 +613,7 @@ def _test_push_unread_count(self): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -642,24 +639,74 @@ def _test_push_unread_count(self): # position in the room. We'll set the read position to this event in a moment first_message_event_id = response["event_id"] - # Advance time a bit (so the pusher will register something has happened) and - # make the push succeed - self.push_attempts[0][0].callback({}) + expected_push_attempts = 1 + self._check_push_attempt(expected_push_attempts, 0) + + self._send_read_request(access_token, first_message_event_id, room_id) + + # Unread count has not changed. Therefore, ensure that read request does not + # trigger a push notification. + self.assertEqual(len(self.push_attempts), 1) + + # Send another message + response2 = self.helper.send( + room_id, body="How's the weather today?", tok=other_access_token + ) + second_message_event_id = response2["event_id"] + + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 1) + + self._send_read_request(access_token, second_message_event_id, room_id) + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 0) + + # If we're grouping by room, sending more messages shouldn't increase the + # unread count, as they're all being sent in the same room. Otherwise, it + # should. Therefore, the last call to _check_push_attempt is done in the + # caller method. + self.helper.send(room_id, body="Hello?", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="Hello??", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="HELLO???", tok=other_access_token) + + def _advance_time_and_make_push_succeed(self, expected_push_attempts): self.pump() + self.push_attempts[expected_push_attempts - 1][0].callback({}) + def _check_push_attempt( + self, expected_push_attempts: int, expected_unread_count_last_push: int + ) -> None: + """ + Makes sure that the last expected push attempt succeeds and checks whether + it contains the expected unread count. + """ + self._advance_time_and_make_push_succeed(expected_push_attempts) # Check our push made it - self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(len(self.push_attempts), expected_push_attempts) + _, push_url, push_body = self.push_attempts[expected_push_attempts - 1] self.assertEqual( - self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + push_url, + "http://example.com/_matrix/push/v1/notify", ) - # Check that the unread count for the room is 0 # # The unread count is zero as the user has no read receipt in the room yet self.assertEqual( - self.push_attempts[0][2]["notification"]["counts"]["unread"], 0 + push_body["notification"]["counts"]["unread"], + expected_unread_count_last_push, ) + def _send_read_request(self, access_token, message_event_id, room_id): # Now set the user's read receipt position to the first event # # This will actually trigger a new notification to be sent out so that @@ -667,56 +714,8 @@ def _test_push_unread_count(self): # count goes down channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), + "/rooms/%s/receipt/m.read/%s" % (room_id, message_event_id), {}, access_token=access_token, ) self.assertEqual(channel.code, 200, channel.json_body) - - # Advance time and make the push succeed - self.push_attempts[1][0].callback({}) - self.pump() - - # Unread count is still zero as we've read the only message in the room - self.assertEqual(len(self.push_attempts), 2) - self.assertEqual( - self.push_attempts[1][2]["notification"]["counts"]["unread"], 0 - ) - - # Send another message - self.helper.send( - room_id, body="How's the weather today?", tok=other_access_token - ) - - # Advance time and make the push succeed - self.push_attempts[2][0].callback({}) - self.pump() - - # This push should contain an unread count of 1 as there's now been one - # message since our last read receipt - self.assertEqual(len(self.push_attempts), 3) - self.assertEqual( - self.push_attempts[2][2]["notification"]["counts"]["unread"], 1 - ) - - # Since we're grouping by room, sending more messages shouldn't increase the - # unread count, as they're all being sent in the same room - self.helper.send(room_id, body="Hello?", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[3][0].callback({}) - - self.helper.send(room_id, body="Hello??", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[4][0].callback({}) - - self.helper.send(room_id, body="HELLO???", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[5][0].callback({}) - - self.assertEqual(len(self.push_attempts), 6) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 32501a374b5b..4137d7da4c35 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -14,6 +14,8 @@ from typing import Any, Dict +import frozendict + from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent from synapse.push import push_rule_evaluator @@ -191,6 +193,13 @@ def test_event_match_non_body(self): "pattern should only match at the start/end of the value", ) + # it should work on frozendicts too + self._assert_matches( + condition, + frozendict.frozendict({"value": "FoobaZ"}), + "patterns should match on frozendicts", + ) + # wildcards should match condition = { "kind": "event_match", diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 9fc50f885227..a7a05a564fe9 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -68,7 +68,7 @@ def prepare(self, reactor, clock, hs): # Since we use sqlite in memory databases we need to make sure the # databases objects are the same. - self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool + self.worker_hs.get_datastores().main.db_pool = hs.get_datastores().main.db_pool # Normally we'd pass in the handler to `setup_test_homeserver`, which would # eventually hit "Install @cache_in_self attributes" in tests/utils.py. @@ -233,7 +233,7 @@ def setUp(self): # We may have an attempt to connect to redis for the external cache already. self.connect_any_redis_attempts() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" @@ -332,7 +332,7 @@ def make_worker_hs( lambda: self._handle_http_replication_attempt(worker_hs, port), ) - store = worker_hs.get_datastore() + store = worker_hs.get_datastores().main store.db_pool._db_pool = self.database_pool._db_pool # Set up TCP replication between master and the new worker if we don't diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 83e89383f64f..85be79d19d48 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -30,8 +30,8 @@ def prepare(self, reactor, clock, hs): self.reconnect() - self.master_store = hs.get_datastore() - self.slaved_store = self.worker_hs.get_datastore() + self.master_store = hs.get_datastores().main + self.slaved_store = self.worker_hs.get_datastores().main self.storage = hs.get_storage() def replicate(self): diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 1f4e09c02b6f..110d11709b73 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -59,7 +59,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def setUp(self): # Patch up the equality operator for events so that we can check - # whether lists of events match using assertEquals + # whether lists of events match using assertEqual self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super().setUp() diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py index cdd052001b6d..50fbff5f324f 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py @@ -23,7 +23,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_room_account_data_limit(self): """Test replication with many room account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] @@ -69,7 +69,7 @@ def test_update_function_room_account_data_limit(self): def test_update_function_global_account_data_limit(self): """Test replication with many global account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f198a9488746..f9d5da723cce 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -136,7 +136,7 @@ def test_update_function_huge_state_change(self): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events = [ @@ -291,7 +291,7 @@ def test_update_function_state_row_limit(self): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events: List[EventBase] = [] diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 38e292c1ab67..eb0011784518 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -32,7 +32,7 @@ def test_receipt(self): # tell the master to send a new receipt self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) @@ -56,7 +56,7 @@ def test_receipt(self): self.test_handler.on_rdata.reset_mock() self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} ) ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 92a5b53e11e7..ba1a63c0d667 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -204,7 +204,7 @@ def test_send_typing_sharded(self): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main federation = self.hs.get_federation_event_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 4094a75f363c..8f4f6688ce86 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -50,7 +50,7 @@ def _create_pusher_and_send_msg(self, localpart): # Register a pusher user_dict = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_dict.token_id diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 596ba5a0c9f4..5f142e84c359 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -47,7 +47,7 @@ def prepare(self, reactor, clock, hs): self.other_access_token = self.login("otheruser", "pass") self.room_creator = self.hs.get_room_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def default_config(self): conf = super().default_config() @@ -99,7 +99,7 @@ def test_basic(self): persisted_on_1 = False persisted_on_2 = False - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") @@ -166,7 +166,7 @@ def test_vector_clock_token(self): user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create two room on the different workers. self._create_room(room_id1, user_id, access_token) @@ -194,7 +194,7 @@ def test_vector_clock_token(self): # # Worker2's event stream position will not advance until we call # __aexit__ again. - worker_store2 = worker_hs2.get_datastore() + worker_store2 = worker_hs2.get_datastores().main assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) actx = worker_store2._stream_id_gen.get_next() diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 1e3fe9c62c00..fb36aa994090 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -36,7 +36,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 71068d16cd18..929bbdc37da0 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -35,7 +35,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -537,7 +537,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 86aff7575c9a..0d47dd0aff73 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -634,7 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.admin_user = self.register_user("admin", "pass", admin=True) @@ -767,7 +767,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 8513b1d2df53..8354250ec21e 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -34,7 +34,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 23da0ad73679..95282f078e77 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -50,7 +50,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -465,7 +465,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1909,7 +1909,7 @@ def test_join_public_room(self) -> None: "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self) -> None: @@ -1957,7 +1957,7 @@ def test_join_private_room_if_member(self) -> None: "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1980,7 +1980,7 @@ def test_join_private_room_if_member(self) -> None: "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self) -> None: @@ -2010,7 +2010,7 @@ def test_join_private_room_if_owner(self) -> None: "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self) -> None: @@ -2044,7 +2044,7 @@ def test_context_as_non_admin(self) -> None: % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self) -> None: @@ -2074,8 +2074,8 @@ def test_context_as_admin(self) -> None: % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEquals( + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -2239,7 +2239,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 3c59f5f766bc..2c855bff99c7 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -38,7 +38,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() self.server_notices_manager = self.hs.get_server_notices_manager() diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 272637e96575..a60ea0a56305 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -410,7 +410,7 @@ def test_register_mau_limit_reached(self) -> None: even if the MAU limit is reached. """ handler = self.hs.get_registration_handler() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Set monthly active users to the limit store.get_monthly_active_count = Mock( @@ -455,7 +455,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -913,7 +913,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1167,7 +1167,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() # create users and get access tokens @@ -2609,7 +2609,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2737,7 +2737,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository_resource() self.filepaths = MediaFilePaths(hs.config.media.media_store_path) @@ -3317,7 +3317,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3609,7 +3609,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3687,7 +3687,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3913,7 +3913,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 89d85b0a1701..6c4462e74a53 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1,6 +1,4 @@ -# Copyright 2015-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,16 +15,22 @@ import os import re from email.parser import Parser -from typing import Optional +from typing import Dict, List, Optional +from unittest.mock import Mock import pkg_resources +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService +from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -73,7 +77,7 @@ async def sendmail( return hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): @@ -100,7 +104,7 @@ def test_basic_password_reset(self): client_secret = "foobar" session_id = self._request_token(email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -139,7 +143,7 @@ def reset(ip): client_secret = "foobar" session_id = self._request_token(email, client_secret, ip) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -189,7 +193,7 @@ def test_basic_password_reset_canonicalise_email(self): client_secret = "foobar" session_id = self._request_token(email_passwort_reset, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -226,7 +230,7 @@ def test_cant_reset_password_without_clicking_link(self): client_secret = "foobar" session_id = self._request_token(email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) # Attempt to reset password without clicking the link self._reset_password(new_password, session_id, client_secret, expected_code=401) @@ -318,7 +322,7 @@ def _validate_token(self, link): shorthand=False, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button @@ -333,7 +337,7 @@ def _validate_token(self, link): shorthand=False, content_is_form=True, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def _get_link_from_email(self): assert self.email_attempts, "No emails have been sent" @@ -372,7 +376,7 @@ def _reset_password( }, }, ) - self.assertEquals(expected_code, channel.code, channel.result) + self.assertEqual(expected_code, channel.code, channel.result) class DeactivateTestCase(unittest.HomeserverTestCase): @@ -394,7 +398,7 @@ def test_deactivate_account(self): self.deactivate(user_id, tok) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user has been marked as deactivated. self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) @@ -405,7 +409,7 @@ def test_deactivate_account(self): def test_pending_invites(self): """Tests that deactivating a user rejects every pending invite for them.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main inviter_id = self.register_user("inviter", "test") inviter_tok = self.login("inviter", "test") @@ -486,8 +490,9 @@ def test_GET_whoami(self): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) @@ -505,8 +510,9 @@ def test_GET_whoami_guests(self): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": True, + "is_guest": True, }, ) @@ -521,15 +527,16 @@ def test_GET_whoami_appservices(self): namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) whoami = self._whoami(as_token) self.assertEqual( whoami, { "user_id": user_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) self.assertFalse(hasattr(whoami, "device_id")) @@ -579,7 +586,7 @@ async def sendmail( return self.hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("kermit", "test") self.user_id_tok = self.login("kermit", "test") @@ -669,7 +676,7 @@ def test_add_email_if_disabled(self): client_secret = "foobar" session_id = self._request_token(self.email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -773,7 +780,7 @@ def test_cant_add_email_without_clicking_link(self): client_secret = "foobar" session_id = self._request_token(self.email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) # Attempt to add email without clicking the link channel = self.make_request( @@ -974,7 +981,7 @@ def _validate_token(self, link): path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def _get_link_from_email(self): assert self.email_attempts, "No emails have been sent" @@ -1003,7 +1010,7 @@ def _add_email(self, request_email, expected_email): client_secret = "foobar" session_id = self._request_token(request_email, client_secret) - self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) + self.assertEqual(len(self.email_attempts) - previous_email_attempts, 1) link = self._get_link_from_email() self._validate_token(link) @@ -1037,3 +1044,197 @@ def _add_email(self, request_email, expected_email): threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} self.assertIn(expected_email, threepids) + + +class AccountStatusTestCase(unittest.HomeserverTestCase): + servlets = [ + account.register_servlets, + admin.register_servlets, + login.register_servlets, + ] + + url = "/_matrix/client/unstable/org.matrix.msc3720/account_status" + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["experimental_features"] = {"msc3720_enabled": True} + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + self.requester = self.register_user("requester", "password") + self.requester_tok = self.login("requester", "password") + self.server_name = homeserver.config.server.server_name + + def test_missing_mxid(self): + """Tests that not providing any MXID raises an error.""" + self._test_status( + users=None, + expected_status_code=400, + expected_errcode=Codes.MISSING_PARAM, + ) + + def test_invalid_mxid(self): + """Tests that providing an invalid MXID raises an error.""" + self._test_status( + users=["bad:test"], + expected_status_code=400, + expected_errcode=Codes.INVALID_PARAM, + ) + + def test_local_user_not_exists(self): + """Tests that the account status endpoints correctly reports that a user doesn't + exist. + """ + user = "@unknown:" + self.hs.config.server.server_name + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": False, + }, + }, + expected_failures=[], + ) + + def test_local_user_exists(self): + """Tests that the account status endpoint correctly reports that a user doesn't + exist. + """ + user = self.register_user("someuser", "password") + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": False, + }, + }, + expected_failures=[], + ) + + def test_local_user_deactivated(self): + """Tests that the account status endpoint correctly reports a deactivated user.""" + user = self.register_user("someuser", "password") + self.get_success( + self.hs.get_datastores().main.set_user_deactivated_status( + user, deactivated=True + ) + ) + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": True, + }, + }, + expected_failures=[], + ) + + def test_mixed_local_and_remote_users(self): + """Tests that if some users are remote the account status endpoint correctly + merges the remote responses with the local result. + """ + # We use 3 users: one doesn't exist but belongs on the local homeserver, one is + # deactivated and belongs on one remote homeserver, and one belongs to another + # remote homeserver that didn't return any result (the federation code should + # mark that user as a failure). + users = [ + "@unknown:" + self.hs.config.server.server_name, + "@deactivated:remote", + "@failed:otherremote", + "@bad:badremote", + ] + + async def post_json(destination, path, data, *a, **kwa): + if destination == "remote": + return { + "account_statuses": { + users[1]: { + "exists": True, + "deactivated": True, + }, + } + } + if destination == "otherremote": + return {} + if destination == "badremote": + # badremote tries to overwrite the status of a user that doesn't belong + # to it (i.e. users[1]) with false data, which Synapse is expected to + # ignore. + return { + "account_statuses": { + users[3]: { + "exists": False, + }, + users[1]: { + "exists": False, + }, + } + } + + # Register a mock that will return the expected result depending on the remote. + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) + + # Check that we've got the correct response from the client-side endpoint. + self._test_status( + users=users, + expected_statuses={ + users[0]: { + "exists": False, + }, + users[1]: { + "exists": True, + "deactivated": True, + }, + users[3]: { + "exists": False, + }, + }, + expected_failures=[users[2]], + ) + + def _test_status( + self, + users: Optional[List[str]], + expected_status_code: int = 200, + expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, + expected_failures: Optional[List[str]] = None, + expected_errcode: Optional[str] = None, + ): + """Send a request to the account status endpoint and check that the response + matches with what's expected. + + Args: + users: The account(s) to request the status of, if any. If set to None, no + `user_id` query parameter will be included in the request. + expected_status_code: The expected HTTP status code. + expected_statuses: The expected account statuses, if any. + expected_failures: The expected failures, if any. + expected_errcode: The expected Matrix error code, if any. + """ + content = {} + if users is not None: + content["user_ids"] = users + + channel = self.make_request( + method="POST", + path=self.url, + content=content, + access_token=self.requester_tok, + ) + + self.assertEqual(channel.code, expected_status_code) + + if expected_statuses is not None: + self.assertEqual(channel.json_body["account_statuses"], expected_statuses) + + if expected_failures is not None: + self.assertEqual(channel.json_body["failures"], expected_failures) + + if expected_errcode is not None: + self.assertEqual(channel.json_body["errcode"], expected_errcode) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 27cb856b0acd..9653f458379f 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -13,16 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from typing import Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from twisted.internet.defer import succeed +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client import account, auth, devices, login, register +from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC @@ -32,11 +37,11 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): - def __init__(self, hs): + def __init__(self, hs: HomeServer) -> None: super().__init__(hs) - self.recaptcha_attempts = [] + self.recaptcha_attempts: List[Tuple[dict, str]] = [] - def check_auth(self, authdict, clientip): + def check_auth(self, authdict: dict, clientip: str) -> Any: self.recaptcha_attempts.append((authdict, clientip)) return succeed(True) @@ -49,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ] hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() @@ -60,7 +65,7 @@ def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(config=config) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker @@ -100,7 +105,7 @@ def recaptcha( self.assertEqual(len(attempts), 1) self.assertEqual(attempts[0][0]["response"], "a") - def test_fallback_captcha(self): + def test_fallback_captcha(self) -> None: """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( @@ -131,7 +136,7 @@ def test_fallback_captcha(self): # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_complete_operation_unknown_session(self): + def test_complete_operation_unknown_session(self) -> None: """ Attempting to mark an invalid session as complete should error. """ @@ -164,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase): register.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns @@ -181,12 +186,12 @@ def default_config(self): return config - def create_resource_dict(self): + def create_resource_dict(self) -> Dict[str, Resource]: resource_dict = super().create_resource_dict() resource_dict.update(build_synapse_client_resource_tree(self.hs)) return resource_dict - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.device_id = "dev1" @@ -228,7 +233,7 @@ def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: return channel - def test_ui_auth(self): + def test_ui_auth(self) -> None: """ Test user interactive authentication outside of registration. """ @@ -258,7 +263,7 @@ def test_ui_auth(self): }, ) - def test_grandfathered_identifier(self): + def test_grandfathered_identifier(self) -> None: """Check behaviour without "identifier" dict Synapse used to require clients to submit a "user" field for m.login.password @@ -285,7 +290,7 @@ def test_grandfathered_identifier(self): }, ) - def test_can_change_body(self): + def test_can_change_body(self) -> None: """ The client dict can be modified during the user interactive authentication session. @@ -324,7 +329,7 @@ def test_can_change_body(self): }, ) - def test_cannot_change_uri(self): + def test_cannot_change_uri(self) -> None: """ The initial requested URI cannot be modified during the user interactive authentication session. """ @@ -361,7 +366,7 @@ def test_cannot_change_uri(self): ) @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) - def test_can_reuse_session(self): + def test_can_reuse_session(self) -> None: """ The session can be reused if configured. @@ -408,7 +413,7 @@ def test_can_reuse_session(self): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_via_sso(self): + def test_ui_auth_via_sso(self) -> None: """Test a successful UI Auth flow via SSO This includes: @@ -451,7 +456,7 @@ def test_ui_auth_via_sso(self): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_does_not_offer_password_for_sso_user(self): + def test_does_not_offer_password_for_sso_user(self) -> None: login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -463,7 +468,7 @@ def test_does_not_offer_password_for_sso_user(self): flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) - def test_does_not_offer_sso_for_password_user(self): + def test_does_not_offer_sso_for_password_user(self) -> None: channel = self.delete_device( self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED ) @@ -473,7 +478,7 @@ def test_does_not_offer_sso_for_password_user(self): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_offers_both_flows_for_upgraded_user(self): + def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) @@ -490,7 +495,7 @@ def test_offers_both_flows_for_upgraded_user(self): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_fails_for_incorrect_sso_user(self): + def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" # log the user in login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) @@ -527,12 +532,13 @@ class RefreshAuthTests(unittest.HomeserverTestCase): auth.register_servlets, account.register_servlets, login.register_servlets, + logout.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) @@ -546,7 +552,7 @@ def use_refresh_token(self, refresh_token: str) -> FakeChannel: {"refresh_token": refresh_token}, ) - def is_access_token_valid(self, access_token) -> bool: + def is_access_token_valid(self, access_token: str) -> bool: """ Checks whether an access token is valid, returning whether it is or not. """ @@ -559,7 +565,7 @@ def is_access_token_valid(self, access_token) -> bool: return code == HTTPStatus.OK - def test_login_issue_refresh_token(self): + def test_login_issue_refresh_token(self) -> None: """ A login response should include a refresh_token only if asked. """ @@ -589,7 +595,7 @@ def test_login_issue_refresh_token(self): self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) - def test_register_issue_refresh_token(self): + def test_register_issue_refresh_token(self) -> None: """ A register response should include a refresh_token only if asked. """ @@ -625,7 +631,7 @@ def test_register_issue_refresh_token(self): self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) - def test_token_refresh(self): + def test_token_refresh(self) -> None: """ A refresh token can be used to issue a new access token. """ @@ -663,7 +669,7 @@ def test_token_refresh(self): ) @override_config({"refreshable_access_token_lifetime": "1m"}) - def test_refreshable_access_token_expiration(self): + def test_refreshable_access_token_expiration(self) -> None: """ The access token should have some time as specified in the config. """ @@ -720,7 +726,9 @@ def test_refreshable_access_token_expiration(self): "nonrefreshable_access_token_lifetime": "10m", } ) - def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): + def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens( + self, + ) -> None: """ Tests that the expiry times for refreshable and non-refreshable access tokens can be different. @@ -780,7 +788,7 @@ def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self) @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} ) - def test_refresh_token_expiry(self): + def test_refresh_token_expiry(self) -> None: """ The refresh token can be configured to have a limited lifetime. When that lifetime has ended, the refresh token can no longer be used to @@ -832,7 +840,7 @@ def test_refresh_token_expiry(self): "session_lifetime": "3m", } ) - def test_ultimate_session_expiry(self): + def test_ultimate_session_expiry(self) -> None: """ The session can be configured to have an ultimate, limited lifetime. """ @@ -880,7 +888,7 @@ def test_ultimate_session_expiry(self): refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result ) - def test_refresh_token_invalidation(self): + def test_refresh_token_invalidation(self) -> None: """Refresh tokens are invalidated after first use of the next token. A refresh token is considered invalid if: @@ -984,3 +992,90 @@ def test_refresh_token_invalidation(self): self.assertEqual( fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) + + def test_many_token_refresh(self) -> None: + """ + If a refresh is performed many times during a session, there shouldn't be + extra 'cruft' built up over time. + + This test was written specifically to troubleshoot a case where logout + was very slow if a lot of refreshes had been performed for the session. + """ + + def _refresh(refresh_token: str) -> Tuple[str, str]: + """ + Performs one refresh, returning the next refresh token and access token. + """ + refresh_response = self.use_refresh_token(refresh_token) + self.assertEqual( + refresh_response.code, HTTPStatus.OK, refresh_response.result + ) + return ( + refresh_response.json_body["refresh_token"], + refresh_response.json_body["access_token"], + ) + + def _table_length(table_name: str) -> int: + """ + Helper to get the size of a table, in rows. + For testing only; trivially vulnerable to SQL injection. + """ + + def _txn(txn: LoggingTransaction) -> int: + txn.execute(f"SELECT COUNT(1) FROM {table_name}") + row = txn.fetchone() + # Query is infallible + assert row is not None + return row[0] + + return self.get_success( + self.hs.get_datastores().main.db_pool.runInteraction( + "_table_length", _txn + ) + ) + + # Before we log in, there are no access tokens. + self.assertEqual(_table_length("access_tokens"), 0) + self.assertEqual(_table_length("refresh_tokens"), 0) + + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + "refresh_token": True, + } + login_response = self.make_request( + "POST", + "/_matrix/client/v3/login", + body, + ) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) + + access_token = login_response.json_body["access_token"] + refresh_token = login_response.json_body["refresh_token"] + + # Now that we have logged in, there should be one access token and one + # refresh token + self.assertEqual(_table_length("access_tokens"), 1) + self.assertEqual(_table_length("refresh_tokens"), 1) + + for _ in range(5): + refresh_token, access_token = _refresh(refresh_token) + + # After 5 sequential refreshes, there should only be the latest two + # refresh/access token pairs. + # (The last one is preserved because it's in use! + # The one before that is preserved because it can still be used to + # replace the last token pair, in case of e.g. a network interruption.) + self.assertEqual(_table_length("access_tokens"), 2) + self.assertEqual(_table_length("refresh_tokens"), 2) + + logout_response = self.make_request( + "POST", "/_matrix/client/v3/logout", {}, access_token=access_token + ) + self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result) + + # Now that we have logged in, there should be no access token + # and no refresh token + self.assertEqual(_table_length("access_tokens"), 0) + self.assertEqual(_table_length("refresh_tokens"), 0) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 989e80176820..d1751e1557f7 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -13,9 +13,13 @@ # limitations under the License. from http import HTTPStatus +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.rest.client import capabilities, login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -29,24 +33,24 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.url = b"/capabilities" hs = self.setup_test_homeserver() self.config = hs.config self.auth_handler = hs.get_auth_handler() return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.localpart = "user" self.password = "pass" self.user = self.register_user(self.localpart, self.password) - def test_check_auth_required(self): + def test_check_auth_required(self) -> None: channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) - def test_get_room_version_capabilities(self): + def test_get_room_version_capabilities(self) -> None: access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) @@ -61,7 +65,7 @@ def test_get_room_version_capabilities(self): capabilities["m.room_versions"]["default"], ) - def test_get_change_password_capabilities_password_login(self): + def test_get_change_password_capabilities_password_login(self) -> None: access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) @@ -71,7 +75,7 @@ def test_get_change_password_capabilities_password_login(self): self.assertTrue(capabilities["m.change_password"]["enabled"]) @override_config({"password_config": {"localdb_enabled": False}}) - def test_get_change_password_capabilities_localdb_disabled(self): + def test_get_change_password_capabilities_localdb_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -85,7 +89,7 @@ def test_get_change_password_capabilities_localdb_disabled(self): self.assertFalse(capabilities["m.change_password"]["enabled"]) @override_config({"password_config": {"enabled": False}}) - def test_get_change_password_capabilities_password_disabled(self): + def test_get_change_password_capabilities_password_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -98,7 +102,7 @@ def test_get_change_password_capabilities_password_disabled(self): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_change_users_attributes_capabilities(self): + def test_get_change_users_attributes_capabilities(self) -> None: """Test that server returns capabilities by default.""" access_token = self.login(self.localpart, self.password) @@ -112,7 +116,7 @@ def test_get_change_users_attributes_capabilities(self): self.assertTrue(capabilities["m.3pid_changes"]["enabled"]) @override_config({"enable_set_displayname": False}) - def test_get_set_displayname_capabilities_displayname_disabled(self): + def test_get_set_displayname_capabilities_displayname_disabled(self) -> None: """Test if set displayname is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -123,7 +127,7 @@ def test_get_set_displayname_capabilities_displayname_disabled(self): self.assertFalse(capabilities["m.set_displayname"]["enabled"]) @override_config({"enable_set_avatar_url": False}) - def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None: """Test if set avatar_url is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -134,7 +138,7 @@ def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) @override_config({"enable_3pid_changes": False}) - def test_get_change_3pid_capabilities_3pid_disabled(self): + def test_get_change_3pid_capabilities_3pid_disabled(self) -> None: """Test if change 3pid is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -145,7 +149,7 @@ def test_get_change_3pid_capabilities_3pid_disabled(self): self.assertFalse(capabilities["m.3pid_changes"]["enabled"]) @override_config({"experimental_features": {"msc3244_enabled": False}}) - def test_get_does_not_include_msc3244_fields_when_disabled(self): + def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -160,7 +164,7 @@ def test_get_does_not_include_msc3244_fields_when_disabled(self): "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] ) - def test_get_does_include_msc3244_fields_when_enabled(self): + def test_get_does_include_msc3244_fields_when_enabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index fcdc565814cf..b1ca81a91191 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -13,11 +13,16 @@ # limitations under the License. import os +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder from synapse.rest.client import login, room from synapse.rest.consent import consent_resource +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -32,7 +37,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["form_secret"] = "123abc" @@ -56,7 +61,7 @@ def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(config=config) return hs - def test_render_public_consent(self): + def test_render_public_consent(self) -> None: """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) channel = make_request( @@ -66,9 +71,9 @@ def test_render_public_consent(self): "/consent?v=1", shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) - def test_accept_consent(self): + def test_accept_consent(self) -> None: """ A user can use the consent form to accept the terms. """ @@ -92,7 +97,7 @@ def test_accept_consent(self): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and whether we've consented version, consented = channel.result["body"].decode("ascii").split(",") @@ -107,7 +112,7 @@ def test_accept_consent(self): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Fetch the consent page, to get the consent version -- it should have # changed @@ -119,7 +124,7 @@ def test_accept_consent(self): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and check that it's the version we # agreed to, and that we've consented to it. diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py new file mode 100644 index 000000000000..a8af4e2435ac --- /dev/null +++ b/tests/rest/client/test_device_lists.py @@ -0,0 +1,159 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus + +from synapse.rest import admin, devices, room, sync +from synapse.rest.client import account, login, register + +from tests import unittest + + +class DeviceListsTestCase(unittest.HomeserverTestCase): + """Tests regarding device list changes.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + account.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def test_receiving_local_device_list_changes(self) -> None: + """Tests that a local users that share a room receive each other's device list + changes. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # Create a room for them to coexist peacefully in + new_room_id = self.helper.create_room_as( + alice_user_id, is_public=True, tok=alice_access_token + ) + self.assertIsNotNone(new_room_id) + + # Have Bob join the room + self.helper.invite( + new_room_id, alice_user_id, bob_user_id, tok=alice_access_token + ) + self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) + + # Now have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=30000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check that bob's incremental sync contains the updated device list. + # If not, the client would only receive the device list update on the + # *next* sync. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) + + def test_not_receiving_local_device_list_changes(self) -> None: + """Tests a local users DO NOT receive device updates from each other if they do not + share a room. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # These users do not share a room. They are lonely. + + # Have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=1000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check that bob's incremental sync does not contain the updated device list. + bob_sync_channel.await_result() + self.assertEqual( + bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body + ) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertNotIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 3d7aa8ec8683..9fa1f82dfee0 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -11,9 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -27,7 +34,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_ephemeral_messages"] = True @@ -35,10 +42,10 @@ def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver(config=config) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_message_expiry_no_delay(self): + def test_message_expiry_no_delay(self) -> None: """Tests that sending a message sent with a m.self_destruct_after field set to the past results in that event being deleted right away. """ @@ -61,7 +68,7 @@ def test_message_expiry_no_delay(self): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def test_message_expiry_delay(self): + def test_message_expiry_delay(self) -> None: """Tests that sending a message with a m.self_destruct_after field set to the future results in that event not being deleted right away, but advancing the clock to after that expiry timestamp causes the event to be deleted. @@ -89,7 +96,9 @@ def test_message_expiry_delay(self): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def get_event(self, room_id, event_id, expected_code=200): + def get_event( + self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK + ) -> JsonDict: url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) channel = self.make_request("GET", url) diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index a90294003eac..1b1392fa2fdd 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -16,8 +16,12 @@ from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import events, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -32,7 +36,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_registration_captcha"] = False @@ -41,11 +45,11 @@ def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(config=config) - hs.get_federation_handler = Mock() + hs.get_federation_handler = Mock() # type: ignore[assignment] return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register an account self.user_id = self.register_user("sid1", "pass") @@ -55,7 +59,7 @@ def prepare(self, reactor, clock, hs): self.other_user = self.register_user("other2", "pass") self.other_token = self.login(self.other_user, "pass") - def test_stream_basic_permissions(self): + def test_stream_basic_permissions(self) -> None: # invalid token, expect 401 # note: this is in violation of the original v1 spec, which expected # 403. However, since the v1 spec no longer exists and the v1 @@ -65,18 +69,18 @@ def test_stream_basic_permissions(self): channel = self.make_request( "GET", "/events?access_token=%s" % ("invalid" + self.token,) ) - self.assertEquals(channel.code, 401, msg=channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) # valid token, expect content channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertTrue("chunk" in channel.json_body) self.assertTrue("start" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_stream_room_permissions(self): + def test_stream_room_permissions(self) -> None: room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) self.helper.send(room_id, tok=self.other_token) @@ -89,10 +93,10 @@ def test_stream_room_permissions(self): channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # We may get a presence event for ourselves down - self.assertEquals( + self.assertEqual( 0, len( [ @@ -111,7 +115,7 @@ def test_stream_room_permissions(self): # left to room (expect no content for room) - def TODO_test_stream_items(self): + def TODO_test_stream_items(self) -> None: # new user, no content # join room, expect 1 item (join) @@ -136,7 +140,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, hs, reactor, clock): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register an account self.user_id = self.register_user("sid1", "pass") @@ -144,7 +148,7 @@ def prepare(self, hs, reactor, clock): self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) - def test_get_event_via_events(self): + def test_get_event_via_events(self) -> None: resp = self.helper.send(self.room_id, tok=self.token) event_id = resp["event_id"] @@ -153,4 +157,4 @@ def test_get_event_via_events(self): "/events/" + event_id, access_token=self.token, ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 475c6bed3d1a..5c31a54421df 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -32,7 +32,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_add_filter(self): channel = self.make_request( @@ -45,7 +45,7 @@ def test_add_filter(self): self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) self.pump() - self.assertEquals(filter.result, self.EXAMPLE_FILTER) + self.assertEqual(filter.result, self.EXAMPLE_FILTER) def test_add_filter_for_other_user(self): channel = self.make_request( @@ -55,7 +55,7 @@ def test_add_filter_for_other_user(self): ) self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine @@ -68,7 +68,7 @@ def test_add_filter_non_local_user(self): self.hs.is_mine = _is_mine self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self): filter_id = defer.ensureDeferred( @@ -83,7 +83,7 @@ def test_get_filter(self): ) self.assertEqual(channel.result["code"], b"200") - self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) + self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self): channel = self.make_request( @@ -91,7 +91,7 @@ def test_get_filter_non_existant(self): ) self.assertEqual(channel.result["code"], b"404") - self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode # in errors.py diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py index ad0425ae6591..e067cf825c62 100644 --- a/tests/rest/client/test_groups.py +++ b/tests/rest/client/test_groups.py @@ -25,13 +25,13 @@ class GroupsTestCase(unittest.HomeserverTestCase): servlets = [room.register_servlets, groups.register_servlets] @override_config({"enable_group_creation": True}) - def test_rooms_limited_by_visibility(self): + def test_rooms_limited_by_visibility(self) -> None: group_id = "+spqr:test" # Alice creates a group channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals(channel.json_body, {"group_id": group_id}) + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual(channel.json_body, {"group_id": group_id}) # Bob creates a private room room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) @@ -45,12 +45,12 @@ def test_rooms_limited_by_visibility(self): channel = self.make_request( "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} ) - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals(channel.json_body, {}) + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual(channel.json_body, {}) # Alice now tries to retrieve the room list of the space. channel = self.make_request("GET", f"/groups/{group_id}/rooms") - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals( + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual( channel.json_body, {"chunk": [], "total_room_count_estimate": 0} ) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index becb4e8dccf5..299b9d21e24b 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -13,9 +13,14 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -28,7 +33,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_3pid_lookup"] = False @@ -36,14 +41,14 @@ def make_homeserver(self, reactor, clock): return self.hs - def test_3pid_lookup_disabled(self): + def test_3pid_lookup_disabled(self) -> None: self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] params = { @@ -56,4 +61,4 @@ def test_3pid_lookup_disabled(self): channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index d7fa635eae1d..bbc8e7424351 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -28,7 +28,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def test_rejects_device_id_ice_key_outside_of_list(self): + def test_rejects_device_id_ice_key_outside_of_list(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -49,7 +49,7 @@ def test_rejects_device_id_ice_key_outside_of_list(self): channel.result, ) - def test_rejects_device_key_given_as_map_to_bool(self): + def test_rejects_device_key_given_as_map_to_bool(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -73,7 +73,7 @@ def test_rejects_device_key_given_as_map_to_bool(self): channel.result, ) - def test_requires_device_key(self): + def test_requires_device_key(self) -> None: """`device_keys` is required. We should complain if it's missing.""" self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 19f5e465372b..090d2d0a2927 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -20,6 +20,7 @@ import pymacaroons +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin @@ -27,12 +28,15 @@ from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -95,7 +99,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.hs.config.registration.enable_registration = True self.hs.config.registration.registrations_require_3pid = [] @@ -117,7 +121,7 @@ def make_homeserver(self, reactor, clock): } } ) - def test_POST_ratelimiting_per_address(self): + def test_POST_ratelimiting_per_address(self) -> None: # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -132,10 +136,10 @@ def test_POST_ratelimiting_per_address(self): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -150,7 +154,7 @@ def test_POST_ratelimiting_per_address(self): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config( { @@ -165,7 +169,7 @@ def test_POST_ratelimiting_per_address(self): } } ) - def test_POST_ratelimiting_per_account(self): + def test_POST_ratelimiting_per_account(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -177,10 +181,10 @@ def test_POST_ratelimiting_per_account(self): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -195,7 +199,7 @@ def test_POST_ratelimiting_per_account(self): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config( { @@ -210,7 +214,7 @@ def test_POST_ratelimiting_per_account(self): } } ) - def test_POST_ratelimiting_per_account_failed_attempts(self): + def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -222,10 +226,10 @@ def test_POST_ratelimiting_per_account_failed_attempts(self): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -240,16 +244,16 @@ def test_POST_ratelimiting_per_account_failed_attempts(self): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) @override_config({"session_lifetime": "24h"}) - def test_soft_logout(self): + def test_soft_logout(self) -> None: self.register_user("kermit", "monkey") # we shouldn't be able to make requests without an access token channel = self.make_request(b"GET", TEST_URL) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") # log in as normal params = { @@ -259,22 +263,22 @@ def test_soft_logout(self): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # # test behaviour after deleting the expired device @@ -286,24 +290,26 @@ def test_soft_logout(self): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # ... but if we delete that device, it will be a proper logout self._delete_device(access_token_2, "kermit", "monkey", device_id) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], False) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], False) - def _delete_device(self, access_token, user_id, password, device_id): + def _delete_device( + self, access_token: str, user_id: str, password: str, device_id: str + ) -> None: """Perform the UI-Auth to delete a device""" channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.assertEquals(channel.code, 401, channel.result) + self.assertEqual(channel.code, 401, channel.result) # check it's a UI-Auth fail self.assertEqual( set(channel.json_body.keys()), @@ -326,10 +332,10 @@ def _delete_device(self, access_token, user_id, password, device_id): access_token=access_token, content={"auth": auth}, ) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_after_being_soft_logged_out(self): + def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -337,23 +343,25 @@ def test_session_can_hard_logout_after_being_soft_logged_out(self): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard logout this session channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( + self, + ) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -361,20 +369,20 @@ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") @@ -432,7 +440,7 @@ def create_resource_dict(self) -> Dict[str, Resource]: d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_get_login_flows(self): + def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) @@ -459,12 +467,14 @@ def test_get_login_flows(self): ], ) - def test_multi_sso_redirect(self): + def test_multi_sso_redirect(self) -> None: """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker channel = self._make_sso_redirect_request(None) self.assertEqual(channel.code, 302, channel.result) - uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + uri = location_headers[0] # hitting that picker should give us some HTML channel = self.make_request("GET", uri) @@ -487,7 +497,7 @@ def test_multi_sso_redirect(self): self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) - def test_multi_sso_redirect_to_cas(self): + def test_multi_sso_redirect_to_cas(self) -> None: """If CAS is chosen, should redirect to the CAS server""" channel = self.make_request( @@ -514,7 +524,7 @@ def test_multi_sso_redirect_to_cas(self): service_uri_params = urllib.parse.parse_qs(service_uri_query) self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_saml(self): + def test_multi_sso_redirect_to_saml(self) -> None: """If SAML is chosen, should redirect to the SAML server""" channel = self.make_request( "GET", @@ -536,7 +546,7 @@ def test_multi_sso_redirect_to_saml(self): relay_state_param = saml_uri_params["RelayState"][0] self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_login_via_oidc(self): + def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" # pick the default OIDC provider @@ -604,7 +614,7 @@ def test_login_via_oidc(self): self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") - def test_multi_sso_redirect_to_unknown(self): + def test_multi_sso_redirect_to_unknown(self) -> None: """An unknown IdP should cause a 400""" channel = self.make_request( "GET", @@ -612,23 +622,25 @@ def test_multi_sso_redirect_to_unknown(self): ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_to_unknown(self): + def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" channel = self._make_sso_redirect_request("xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - def test_client_idp_redirect_to_oidc(self): + def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + oidc_uri = location_headers[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - def _make_sso_redirect_request(self, idp_prov: Optional[str] = None): + def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect ... possibly specifying an IDP provider @@ -659,7 +671,7 @@ class CASTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.base_url = "https://matrix.goodserver.com/" self.redirect_path = "_synapse/client/login/sso/redirect/confirm" @@ -675,7 +687,7 @@ def make_homeserver(self, reactor, clock): cas_user_id = "username" self.user_id = "@%s:test" % cas_user_id - async def get_raw(uri, args): + async def get_raw(uri: str, args: Any) -> bytes: """Return an example response payload from a call to the `/proxyValidate` endpoint of a CAS server, copied from https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 @@ -709,10 +721,10 @@ async def get_raw(uri, args): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.deactivate_account_handler = hs.get_deactivate_account_handler() - def test_cas_redirect_confirm(self): + def test_cas_redirect_confirm(self) -> None: """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. """ @@ -754,15 +766,15 @@ def test_cas_redirect_confirm(self): } } ) - def test_cas_redirect_whitelisted(self): + def test_cas_redirect_whitelisted(self) -> None: """Tests that the SSO login flow serves a redirect to a whitelisted url""" self._test_redirect("https://legit-site.com/") @override_config({"public_baseurl": "https://example.com"}) - def test_cas_redirect_login_fallback(self): + def test_cas_redirect_login_fallback(self) -> None: self._test_redirect("https://example.com/_matrix/static/client/login") - def _test_redirect(self, redirect_url): + def _test_redirect(self, redirect_url: str) -> None: """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" @@ -778,7 +790,7 @@ def _test_redirect(self, redirect_url): self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) - def test_deactivated_user(self): + def test_deactivated_user(self) -> None: """Logging in as a deactivated account should error.""" redirect_url = "https://legit-site.com/" @@ -821,7 +833,7 @@ class JWTTestCase(unittest.HomeserverTestCase): "algorithm": jwt_algorithm, } - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # If jwt_config has been defined (eg via @override_config), don't replace it. @@ -837,23 +849,23 @@ def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid_registered(self): + def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_valid_unregistered(self): + def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -862,7 +874,7 @@ def test_login_jwt_invalid_signature(self): "JWT validation failed: Signature verification failed", ) - def test_login_jwt_expired(self): + def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -870,7 +882,7 @@ def test_login_jwt_expired(self): channel.json_body["error"], "JWT validation failed: Signature has expired" ) - def test_login_jwt_not_before(self): + def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -880,14 +892,14 @@ def test_login_jwt_not_before(self): "JWT validation failed: The token is not yet valid (nbf)", ) - def test_login_no_sub(self): + def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) - def test_login_iss(self): + def test_login_iss(self) -> None: """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) @@ -911,14 +923,14 @@ def test_login_iss(self): 'JWT validation failed: Token is missing the "iss" claim', ) - def test_login_iss_no_config(self): + def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) - def test_login_aud(self): + def test_login_aud(self) -> None: """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) @@ -942,7 +954,7 @@ def test_login_aud(self): 'JWT validation failed: Token is missing the "aud" claim', ) - def test_login_aud_no_config(self): + def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -951,20 +963,20 @@ def test_login_aud_no_config(self): channel.json_body["error"], "JWT validation failed: Invalid audience" ) - def test_login_default_sub(self): + def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) - def test_login_custom_sub(self): + def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_no_token(self): + def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -1026,7 +1038,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): ] ) - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["jwt_config"] = { "enabled": True, @@ -1042,17 +1054,17 @@ def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> s return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid(self): + def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -1071,7 +1083,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.service = ApplicationService( @@ -1101,11 +1113,11 @@ def make_homeserver(self, reactor, clock): }, ) - self.hs.get_datastore().services_cache.append(self.service) - self.hs.get_datastore().services_cache.append(self.another_service) + self.hs.get_datastores().main.services_cache.append(self.service) + self.hs.get_datastores().main.services_cache.append(self.another_service) return self.hs - def test_login_appservice_user(self): + def test_login_appservice_user(self) -> None: """Test that an appservice user can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1117,9 +1129,9 @@ def test_login_appservice_user(self): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) - def test_login_appservice_user_bot(self): + def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1131,9 +1143,9 @@ def test_login_appservice_user_bot(self): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) - def test_login_appservice_wrong_user(self): + def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1145,9 +1157,9 @@ def test_login_appservice_wrong_user(self): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) - def test_login_appservice_wrong_as(self): + def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1159,9 +1171,9 @@ def test_login_appservice_wrong_as(self): b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) - def test_login_appservice_no_token(self): + def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice login method """ @@ -1173,7 +1185,7 @@ def test_login_appservice_no_token(self): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) @skip_unless(HAS_OIDC, "requires OIDC") @@ -1182,7 +1194,7 @@ class UsernamePickerTestCase(HomeserverTestCase): servlets = [login.register_servlets] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL @@ -1202,7 +1214,7 @@ def create_resource_dict(self) -> Dict[str, Resource]: d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_username_picker(self): + def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" # do the start of the login flow diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py index 3cf5871899e1..3a74d2e96c49 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -13,11 +13,16 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import account, login, password_policy, register +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -46,7 +51,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.register_url = "/_matrix/client/r0/register" self.policy = { "enabled": True, @@ -65,12 +70,12 @@ def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(config=config) return hs - def test_get_policy(self): + def test_get_policy(self) -> None: """Tests if the /password_policy endpoint returns the configured policy.""" channel = self.make_request("GET", "/_matrix/client/r0/password_policy") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual( channel.json_body, { @@ -83,70 +88,70 @@ def test_get_policy(self): channel.result, ) - def test_password_too_short(self): + def test_password_too_short(self) -> None: request_data = json.dumps({"username": "kermit", "password": "shorty"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, ) - def test_password_no_digit(self): + def test_password_no_digit(self) -> None: request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, ) - def test_password_no_symbol(self): + def test_password_no_symbol(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, ) - def test_password_no_uppercase(self): + def test_password_no_uppercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, ) - def test_password_no_lowercase(self): + def test_password_no_lowercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, ) - def test_password_compliant(self): + def test_password_compliant(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has # responded with a list of registration flows. - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) - def test_password_change(self): + def test_password_change(self) -> None: """This doesn't test every possible use case, only that hitting /account/password triggers the password validation code. """ @@ -173,5 +178,5 @@ def test_password_change(self): access_token=tok, ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index c0de4c93a806..27dcfc83d204 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -30,12 +35,12 @@ class PowerLevelsTestCase(HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register a room admin, moderator and regular user self.admin_user_id = self.register_user("admin", "pass") self.admin_access_token = self.login("admin", "pass") @@ -88,7 +93,7 @@ def prepare(self, reactor, clock, hs): tok=self.admin_access_token, ) - def test_non_admins_cannot_enable_room_encryption(self): + def test_non_admins_cannot_enable_room_encryption(self) -> None: # have the mod try to enable room encryption self.helper.send_state( self.room_id, @@ -104,10 +109,10 @@ def test_non_admins_cannot_enable_room_encryption(self): "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_send_server_acl(self): + def test_non_admins_cannot_send_server_acl(self) -> None: # have the mod try to send a server ACL self.helper.send_state( self.room_id, @@ -118,7 +123,7 @@ def test_non_admins_cannot_send_server_acl(self): "deny": ["*.evil.com", "evil.com"], }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a server ACL @@ -131,10 +136,10 @@ def test_non_admins_cannot_send_server_acl(self): "deny": ["*.evil.com", "evil.com"], }, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_tombstone_room(self): + def test_non_admins_cannot_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -149,7 +154,7 @@ def test_non_admins_cannot_tombstone_room(self): "replacement_room": self.upgraded_room_id, }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a tombstone event @@ -164,17 +169,17 @@ def test_non_admins_cannot_tombstone_room(self): expect_code=403, # expect failure ) - def test_admins_can_enable_room_encryption(self): + def test_admins_can_enable_room_encryption(self) -> None: # have the admin try to enable room encryption self.helper.send_state( self.room_id, "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_send_server_acl(self): + def test_admins_can_send_server_acl(self) -> None: # have the admin try to send a server ACL self.helper.send_state( self.room_id, @@ -185,10 +190,10 @@ def test_admins_can_send_server_acl(self): "deny": ["*.evil.com", "evil.com"], }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_tombstone_room(self): + def test_admins_can_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -203,10 +208,10 @@ def test_admins_can_tombstone_room(self): "replacement_room": self.upgraded_room_id, }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_cannot_set_string_power_levels(self): + def test_cannot_set_string_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -221,7 +226,7 @@ def test_cannot_set_string_power_levels(self): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -230,7 +235,7 @@ def test_cannot_set_string_power_levels(self): body, ) - def test_cannot_set_unsafe_large_power_levels(self): + def test_cannot_set_unsafe_large_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -247,7 +252,7 @@ def test_cannot_set_unsafe_large_power_levels(self): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -256,7 +261,7 @@ def test_cannot_set_unsafe_large_power_levels(self): body, ) - def test_cannot_set_unsafe_small_power_levels(self): + def test_cannot_set_unsafe_small_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -273,7 +278,7 @@ def test_cannot_set_unsafe_small_power_levels(self): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 56fe1a3d0133..0abe378fe4d6 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.presence import PresenceHandler from synapse.rest.client import presence +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -31,7 +34,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): user = UserID.from_string(user_id) servlets = [presence.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: presence_handler = Mock(spec=PresenceHandler) presence_handler.set_state.return_value = defer.succeed(None) @@ -45,7 +48,7 @@ def make_homeserver(self, reactor, clock): return hs - def test_put_presence(self): + def test_put_presence(self) -> None: """ PUT to the status endpoint with use_presence enabled will call set_state on the presence handler. @@ -57,11 +60,11 @@ def test_put_presence(self): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) @unittest.override_config({"use_presence": False}) - def test_put_presence_disabled(self): + def test_put_presence_disabled(self) -> None: """ PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. @@ -72,5 +75,5 @@ def test_put_presence_disabled(self): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index ead883ded886..77c3ced42ec1 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -13,12 +13,16 @@ # limitations under the License. """Tests REST events for /profile paths.""" -from typing import Any, Dict +from typing import Any, Dict, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import login, profile, room +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -32,20 +36,20 @@ class ProfileTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") self.other = self.register_user("other", "pass", displayname="Bob") - def test_get_displayname(self): + def test_get_displayname(self) -> None: res = self._get_displayname() self.assertEqual(res, "owner") - def test_set_displayname(self): + def test_set_displayname(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), @@ -57,7 +61,7 @@ def test_set_displayname(self): res = self._get_displayname() self.assertEqual(res, "test") - def test_set_displayname_noauth(self): + def test_set_displayname_noauth(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), @@ -65,7 +69,7 @@ def test_set_displayname_noauth(self): ) self.assertEqual(channel.code, 401, channel.result) - def test_set_displayname_too_long(self): + def test_set_displayname_too_long(self) -> None: """Attempts to set a stupid displayname should get a 400""" channel = self.make_request( "PUT", @@ -78,11 +82,11 @@ def test_set_displayname_too_long(self): res = self._get_displayname() self.assertEqual(res, "owner") - def test_get_displayname_other(self): + def test_get_displayname_other(self) -> None: res = self._get_displayname(self.other) - self.assertEquals(res, "Bob") + self.assertEqual(res, "Bob") - def test_set_displayname_other(self): + def test_set_displayname_other(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.other,), @@ -91,11 +95,11 @@ def test_set_displayname_other(self): ) self.assertEqual(channel.code, 400, channel.result) - def test_get_avatar_url(self): + def test_get_avatar_url(self) -> None: res = self._get_avatar_url() self.assertIsNone(res) - def test_set_avatar_url(self): + def test_set_avatar_url(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.owner,), @@ -107,7 +111,7 @@ def test_set_avatar_url(self): res = self._get_avatar_url() self.assertEqual(res, "http://my.server/pic.gif") - def test_set_avatar_url_noauth(self): + def test_set_avatar_url_noauth(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.owner,), @@ -115,7 +119,7 @@ def test_set_avatar_url_noauth(self): ) self.assertEqual(channel.code, 401, channel.result) - def test_set_avatar_url_too_long(self): + def test_set_avatar_url_too_long(self) -> None: """Attempts to set a stupid avatar_url should get a 400""" channel = self.make_request( "PUT", @@ -128,11 +132,11 @@ def test_set_avatar_url_too_long(self): res = self._get_avatar_url() self.assertIsNone(res) - def test_get_avatar_url_other(self): + def test_get_avatar_url_other(self) -> None: res = self._get_avatar_url(self.other) self.assertIsNone(res) - def test_set_avatar_url_other(self): + def test_set_avatar_url_other(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.other,), @@ -141,14 +145,14 @@ def test_set_avatar_url_other(self): ) self.assertEqual(channel.code, 400, channel.result) - def _get_displayname(self, name=None): + def _get_displayname(self, name: Optional[str] = None) -> str: channel = self.make_request( "GET", "/profile/%s/displayname" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] - def _get_avatar_url(self, name=None): + def _get_avatar_url(self, name: Optional[str] = None) -> str: channel = self.make_request( "GET", "/profile/%s/avatar_url" % (name or self.owner,) ) @@ -156,7 +160,7 @@ def _get_avatar_url(self, name=None): return channel.json_body.get("avatar_url") @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_size_limit_global(self): + def test_avatar_size_limit_global(self) -> None: """Tests that the maximum size limit for avatars is enforced when updating a global profile. """ @@ -187,7 +191,7 @@ def test_avatar_size_limit_global(self): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_size_limit_per_room(self): + def test_avatar_size_limit_per_room(self) -> None: """Tests that the maximum size limit for avatars is enforced when updating a per-room profile. """ @@ -220,7 +224,7 @@ def test_avatar_size_limit_per_room(self): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_allowed_mime_type_global(self): + def test_avatar_allowed_mime_type_global(self) -> None: """Tests that the MIME type whitelist for avatars is enforced when updating a global profile. """ @@ -251,7 +255,7 @@ def test_avatar_allowed_mime_type_global(self): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_allowed_mime_type_per_room(self): + def test_avatar_allowed_mime_type_per_room(self) -> None: """Tests that the MIME type whitelist for avatars is enforced when updating a per-room profile. """ @@ -283,7 +287,7 @@ def test_avatar_allowed_mime_type_per_room(self): ) self.assertEqual(channel.code, 200, channel.result) - def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None: """Stores metadata about files in the database. Args: @@ -292,7 +296,7 @@ def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( @@ -316,8 +320,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): - + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_auth_for_profile_requests"] = True config["limit_profile_requests_to_users_who_share_rooms"] = True @@ -325,7 +328,7 @@ def make_homeserver(self, reactor, clock): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # User owning the requested profile. self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") @@ -337,22 +340,24 @@ def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) - def test_no_auth(self): + def test_no_auth(self) -> None: self.try_fetch_profile(401) - def test_not_in_shared_room(self): + def test_not_in_shared_room(self) -> None: self.ensure_requester_left_room() self.try_fetch_profile(403, access_token=self.requester_tok) - def test_in_shared_room(self): + def test_in_shared_room(self) -> None: self.ensure_requester_left_room() self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) self.try_fetch_profile(200, self.requester_tok) - def try_fetch_profile(self, expected_code, access_token=None): + def try_fetch_profile( + self, expected_code: int, access_token: Optional[str] = None + ) -> None: self.request_profile(expected_code, access_token=access_token) self.request_profile( @@ -363,13 +368,18 @@ def try_fetch_profile(self, expected_code, access_token=None): expected_code, url_suffix="/avatar_url", access_token=access_token ) - def request_profile(self, expected_code, url_suffix="", access_token=None): + def request_profile( + self, + expected_code: int, + url_suffix: str = "", + access_token: Optional[str] = None, + ) -> None: channel = self.make_request( "GET", self.profile_url + url_suffix, access_token=access_token ) self.assertEqual(channel.code, expected_code, channel.result) - def ensure_requester_left_room(self): + def ensure_requester_left_room(self) -> None: try: self.helper.leave( room=self.room_id, user=self.requester, tok=self.requester_tok @@ -389,7 +399,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_auth_for_profile_requests"] = True config["limit_profile_requests_to_users_who_share_rooms"] = True @@ -397,12 +407,12 @@ def make_homeserver(self, reactor, clock): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # User requesting the profile. self.requester = self.register_user("requester", "pass") self.requester_tok = self.login("requester", "pass") - def test_can_lookup_own_profile(self): + def test_can_lookup_own_profile(self) -> None: """Tests that a user can lookup their own profile without having to be in a room if 'require_auth_for_profile_requests' is set to true in the server's config. """ diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py index d0ce91ccd95c..4f875b92898e 100644 --- a/tests/rest/client/test_push_rule_attrs.py +++ b/tests/rest/client/test_push_rule_attrs.py @@ -27,7 +27,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): ] hijack_auth = False - def test_enabled_on_creation(self): + def test_enabled_on_creation(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is enabled upon creation, even though a rule with that @@ -56,7 +56,7 @@ def test_enabled_on_creation(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_on_recreation(self): + def test_enabled_on_recreation(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is enabled upon creation, even if a rule with that @@ -113,7 +113,7 @@ def test_enabled_on_recreation(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_disable(self): + def test_enabled_disable(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is disabled and enabled when we ask for it. @@ -166,7 +166,7 @@ def test_enabled_disable(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_404_when_get_non_existent(self): + def test_enabled_404_when_get_non_existent(self) -> None: """ Tests that `enabled` gives 404 when the rule doesn't exist. """ @@ -212,7 +212,7 @@ def test_enabled_404_when_get_non_existent(self): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_get_non_existent_server_rule(self): + def test_enabled_404_when_get_non_existent_server_rule(self) -> None: """ Tests that `enabled` gives 404 when the server-default rule doesn't exist. """ @@ -226,7 +226,7 @@ def test_enabled_404_when_get_non_existent_server_rule(self): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_put_non_existent_rule(self): + def test_enabled_404_when_put_non_existent_rule(self) -> None: """ Tests that `enabled` gives 404 when we put to a rule that doesn't exist. """ @@ -243,7 +243,7 @@ def test_enabled_404_when_put_non_existent_rule(self): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_put_non_existent_server_rule(self): + def test_enabled_404_when_put_non_existent_server_rule(self) -> None: """ Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. """ @@ -260,7 +260,7 @@ def test_enabled_404_when_put_non_existent_server_rule(self): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_get(self): + def test_actions_get(self) -> None: """ Tests that `actions` gives you what you expect on a fresh rule. """ @@ -289,7 +289,7 @@ def test_actions_get(self): channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] ) - def test_actions_put(self): + def test_actions_put(self) -> None: """ Tests that PUT on actions updates the value you'd get from GET. """ @@ -325,7 +325,7 @@ def test_actions_put(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["actions"], ["dont_notify"]) - def test_actions_404_when_get_non_existent(self): + def test_actions_404_when_get_non_existent(self) -> None: """ Tests that `actions` gives 404 when the rule doesn't exist. """ @@ -365,7 +365,7 @@ def test_actions_404_when_get_non_existent(self): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_get_non_existent_server_rule(self): + def test_actions_404_when_get_non_existent_server_rule(self) -> None: """ Tests that `actions` gives 404 when the server-default rule doesn't exist. """ @@ -379,7 +379,7 @@ def test_actions_404_when_get_non_existent_server_rule(self): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_put_non_existent_rule(self): + def test_actions_404_when_put_non_existent_rule(self) -> None: """ Tests that `actions` gives 404 when putting to a rule that doesn't exist. """ @@ -396,7 +396,7 @@ def test_actions_404_when_put_non_existent_rule(self): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_put_non_existent_server_rule(self): + def test_actions_404_when_put_non_existent_server_rule(self) -> None: """ Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. """ diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 433d715f695d..7401b5e0c0fa 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,9 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -28,7 +34,7 @@ class RedactionsTestCase(HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["rc_message"] = {"per_second": 0.2, "burst_count": 10} @@ -36,7 +42,7 @@ def make_homeserver(self, reactor, clock): return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register a couple of users self.mod_user_id = self.register_user("user1", "pass") self.mod_access_token = self.login("user1", "pass") @@ -60,7 +66,9 @@ def prepare(self, reactor, clock, hs): room=self.room_id, user=self.other_user_id, tok=self.other_access_token ) - def _redact_event(self, access_token, room_id, event_id, expect_code=200): + def _redact_event( + self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + ) -> JsonDict: """Helper function to send a redaction event. Returns the json body. @@ -71,13 +79,13 @@ def _redact_event(self, access_token, room_id, event_id, expect_code=200): self.assertEqual(int(channel.result["code"]), expect_code) return channel.json_body - def _sync_room_timeline(self, access_token, room_id): + def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: channel = self.make_request("GET", "sync", access_token=self.mod_access_token) self.assertEqual(channel.result["code"], b"200") room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] - def test_redact_event_as_moderator(self): + def test_redact_event_as_moderator(self) -> None: # as a regular user, send a message to redact b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) msg_id = b["event_id"] @@ -98,7 +106,7 @@ def test_redact_event_as_moderator(self): self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-2]["content"], {}) - def test_redact_event_as_normal(self): + def test_redact_event_as_normal(self) -> None: # as a regular user, send a message to redact b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) normal_msg_id = b["event_id"] @@ -133,7 +141,7 @@ def test_redact_event_as_normal(self): self.assertEqual(timeline[-3]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-3]["content"], {}) - def test_redact_nonexistent_event(self): + def test_redact_nonexistent_event(self) -> None: # control case: an existing event b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) msg_id = b["event_id"] @@ -158,7 +166,7 @@ def test_redact_nonexistent_event(self): self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-2]["content"], {}) - def test_redact_create_event(self): + def test_redact_create_event(self) -> None: # control case: an existing event b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token) msg_id = b["event_id"] @@ -178,7 +186,7 @@ def test_redact_create_event(self): self.other_access_token, self.room_id, create_event_id, expect_code=403 ) - def test_redact_event_as_moderator_ratelimit(self): + def test_redact_event_as_moderator_ratelimit(self) -> None: """Tests that the correct ratelimiting is applied to redactions""" message_ids = [] diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 0f1c47dcbb87..9aebf1735a96 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -16,15 +16,21 @@ import datetime import json import os +from typing import Any, Dict, List, Tuple import pkg_resources +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync +from synapse.server import HomeServer from synapse.storage._base import db_to_json +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -39,12 +45,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ] url = b"/_matrix/client/r0/register" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["allow_guest_access"] = True return config - def test_POST_appservice_registration_valid(self): + def test_POST_appservice_registration_valid(self) -> None: user_id = "@as_user_kermit:test" as_token = "i_am_an_app_service" @@ -56,7 +62,7 @@ def test_POST_appservice_registration_valid(self): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps( {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} ) @@ -65,11 +71,11 @@ def test_POST_appservice_registration_valid(self): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) - def test_POST_appservice_registration_no_type(self): + def test_POST_appservice_registration_no_type(self) -> None: as_token = "i_am_an_app_service" appservice = ApplicationService( @@ -80,16 +86,16 @@ def test_POST_appservice_registration_no_type(self): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps({"username": "as_user_kermit"}) channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.result["code"], b"400", channel.result) - def test_POST_appservice_registration_invalid(self): + def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists request_data = json.dumps( {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} @@ -98,23 +104,23 @@ def test_POST_appservice_registration_invalid(self): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) - def test_POST_bad_password(self): + def test_POST_bad_password(self) -> None: request_data = json.dumps({"username": "kermit", "password": 666}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid password") + self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.json_body["error"], "Invalid password") - def test_POST_bad_username(self): + def test_POST_bad_username(self) -> None: request_data = json.dumps({"username": 777, "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") + self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.json_body["error"], "Invalid username") - def test_POST_user_valid(self): + def test_POST_user_valid(self) -> None: user_id = "@kermit:test" device_id = "frogfone" params = { @@ -131,58 +137,58 @@ def test_POST_user_valid(self): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) - def test_POST_disabled_registration(self): + def test_POST_disabled_registration(self) -> None: request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Registration has been disabled") - self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["error"], "Registration has been disabled") + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - def test_POST_guest_registration(self): + def test_POST_guest_registration(self) -> None: self.hs.config.key.macaroon_secret_key = "test" self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) - def test_POST_disabled_guest_registration(self): + def test_POST_disabled_guest_registration(self) -> None: self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Guest access is disabled") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting_guest(self): + def test_POST_ratelimiting_guest(self) -> None: for i in range(0, 6): url = self.url + b"?kind=guest" channel = self.make_request(b"POST", url, b"{}") if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting(self): + def test_POST_ratelimiting(self) -> None: for i in range(0, 6): params = { "username": "kermit" + str(i), @@ -194,23 +200,23 @@ def test_POST_ratelimiting(self): channel = self.make_request(b"POST", self.url, request_data) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"registration_requires_token": True}) - def test_POST_registration_requires_token(self): + def test_POST_registration_requires_token(self) -> None: username = "kermit" device_id = "frogfone" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -223,7 +229,7 @@ def test_POST_registration_requires_token(self): }, ) ) - params = { + params: JsonDict = { "username": username, "password": "monkey", "device_id": device_id, @@ -231,7 +237,7 @@ def test_POST_registration_requires_token(self): # Request without auth to get flows and session channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one # flow would be a subset of another flow. @@ -249,7 +255,7 @@ def test_POST_registration_requires_token(self): } request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -265,7 +271,7 @@ def test_POST_registration_requires_token(self): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 @@ -276,12 +282,12 @@ def test_POST_registration_requires_token(self): retcols=["pending", "completed"], ) ) - self.assertEquals(res["completed"], 1) - self.assertEquals(res["pending"], 0) + self.assertEqual(res["completed"], 1) + self.assertEqual(res["pending"], 0) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_invalid(self): - params = { + def test_POST_registration_token_invalid(self) -> None: + params: JsonDict = { "username": "kermit", "password": "monkey", } @@ -295,28 +301,28 @@ def test_POST_registration_token_invalid(self): "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) + self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_limit_uses(self): + def test_POST_registration_token_limit_uses(self) -> None: token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that can be used once self.get_success( store.db_pool.simple_insert( @@ -330,8 +336,8 @@ def test_POST_registration_token_limit_uses(self): }, ) ) - params1 = {"username": "bert", "password": "monkey"} - params2 = {"username": "ernie", "password": "monkey"} + params1: JsonDict = {"username": "bert", "password": "monkey"} + params2: JsonDict = {"username": "ernie", "password": "monkey"} # Do 2 requests without auth to get two session IDs channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) session1 = channel1.json_body["session"] @@ -354,7 +360,7 @@ def test_POST_registration_token_limit_uses(self): retcol="pending", ) ) - self.assertEquals(pending, 1) + self.assertEqual(pending, 1) # Check auth fails when using token with session2 params2["auth"] = { @@ -363,9 +369,9 @@ def test_POST_registration_token_limit_uses(self): "session": session2, } channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY @@ -378,20 +384,20 @@ def test_POST_registration_token_limit_uses(self): retcols=["pending", "completed"], ) ) - self.assertEquals(res["pending"], 0) - self.assertEquals(res["completed"], 1) + self.assertEqual(res["pending"], 0) + self.assertEqual(res["completed"], 1) # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_expiry(self): + def test_POST_registration_token_expiry(self) -> None: token = "abcd" now = self.hs.get_clock().time_msec() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that expired yesterday self.get_success( store.db_pool.simple_insert( @@ -405,7 +411,7 @@ def test_POST_registration_token_expiry(self): }, ) ) - params = {"username": "kermit", "password": "monkey"} + params: JsonDict = {"username": "kermit", "password": "monkey"} # Request without auth to get session channel = self.make_request(b"POST", self.url, json.dumps(params)) session = channel.json_body["session"] @@ -417,9 +423,9 @@ def test_POST_registration_token_expiry(self): "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) # Update token so it expires tomorrow self.get_success( @@ -436,10 +442,10 @@ def test_POST_registration_token_expiry(self): self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_session_expiry(self): + def test_POST_registration_token_session_expiry(self) -> None: """Test `pending` is decremented when an uncompleted session expires.""" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -454,8 +460,8 @@ def test_POST_registration_token_session_expiry(self): ) # Do 2 requests without auth to get two session IDs - params1 = {"username": "bert", "password": "monkey"} - params2 = {"username": "ernie", "password": "monkey"} + params1: JsonDict = {"username": "bert", "password": "monkey"} + params2: JsonDict = {"username": "ernie", "password": "monkey"} channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) session1 = channel1.json_body["session"] channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) @@ -504,7 +510,7 @@ def test_POST_registration_token_session_expiry(self): retcol="result", ) ) - self.assertEquals(db_to_json(result2), token) + self.assertEqual(db_to_json(result2), token) # Delete both sessions (mimics expiry) self.get_success( @@ -519,10 +525,10 @@ def test_POST_registration_token_session_expiry(self): retcol="pending", ) ) - self.assertEquals(pending, 0) + self.assertEqual(pending, 0) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_session_expiry_deleted_token(self): + def test_POST_registration_token_session_expiry_deleted_token(self) -> None: """Test session expiry doesn't break when the token is deleted. 1. Start but don't complete UIA with a registration token @@ -530,7 +536,7 @@ def test_POST_registration_token_session_expiry_deleted_token(self): 3. Expire the session """ token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -545,7 +551,7 @@ def test_POST_registration_token_session_expiry_deleted_token(self): ) # Do request without auth to get a session ID - params = {"username": "kermit", "password": "monkey"} + params: JsonDict = {"username": "kermit", "password": "monkey"} channel = self.make_request(b"POST", self.url, json.dumps(params)) session = channel.json_body["session"] @@ -570,9 +576,9 @@ def test_POST_registration_token_session_expiry_deleted_token(self): store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) ) - def test_advertised_flows(self): + def test_advertised_flows(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # with the stock config, we only expect the dummy flow @@ -593,9 +599,9 @@ def test_advertised_flows(self): }, } ) - def test_advertised_flows_captcha_and_terms_and_3pids(self): + def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] self.assertCountEqual( @@ -625,9 +631,9 @@ def test_advertised_flows_captcha_and_terms_and_3pids(self): }, } ) - def test_advertised_flows_no_msisdn_email_required(self): + def test_advertised_flows_no_msisdn_email_required(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # with the stock config, we expect all four combinations of 3pid @@ -646,7 +652,7 @@ def test_advertised_flows_no_msisdn_email_required(self): }, } ) - def test_request_token_existing_email_inhibit_error(self): + def test_request_token_existing_email_inhibit_error(self) -> None: """Test that requesting a token via this endpoint doesn't leak existing associations if configured that way. """ @@ -657,7 +663,7 @@ def test_request_token_existing_email_inhibit_error(self): # Add a threepid self.get_success( - self.hs.get_datastore().user_add_threepid( + self.hs.get_datastores().main.user_add_threepid( user_id=user_id, medium="email", address=email, @@ -671,7 +677,7 @@ def test_request_token_existing_email_inhibit_error(self): b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIsNotNone(channel.json_body.get("sid")) @@ -685,7 +691,7 @@ def test_request_token_existing_email_inhibit_error(self): }, } ) - def test_reject_invalid_email(self): + def test_reject_invalid_email(self) -> None: """Check that bad emails are rejected""" # Test for email with multiple @ @@ -694,9 +700,9 @@ def test_reject_invalid_email(self): b"register/email/requestToken", {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) + self.assertEqual(400, channel.code, channel.result) # Check error to ensure that we're not erroring due to a bug in the test. - self.assertEquals( + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -707,8 +713,8 @@ def test_reject_invalid_email(self): b"register/email/requestToken", {"client_secret": "foobar", "email": "email", "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( + self.assertEqual(400, channel.code, channel.result) + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -720,8 +726,8 @@ def test_reject_invalid_email(self): b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( + self.assertEqual(400, channel.code, channel.result) + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -731,7 +737,7 @@ def test_reject_invalid_email(self): "inhibit_user_in_use_error": True, } ) - def test_inhibit_user_in_use_error(self): + def test_inhibit_user_in_use_error(self) -> None: """Tests that the 'inhibit_user_in_use_error' configuration flag behaves correctly. """ @@ -745,7 +751,7 @@ def test_inhibit_user_in_use_error(self): # Check that /available correctly ignores the username provided despite the # username being already registered. channel = self.make_request("GET", "register/available?username=" + username) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Test that when starting a UIA registration flow the request doesn't fail because # of a conflicting username @@ -779,7 +785,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): account_validity.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Test for account expiring after a week. config["enable_registration"] = True @@ -791,7 +797,7 @@ def make_homeserver(self, reactor, clock): return self.hs - def test_validity_period(self): + def test_validity_period(self) -> None: self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -799,18 +805,18 @@ def test_validity_period(self): # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) - def test_manual_renewal(self): + def test_manual_renewal(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -826,14 +832,14 @@ def test_manual_renewal(self): params = {"user_id": user_id} request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) - def test_manual_expire(self): + def test_manual_expire(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -848,17 +854,17 @@ def test_manual_expire(self): } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) - def test_logging_out_expired_user(self): + def test_logging_out_expired_user(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -873,18 +879,18 @@ def test_logging_out_expired_user(self): } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Try to log the user out channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -898,7 +904,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Test for account expiring after a week and renewal emails being sent 2 @@ -935,17 +941,17 @@ def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver(config=config) - async def sendmail(*args, **kwargs): + async def sendmail(*args: Any, **kwargs: Any) -> None: self.email_attempts.append((args, kwargs)) - self.email_attempts = [] + self.email_attempts: List[Tuple[Any, Any]] = [] self.hs.get_send_email_handler()._sendmail = sendmail - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs - def test_renewal_email(self): + def test_renewal_email(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -959,7 +965,7 @@ def test_renewal_email(self): renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -977,7 +983,7 @@ def test_renewal_email(self): # Move 1 day forward. Try to renew with the same token again. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -997,14 +1003,14 @@ def test_renewal_email(self): # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) - def test_renewal_invalid_token(self): + def test_renewal_invalid_token(self) -> None: # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"404", channel.result) + self.assertEqual(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1019,7 +1025,7 @@ def test_renewal_invalid_token(self): channel.result["body"], expected_html.encode("utf8"), channel.result ) - def test_manual_email_send(self): + def test_manual_email_send(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -1028,11 +1034,11 @@ def test_manual_email_send(self): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) - def test_deactivated_user(self): + def test_deactivated_user(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -1056,7 +1062,7 @@ def test_deactivated_user(self): self.assertEqual(len(self.email_attempts), 0) - def create_user(self): + def create_user(self) -> Tuple[str, str]: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") # We need to manually add an email address otherwise the handler will do @@ -1073,7 +1079,7 @@ def create_user(self): ) return user_id, tok - def test_manual_email_send_expired_account(self): + def test_manual_email_send_expired_account(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -1103,7 +1109,7 @@ def test_manual_email_send_expired_account(self): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1112,7 +1118,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.validity_period = 10 self.max_delta = self.validity_period * 10.0 / 100.0 @@ -1126,14 +1132,16 @@ def make_homeserver(self, reactor, clock): # We need to set these directly, instead of in the homeserver config dict above. # This is due to account validity-related config options not being read by # Synapse when account_validity.enabled is False. - self.hs.get_datastore()._account_validity_period = self.validity_period - self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta + self.hs.get_datastores().main._account_validity_period = self.validity_period + self.hs.get_datastores().main._account_validity_startup_job_max_delta = ( + self.max_delta + ) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs - def test_background_job(self): + def test_background_job(self) -> None: """ Tests the same thing as test_background_job, except that it sets the startup_job_max_delta parameter and checks that the expiration date is within the @@ -1156,14 +1164,14 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] url = "/_matrix/client/v1/register/m.login.registration_token/validity" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["registration_requires_token"] = True return config - def test_GET_token_valid(self): + def test_GET_token_valid(self) -> None: token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -1181,22 +1189,22 @@ def test_GET_token_valid(self): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertEquals(channel.json_body["valid"], True) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["valid"], True) - def test_GET_token_invalid(self): + def test_GET_token_invalid(self) -> None: token = "1234" channel = self.make_request( b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertEquals(channel.json_body["valid"], False) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["valid"], False) @override_config( {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} ) - def test_GET_ratelimiting(self): + def test_GET_ratelimiting(self) -> None: token = "1234" for i in range(0, 6): @@ -1206,10 +1214,10 @@ def test_GET_ratelimiting(self): ) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -1217,4 +1225,4 @@ def test_GET_ratelimiting(self): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index de80aca03705..c8db45719e07 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -18,11 +18,15 @@ from typing import Dict, List, Optional, Tuple from unittest.mock import patch +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync +from synapse.server import HomeServer from synapse.storage.relations import RelationPaginationToken from synapse.types import JsonDict, StreamToken +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -52,8 +56,8 @@ def default_config(self) -> dict: return config - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main self.user_id, self.user_token = self._create_user("alice") self.user2_id, self.user2_token = self._create_user("bob") @@ -63,13 +67,13 @@ def prepare(self, reactor, clock, hs): res = self.helper.send(self.room, body="Hi!", tok=self.user_token) self.parent_id = res["event_id"] - def test_send_relation(self): + def test_send_relation(self) -> None: """Tests that sending a relation using the new /send_relation works creates the right shape of event. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] @@ -78,7 +82,7 @@ def test_send_relation(self): "/rooms/%s/event/%s" % (self.room, event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assert_dict( { @@ -95,7 +99,7 @@ def test_send_relation(self): channel.json_body, ) - def test_deny_invalid_event(self): + def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" channel = self._send_relation( RelationTypes.ANNOTATION, @@ -103,11 +107,11 @@ def test_deny_invalid_event(self): parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) # Unless that event is referenced from another event! self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_relations", values={ "event_id": "bar", @@ -123,9 +127,9 @@ def test_deny_invalid_event(self): parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - def test_deny_invalid_room(self): + def test_deny_invalid_room(self) -> None: """Test that we deny relations on non-existant events""" # Create another room and send a message in it. room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) @@ -136,17 +140,17 @@ def test_deny_invalid_room(self): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) - def test_deny_double_react(self): + def test_deny_double_react(self) -> None: """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) - def test_deny_forked_thread(self): + def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" channel = self._send_relation( RelationTypes.THREAD, @@ -154,7 +158,7 @@ def test_deny_forked_thread(self): content={"msgtype": "m.text", "body": "foo"}, parent_id=self.parent_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) parent_id = channel.json_body["event_id"] channel = self._send_relation( @@ -163,16 +167,16 @@ def test_deny_forked_thread(self): content={"msgtype": "m.text", "body": "foo"}, parent_id=parent_id, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) - def test_basic_paginate_relations(self): + def test_basic_paginate_relations(self) -> None: """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) first_annotation_id = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) second_annotation_id = channel.json_body["event_id"] channel = self.make_request( @@ -180,11 +184,11 @@ def test_basic_paginate_relations(self): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the latest # full relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( { "event_id": second_annotation_id, @@ -195,7 +199,7 @@ def test_basic_paginate_relations(self): ) # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( + self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -212,11 +216,11 @@ def test_basic_paginate_relations(self): f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the earliest # full relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( { "event_id": first_annotation_id, @@ -235,7 +239,7 @@ def _stream_token_to_relation_token(self, token: str) -> str: ).to_string(self.store) ) - def test_repeated_paginate_relations(self): + def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. """ @@ -245,7 +249,7 @@ def test_repeated_paginate_relations(self): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) prev_token = "" @@ -260,12 +264,12 @@ def test_repeated_paginate_relations(self): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -273,7 +277,7 @@ def test_repeated_paginate_relations(self): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) # Reset and try again, but convert the tokens to the legacy format. prev_token = "" @@ -288,12 +292,12 @@ def test_repeated_paginate_relations(self): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -301,12 +305,12 @@ def test_repeated_paginate_relations(self): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) - def test_pagination_from_sync_and_messages(self): + def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Send an event after the relation events. self.helper.send(self.room, body="Latest event", tok=self.user_token) @@ -319,7 +323,7 @@ def test_pagination_from_sync_and_messages(self): channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] sync_prev_batch = room_timeline["prev_batch"] self.assertIsNotNone(sync_prev_batch) @@ -335,7 +339,7 @@ def test_pagination_from_sync_and_messages(self): f"/rooms/{self.room}/messages?dir=b&limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) messages_end = channel.json_body["end"] self.assertIsNotNone(messages_end) # Ensure the relation event is not in the chunk returned from /messages. @@ -355,14 +359,14 @@ def test_pagination_from_sync_and_messages(self): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # The relation should be in the returned chunk. self.assertIn( annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] ) - def test_aggregation_pagination_groups(self): + def test_aggregation_pagination_groups(self) -> None: """Test that we can paginate annotation groups correctly.""" # We need to create ten separate users to send each reaction. @@ -386,7 +390,7 @@ def test_aggregation_pagination_groups(self): key=key, access_token=access_tokens[idx], ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) idx += 1 idx %= len(access_tokens) @@ -404,7 +408,7 @@ def test_aggregation_pagination_groups(self): % (self.room, self.parent_id, from_token), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -419,15 +423,15 @@ def test_aggregation_pagination_groups(self): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: break - self.assertEquals(sent_groups, found_groups) + self.assertEqual(sent_groups, found_groups) - def test_aggregation_pagination_within_group(self): + def test_aggregation_pagination_within_group(self) -> None: """Test that we can paginate within an annotation group.""" # We need to create ten separate users to send each reaction. @@ -449,14 +453,14 @@ def test_aggregation_pagination_within_group(self): key="👍", access_token=access_tokens[idx], ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) idx += 1 # Also send a different type of reaction so that we test we don't see it channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) prev_token = "" found_event_ids: List[str] = [] @@ -473,7 +477,7 @@ def test_aggregation_pagination_within_group(self): f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -481,7 +485,7 @@ def test_aggregation_pagination_within_group(self): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -489,7 +493,7 @@ def test_aggregation_pagination_within_group(self): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) # Reset and try again, but convert the tokens to the legacy format. prev_token = "" @@ -506,7 +510,7 @@ def test_aggregation_pagination_within_group(self): f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -514,7 +518,7 @@ def test_aggregation_pagination_within_group(self): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -522,21 +526,21 @@ def test_aggregation_pagination_within_group(self): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) - def test_aggregation(self): + def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -544,9 +548,9 @@ def test_aggregation(self): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual( channel.json_body, { "chunk": [ @@ -556,17 +560,17 @@ def test_aggregation(self): }, ) - def test_aggregation_redactions(self): + def test_aggregation_redactions(self) -> None: """Test that annotations get correctly aggregated after a redaction.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Now lets redact one of the 'a' reactions channel = self.make_request( @@ -575,7 +579,7 @@ def test_aggregation_redactions(self): access_token=self.user_token, content={}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -583,14 +587,14 @@ def test_aggregation_redactions(self): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual( channel.json_body, {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) - def test_aggregation_must_be_annotation(self): + def test_aggregation_must_be_annotation(self) -> None: """Test that aggregations must be annotations.""" channel = self.make_request( @@ -599,12 +603,12 @@ def test_aggregation_must_be_annotation(self): % (self.room, self.parent_id, RelationTypes.REPLACE), access_token=self.user_token, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) @unittest.override_config( {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} ) - def test_bundled_aggregations(self): + def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -615,29 +619,29 @@ def test_bundled_aggregations(self): """ # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply_1 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -655,7 +659,7 @@ def assert_bundle(event_json: JsonDict) -> None: ) # Check the values of each field. - self.assertEquals( + self.assertEqual( { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, @@ -665,12 +669,12 @@ def assert_bundle(event_json: JsonDict) -> None: relations_dict[RelationTypes.ANNOTATION], ) - self.assertEquals( + self.assertEqual( {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, relations_dict[RelationTypes.REFERENCE], ) - self.assertEquals( + self.assertEqual( 2, relations_dict[RelationTypes.THREAD].get("count"), ) @@ -701,7 +705,7 @@ def assert_bundle(event_json: JsonDict) -> None: f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body) # Request the room messages. @@ -710,7 +714,7 @@ def assert_bundle(event_json: JsonDict) -> None: f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. @@ -719,12 +723,12 @@ def assert_bundle(event_json: JsonDict) -> None: f"/rooms/{self.room}/context/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) # Request sync. channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) @@ -737,7 +741,7 @@ def assert_bundle(event_json: JsonDict) -> None: content={"search_categories": {"room_events": {"search_term": "Hi"}}}, access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) chunk = [ result["result"] for result in channel.json_body["search_categories"]["room_events"][ @@ -746,47 +750,47 @@ def assert_bundle(event_json: JsonDict) -> None: ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_aggregation_get_event_for_annotation(self): + def test_aggregation_get_event_for_annotation(self) -> None: """Test that annotations do not get bundled aggregations included when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Annotate the annotation. channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", f"/rooms/{self.room}/event/{annotation_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - def test_aggregation_get_event_for_thread(self): + def test_aggregation_get_event_for_thread(self) -> None: """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] # Annotate the annotation. channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", f"/rooms/{self.room}/event/{thread_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( channel.json_body["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { @@ -801,11 +805,11 @@ def test_aggregation_get_event_for_thread(self): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1) thread_message = channel.json_body["chunk"][0] - self.assertEquals( + self.assertEqual( thread_message["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { @@ -815,7 +819,7 @@ def test_aggregation_get_event_for_thread(self): ) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_ignore_invalid_room(self): + def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) @@ -905,7 +909,7 @@ def test_ignore_invalid_room(self): f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) # And when fetching aggregations. @@ -914,7 +918,7 @@ def test_ignore_invalid_room(self): f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) # And for bundled aggregations. @@ -923,11 +927,11 @@ def test_ignore_invalid_room(self): f"/rooms/{room2}/event/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_edit(self): + def test_edit(self) -> None: """Test that a simple edit works.""" new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -936,7 +940,7 @@ def test_edit(self): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -958,8 +962,8 @@ def assert_bundle(event_json: JsonDict) -> None: f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["content"], new_body) assert_bundle(channel.json_body) # Request the room messages. @@ -968,7 +972,7 @@ def assert_bundle(event_json: JsonDict) -> None: f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. @@ -977,7 +981,7 @@ def assert_bundle(event_json: JsonDict) -> None: f"/rooms/{self.room}/context/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) # Request sync, but limit the timeline so it becomes limited (and includes @@ -988,7 +992,7 @@ def assert_bundle(event_json: JsonDict) -> None: channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) @@ -1001,7 +1005,7 @@ def assert_bundle(event_json: JsonDict) -> None: content={"search_categories": {"room_events": {"search_term": "Hi"}}}, access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) chunk = [ result["result"] for result in channel.json_body["search_categories"]["room_events"][ @@ -1010,7 +1014,7 @@ def assert_bundle(event_json: JsonDict) -> None: ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_multi_edit(self): + def test_multi_edit(self) -> None: """Test that multiple edits, including attempts by people who shouldn't be allowed, are correctly handled. """ @@ -1024,7 +1028,7 @@ def test_multi_edit(self): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( @@ -1032,7 +1036,7 @@ def test_multi_edit(self): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -1045,16 +1049,16 @@ def test_multi_edit(self): "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(channel.json_body["content"], new_body) relations_dict = channel.json_body["unsigned"].get("m.relations") self.assertIn(RelationTypes.REPLACE, relations_dict) @@ -1067,7 +1071,7 @@ def test_multi_edit(self): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_edit_reply(self): + def test_edit_reply(self) -> None: """Test that editing a reply works.""" # Create a reply to edit. @@ -1076,7 +1080,7 @@ def test_edit_reply(self): "m.room.message", content={"msgtype": "m.text", "body": "A reply!"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1086,7 +1090,7 @@ def test_edit_reply(self): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=reply, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -1095,7 +1099,7 @@ def test_edit_reply(self): "/rooms/%s/event/%s" % (self.room, reply), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to see the new body in the dict, as well as the reference # metadata sill intact. @@ -1123,7 +1127,47 @@ def test_edit_reply(self): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_edit_edit(self): + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_edit_thread(self) -> None: + """Test that editing a thread works.""" + + # Create a thread and edit the last event. + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "A threaded reply!"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + threaded_event_id = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=threaded_event_id, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Fetch the thread root, to get the bundled aggregation for the thread. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect that the edit message appears in the thread summary in the + # unsigned relations section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.THREAD, relations_dict) + + thread_summary = relations_dict[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event_in_thread = thread_summary["latest_event"] + self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") + + def test_edit_edit(self) -> None: """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} channel = self._send_relation( @@ -1135,7 +1179,7 @@ def test_edit_edit(self): "m.new_content": new_body, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] # Edit the edit event. @@ -1149,7 +1193,7 @@ def test_edit_edit(self): }, parent_id=edit_event_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Request the original event. channel = self.make_request( @@ -1157,9 +1201,9 @@ def test_edit_edit(self): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # The edit to the edit should be ignored. - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(channel.json_body["content"], new_body) # The relations information should not include the edit to the edit. relations_dict = channel.json_body["unsigned"].get("m.relations") @@ -1173,7 +1217,7 @@ def test_edit_edit(self): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_relations_redaction_redacts_edits(self): + def test_relations_redaction_redacts_edits(self) -> None: """Test that edits of an event are redacted when the original event is redacted. """ @@ -1192,7 +1236,7 @@ def test_relations_redaction_redacts_edits(self): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned channel = self.make_request( @@ -1201,10 +1245,10 @@ def test_relations_redaction_redacts_edits(self): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) - self.assertEquals(len(channel.json_body["chunk"]), 1) + self.assertEqual(len(channel.json_body["chunk"]), 1) # Redact the original event channel = self.make_request( @@ -1214,7 +1258,7 @@ def test_relations_redaction_redacts_edits(self): access_token=self.user_token, content="{}", ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Try to check for remaining m.replace relations channel = self.make_request( @@ -1223,13 +1267,13 @@ def test_relations_redaction_redacts_edits(self): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check that no relations are returned self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(channel.json_body["chunk"], []) - def test_aggregations_redaction_prevents_access_to_aggregations(self): + def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None: """Test that annotations of an event are redacted when the original event is redacted. """ @@ -1241,7 +1285,7 @@ def test_aggregations_redaction_prevents_access_to_aggregations(self): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Redact the original channel = self.make_request( @@ -1255,7 +1299,7 @@ def test_aggregations_redaction_prevents_access_to_aggregations(self): access_token=self.user_token, content="{}", ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check that aggregations returns zero channel = self.make_request( @@ -1264,15 +1308,15 @@ def test_aggregations_redaction_prevents_access_to_aggregations(self): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(channel.json_body["chunk"], []) - def test_unknown_relations(self): + def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1281,18 +1325,18 @@ def test_unknown_relations(self): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the full # relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, channel.json_body["chunk"][0], ) # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( + self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -1302,7 +1346,7 @@ def test_unknown_relations(self): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) # But unknown relations can be directly queried. @@ -1312,8 +1356,8 @@ def test_unknown_relations(self): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: """ @@ -1377,18 +1421,18 @@ def _create_user(self, localpart: str) -> Tuple[str, str]: return user_id, access_token - def test_background_update(self): + def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_good = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_bad = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_event_id = channel.json_body["event_id"] # Clean-up the table as if the inserts did not happen during event creation. @@ -1408,8 +1452,8 @@ def test_background_update(self): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good], ) @@ -1433,7 +1477,7 @@ def test_background_update(self): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertCountEqual( [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good, thread_event_id], diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index fe5b536d9705..f3bf8d0934e6 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -13,9 +13,14 @@ # limitations under the License. from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from synapse.visibility import filter_events_for_client from tests import unittest @@ -31,7 +36,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["retention"] = { "enabled": True, @@ -47,15 +52,15 @@ def make_homeserver(self, reactor, clock): return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.serializer = self.hs.get_event_client_serializer() self.clock = self.hs.get_clock() - def test_retention_event_purged_with_state_event(self): + def test_retention_event_purged_with_state_event(self) -> None: """Tests that expired events are correctly purged when the room's retention policy is defined by a state event. """ @@ -72,7 +77,7 @@ def test_retention_event_purged_with_state_event(self): self._test_retention_event_purged(room_id, one_day_ms * 1.5) - def test_retention_event_purged_with_state_event_outside_allowed(self): + def test_retention_event_purged_with_state_event_outside_allowed(self) -> None: """Tests that the server configuration can override the policy for a room when running the purge jobs. """ @@ -102,7 +107,7 @@ def test_retention_event_purged_with_state_event_outside_allowed(self): # instead of the one specified in the room's policy. self._test_retention_event_purged(room_id, one_day_ms * 0.5) - def test_retention_event_purged_without_state_event(self): + def test_retention_event_purged_without_state_event(self) -> None: """Tests that expired events are correctly purged when the room's retention policy is defined by the server's configuration's default retention policy. """ @@ -110,11 +115,11 @@ def test_retention_event_purged_without_state_event(self): self._test_retention_event_purged(room_id, one_day_ms * 2) - def test_visibility(self): + def test_visibility(self) -> None: """Tests that synapse.visibility.filter_events_for_client correctly filters out outdated events """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main storage = self.hs.get_storage() room_id = self.helper.create_room_as(self.user_id, tok=self.token) events = [] @@ -152,7 +157,7 @@ def test_visibility(self): # That event should be the second, not outdated event. self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) - def _test_retention_event_purged(self, room_id: str, increment: float): + def _test_retention_event_purged(self, room_id: str, increment: float) -> None: """Run the following test scenario to test the message retention policy support: 1. Send event 1 @@ -186,6 +191,7 @@ def _test_retention_event_purged(self, room_id: str, increment: float): resp = self.helper.send(room_id=room_id, body="1", tok=self.token) expired_event_id = resp.get("event_id") + assert expired_event_id is not None # Check that we can retrieve the event. expired_event = self.get_event(expired_event_id) @@ -201,6 +207,7 @@ def _test_retention_event_purged(self, room_id: str, increment: float): resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") + assert valid_event_id is not None # Advance the time again. Now our first event should have expired but our second # one should still be kept. @@ -218,7 +225,7 @@ def _test_retention_event_purged(self, room_id: str, increment: float): # has been purged. self.get_event(room_id, create_event.event_id) - def get_event(self, event_id, expect_none=False): + def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: event = self.get_success(self.store.get_event(event_id, allow_none=True)) if expect_none: @@ -240,7 +247,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["retention"] = { "enabled": True, @@ -254,11 +261,11 @@ def make_homeserver(self, reactor, clock): ) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - def test_no_default_policy(self): + def test_no_default_policy(self) -> None: """Tests that an event doesn't get expired if there is neither a default retention policy nor a policy specific to the room. """ @@ -266,7 +273,7 @@ def test_no_default_policy(self): self._test_retention(room_id) - def test_state_policy(self): + def test_state_policy(self) -> None: """Tests that an event gets correctly expired if there is no default retention policy but there's a policy specific to the room. """ @@ -283,12 +290,15 @@ def test_state_policy(self): self._test_retention(room_id, expected_code_for_first_event=404) - def _test_retention(self, room_id, expected_code_for_first_event=200): + def _test_retention( + self, room_id: str, expected_code_for_first_event: int = 200 + ) -> None: # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") + assert first_event_id is not None # Check that we can retrieve the event. expired_event = self.get_event(room_id, first_event_id) @@ -304,6 +314,7 @@ def _test_retention(self, room_id, expected_code_for_first_event=200): resp = self.helper.send(room_id=room_id, body="2", tok=self.token) second_event_id = resp.get("event_id") + assert second_event_id is not None # Advance the time by another month. self.reactor.advance(one_day_ms * 30 / 1000) @@ -322,7 +333,9 @@ def _test_retention(self, room_id, expected_code_for_first_event=200): second_event = self.get_event(room_id, second_event_id) self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) - def get_event(self, room_id, event_id, expected_code=200): + def get_event( + self, room_id: str, event_id: str, expected_code: int = 200 + ) -> JsonDict: url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) channel = self.make_request("GET", url, access_token=self.token) diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index e9f870403536..44f333a0ee77 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -134,7 +134,7 @@ def _create_test_room(self) -> Tuple[str, str, str, str]: return room_id, event_id_a, event_id_b, event_id_c @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) - def test_same_state_groups_for_whole_historical_batch(self): + def test_same_state_groups_for_whole_historical_batch(self) -> None: """Make sure that when using the `/batch_send` endpoint to import a bunch of historical messages, it re-uses the same `state_group` across the whole batch. This is an easy optimization to make sure we're getting diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b7f086927bcf..e0b11e726433 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -65,7 +65,7 @@ def make_homeserver(self, reactor, clock): async def _insert_client_ip(*args, **kwargs): return None - self.hs.get_datastore().insert_client_ip = _insert_client_ip + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip return self.hs @@ -95,7 +95,7 @@ def prepare(self, reactor, clock, hs): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -103,7 +103,7 @@ def prepare(self, reactor, clock, hs): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -125,28 +125,28 @@ def send_msg_path(): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): topic_content = b'{"topic":"My Topic Name"}' @@ -156,28 +156,28 @@ def test_topic_perms(self): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -185,25 +185,25 @@ def test_topic_perms(self): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -211,7 +211,7 @@ def test_topic_perms(self): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room=None, members: Iterable = frozenset(), expect_code=None @@ -219,7 +219,7 @@ def _test_get_membership( for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) - self.assertEquals(expect_code, channel.code) + self.assertEqual(expect_code, channel.code) def test_membership_basic_room_perms(self): # === room does not exist === @@ -478,16 +478,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self): """ @@ -498,7 +498,7 @@ def test_get_member_list_no_permission_with_at_token(self): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -507,7 +507,7 @@ def test_get_member_list_no_permission_with_at_token(self): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self): """ @@ -520,14 +520,14 @@ def test_get_member_list_no_permission_former_member(self): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self): """ @@ -541,14 +541,14 @@ def test_get_member_list_no_permission_former_member_with_at_token(self): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -560,14 +560,14 @@ def test_get_member_list_no_permission_former_member_with_at_token(self): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): room_creator = "@some_other_guy:red" @@ -576,17 +576,17 @@ def test_get_member_list_mixed_memberships(self): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) class RoomsCreateTestCase(RoomBase): @@ -598,19 +598,19 @@ def test_post_room_no_keys(self): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): @@ -618,16 +618,16 @@ def test_post_room_known_and_unknown_keys(self): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) def test_post_room_invitees_invalid_mxid(self): # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -635,7 +635,7 @@ def test_post_room_invitees_invalid_mxid(self): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self): @@ -667,7 +667,7 @@ def test_post_room_invitees_ratelimit(self): # Add the current user to the ratelimit overrides, allowing them no ratelimiting. self.get_success( - self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) + self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) ) # Test that the invites aren't ratelimited anymore. @@ -694,9 +694,9 @@ async def user_may_join_room( "/createRoom", {}, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) - self.assertEquals(join_mock.call_count, 0) + self.assertEqual(join_mock.call_count, 0) class RoomTopicTestCase(RoomBase): @@ -712,54 +712,54 @@ def prepare(self, reactor, clock, hs): def test_invalid_puts(self): # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there channel = self.make_request("GET", self.path) - self.assertEquals(404, channel.code, msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -775,22 +775,22 @@ def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -799,7 +799,7 @@ def test_invalid_puts(self): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( @@ -810,13 +810,13 @@ def test_rooms_members_self(self): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} - self.assertEquals(expected_response, channel.json_body) + self.assertEqual(expected_response, channel.json_body) def test_rooms_members_other(self): self.other_id = "@zzsid1:red" @@ -828,11 +828,11 @@ def test_rooms_members_other(self): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self): self.other_id = "@zzsid1:red" @@ -847,11 +847,11 @@ def test_rooms_members_other_custom_keys(self): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(json.loads(content), channel.json_body) class RoomInviteRatelimitTestCase(RoomBase): @@ -937,7 +937,7 @@ async def user_may_join_room( False, ), ) - self.assertEquals( + self.assertEqual( callback_mock.call_args, expected_call_args, callback_mock.call_args, @@ -955,7 +955,7 @@ async def user_may_join_room( True, ), ) - self.assertEquals( + self.assertEqual( callback_mock.call_args, expected_call_args, callback_mock.call_args, @@ -1013,7 +1013,7 @@ def test_join_local_ratelimit_profile_change(self): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1023,10 +1023,10 @@ def test_join_local_ratelimit_profile_change(self): ) channel = self.make_request("GET", path) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) self.assertIn("displayname", channel.json_body) - self.assertEquals(channel.json_body["displayname"], "John Doe") + self.assertEqual(channel.json_body["displayname"], "John Doe") @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} @@ -1047,7 +1047,7 @@ def test_join_local_ratelimit_idempotent(self): # if all of these requests ended up joining the user to a room. for _ in range(4): channel = self.make_request("POST", path % room_id, {}) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) @unittest.override_config( { @@ -1060,7 +1060,9 @@ def test_autojoin_rooms(self): user_id = self.register_user("testuser", "password") # Check that the new user successfully joined the four rooms - rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + rooms = self.get_success( + self.hs.get_datastores().main.get_rooms_for_user(user_id) + ) self.assertEqual(len(rooms), 4) @@ -1076,40 +1078,40 @@ def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b"text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b"") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) class RoomInitialSyncTestCase(RoomBase): @@ -1123,10 +1125,10 @@ def prepare(self, reactor, clock, hs): def test_initial_sync(self): channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.room_id, channel.json_body["room_id"]) - self.assertEquals("join", channel.json_body["membership"]) + self.assertEqual(self.room_id, channel.json_body["room_id"]) + self.assertEqual("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict state = {} @@ -1150,7 +1152,7 @@ def test_initial_sync(self): e["content"]["user_id"]: e for e in channel.json_body["presence"] } self.assertTrue(self.user_id in presence_by_user) - self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) + self.assertEqual("m.presence", presence_by_user[self.user_id]["type"]) class RoomMessageListTestCase(RoomBase): @@ -1166,9 +1168,9 @@ def test_topo_token_is_accepted(self): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) + self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -1177,14 +1179,14 @@ def test_stream_token_is_accepted_for_fwd_pagianation(self): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) + self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) def test_room_messages_purge(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main pagination_handler = self.hs.get_pagination_handler() # Send a first message in the room, which will be removed by the purge. @@ -2612,7 +2614,7 @@ def test_threepid_invite_spamcheck(self): }, access_token=self.tok, ) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) # Check that the callback was called with the right params. mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) @@ -2634,7 +2636,7 @@ def test_threepid_invite_spamcheck(self): }, access_token=self.tok, ) - self.assertEquals(channel.code, 403) + self.assertEqual(channel.code, 403) # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index e2ed14457fc0..c3942889e113 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -26,7 +26,7 @@ class SendToDeviceTestCase(HomeserverTestCase): sync.register_servlets, ] - def test_user_to_user(self): + def test_user_to_user(self) -> None: """A to-device message from one user to another should get delivered""" user1 = self.register_user("u1", "pass") @@ -73,7 +73,7 @@ def test_user_to_user(self): self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self): + def test_local_room_key_request(self) -> None: """m.room_key_request has special-casing; test from local user""" user1 = self.register_user("u1", "pass") user1_tok = self.login("u1", "pass", "d1") @@ -128,7 +128,7 @@ def test_local_room_key_request(self): ) @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self): + def test_remote_room_key_request(self) -> None: """m.room_key_request has special-casing; test from remote user""" user2 = self.register_user("u2", "pass") user2_tok = self.login("u2", "pass", "d2") @@ -199,7 +199,7 @@ def test_remote_room_key_request(self): }, ) - def test_limited_sync(self): + def test_limited_sync(self) -> None: """If a limited sync for to-devices happens the next /sync should respond immediately.""" self.register_user("u1", "pass") diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index b0c44af033c0..ae5ada3be7ef 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -14,6 +14,8 @@ from unittest.mock import Mock, patch +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import ( @@ -23,18 +25,20 @@ room, room_upgrade_rest_servlet, ) +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest class _ShadowBannedBase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Create two users, one of which is shadow-banned. self.banned_user_id = self.register_user("banned", "test") self.banned_access_token = self.login("banned", "test") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) @@ -55,7 +59,7 @@ class RoomTestCase(_ShadowBannedBase): room_upgrade_rest_servlet.register_servlets, ] - def test_invite(self): + def test_invite(self) -> None: """Invites from shadow-banned users don't actually get sent.""" # The create works fine. @@ -77,7 +81,7 @@ def test_invite(self): ) self.assertEqual(invited_rooms, []) - def test_invite_3pid(self): + def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() identity_handler.lookup_3pid = Mock( @@ -96,12 +100,12 @@ def test_invite_3pid(self): {"id_server": "test", "medium": "email", "address": "test@test.test"}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # This should have raised an error earlier, but double check this wasn't called. identity_handler.lookup_3pid.assert_not_called() - def test_create_room(self): + def test_create_room(self) -> None: """Invitations during a room creation should be discarded, but the room still gets created.""" # The room creation is successful. channel = self.make_request( @@ -110,7 +114,7 @@ def test_create_room(self): {"visibility": "public", "invite": [self.other_user_id]}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) room_id = channel.json_body["room_id"] # But the user wasn't actually invited. @@ -126,7 +130,7 @@ def test_create_room(self): users = self.get_success(self.store.get_users_in_room(room_id)) self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) - def test_message(self): + def test_message(self) -> None: """Messages from shadow-banned users don't actually get sent.""" room_id = self.helper.create_room_as( @@ -151,7 +155,7 @@ def test_message(self): ) self.assertNotIn(event_id, latest_events) - def test_upgrade(self): + def test_upgrade(self) -> None: """A room upgrade should fail, but look like it succeeded.""" # The create works fine. @@ -165,7 +169,7 @@ def test_upgrade(self): {"new_version": "6"}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # A new room_id should be returned. self.assertIn("replacement_room", channel.json_body) @@ -177,7 +181,7 @@ def test_upgrade(self): # The summary should be empty since the room doesn't exist. self.assertEqual(summary, {}) - def test_typing(self): + def test_typing(self) -> None: """Typing notifications should not be propagated into the room.""" # The create works fine. room_id = self.helper.create_room_as( @@ -190,11 +194,11 @@ def test_typing(self): {"typing": True, "timeout": 30000}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # There should be no typing events. event_source = self.hs.get_event_sources().sources.typing - self.assertEquals(event_source.get_current_key(), 0) + self.assertEqual(event_source.get_current_key(), 0) # The other user can join and send typing events. self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) @@ -205,10 +209,10 @@ def test_typing(self): {"typing": True, "timeout": 30000}, access_token=self.other_access_token, ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # These appear in the room. - self.assertEquals(event_source.get_current_key(), 1) + self.assertEqual(event_source.get_current_key(), 1) events = self.get_success( event_source.get_new_events( user=UserID.from_string(self.other_user_id), @@ -218,7 +222,7 @@ def test_typing(self): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -240,7 +244,7 @@ class ProfileTestCase(_ShadowBannedBase): room.register_servlets, ] - def test_displayname(self): + def test_displayname(self) -> None: """Profile changes should succeed, but don't end up in a room.""" original_display_name = "banned" new_display_name = "new name" @@ -257,7 +261,7 @@ def test_displayname(self): {"displayname": new_display_name}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertEqual(channel.json_body, {}) # The user's display name should be updated. @@ -281,7 +285,7 @@ def test_displayname(self): event.content, {"membership": "join", "displayname": original_display_name} ) - def test_room_displayname(self): + def test_room_displayname(self) -> None: """Changes to state events for a room should be processed, but not end up in the room.""" original_display_name = "banned" new_display_name = "new name" @@ -299,7 +303,7 @@ def test_room_displayname(self): {"membership": "join", "displayname": new_display_name}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("event_id", channel.json_body) # The display name in the room should not be changed. diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index 283eccd53f95..3818b7b14bf3 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, room, shared_rooms +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -30,16 +34,16 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): shared_rooms.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["update_user_directory"] = True return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() - def _get_shared_rooms(self, token, other_user) -> FakeChannel: + def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel: return self.make_request( "GET", "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" @@ -47,14 +51,14 @@ def _get_shared_rooms(self, token, other_user) -> FakeChannel: access_token=token, ) - def test_shared_room_list_public(self): + def test_shared_room_list_public(self) -> None: """ A room should show up in the shared list of rooms between two users if it is public. """ self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - def test_shared_room_list_private(self): + def test_shared_room_list_private(self) -> None: """ A room should show up in the shared list of rooms between two users if it is private. @@ -63,7 +67,7 @@ def test_shared_room_list_private(self): room_one_is_public=False, room_two_is_public=False ) - def test_shared_room_list_mixed(self): + def test_shared_room_list_mixed(self) -> None: """ The shared room list between two users should contain both public and private rooms. @@ -72,7 +76,7 @@ def test_shared_room_list_mixed(self): def _check_shared_rooms_with( self, room_one_is_public: bool, room_two_is_public: bool - ): + ) -> None: """Checks that shared public or private rooms between two users appear in their shared room lists """ @@ -91,9 +95,9 @@ def _check_shared_rooms_with( # Check shared rooms from user1's perspective. # We should see the one room in common channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room_id_one) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room_id_one) # Create another room and invite user2 to it room_id_two = self.helper.create_room_as( @@ -104,12 +108,12 @@ def _check_shared_rooms_with( # Check shared rooms again. We should now see both rooms. channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 2) for room_id_id in channel.json_body["joined"]: self.assertIn(room_id_id, [room_id_one, room_id_two]) - def test_shared_room_list_after_leave(self): + def test_shared_room_list_after_leave(self) -> None: """ A room should no longer be considered shared if the other user has left it. @@ -125,18 +129,18 @@ def test_shared_room_list_after_leave(self): # Assert user directory is not empty channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room) self.helper.leave(room, user=u1, tok=u1_token) # Check user1's view of shared rooms with user2 channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) # Check user2's view of shared rooms with user1 channel = self._get_shared_rooms(u2_token, u1) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index cd4af2b1f358..435101395204 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from typing import List, Optional from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import ( EventContentFields, @@ -24,6 +27,9 @@ RelationTypes, ) from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.federation.transport.test_knocking import ( @@ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_sync_argless(self): + def test_sync_argless(self) -> None: channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) @@ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_sync_filter_labels(self): + def test_sync_filter_labels(self) -> None: """Test that we can filter by a label.""" sync_filter = json.dumps( { @@ -77,7 +83,7 @@ def test_sync_filter_labels(self): self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - def test_sync_filter_not_labels(self): + def test_sync_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label.""" sync_filter = json.dumps( { @@ -99,7 +105,7 @@ def test_sync_filter_not_labels(self): events[2]["content"]["body"], "with two wrong labels", events[2] ) - def test_sync_filter_labels_not_labels(self): + def test_sync_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label.""" sync_filter = json.dumps( { @@ -118,7 +124,7 @@ def test_sync_filter_labels_not_labels(self): self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - def _test_sync_filter_labels(self, sync_filter): + def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def test_sync_backwards_typing(self): + def test_sync_backwards_typing(self) -> None: """ If the typing serial goes backwards and the typing handler is then reset (such as when the master restarts and sets the typing serial to 0), we @@ -231,10 +237,10 @@ def test_sync_backwards_typing(self): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Stop typing. @@ -243,7 +249,7 @@ def test_sync_backwards_typing(self): typing_url % (room, other_user_id, other_access_token), b'{"typing": false}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # Start typing. channel = self.make_request( @@ -251,11 +257,11 @@ def test_sync_backwards_typing(self): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # Should return immediately channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Reset typing serial back to 0, as if the master had. @@ -267,7 +273,7 @@ def test_sync_backwards_typing(self): self.helper.send(room, body="There!", tok=other_access_token) channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # This should time out! But it does not, because our stream token is @@ -275,7 +281,7 @@ def test_sync_backwards_typing(self): # already seen) is new, since it's got a token above our new, now-reset # stream token. channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Clear the typing information, so that it doesn't think everything is @@ -298,8 +304,8 @@ class SyncKnockTestCase( knock.register_servlets, ] - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main self.url = "/sync?since=%s" self.next_batch = "s0" @@ -336,7 +342,7 @@ def prepare(self, reactor, clock, hs): ) @override_config({"experimental_features": {"msc2403_enabled": True}}) - def test_knock_room_state(self): + def test_knock_room_state(self) -> None: """Tests that /sync returns state from a room after knocking on it.""" # Knock on a room channel = self.make_request( @@ -345,7 +351,7 @@ def test_knock_room_state(self): b"{}", self.knocker_tok, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # We expect to see the knock event in the stripped room state later self.expected_room_state[EventTypes.Member] = { @@ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -402,7 +408,7 @@ def prepare(self, reactor, clock, hs): self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_hidden_read_receipts(self): + def test_hidden_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -441,8 +447,8 @@ def test_hidden_read_receipts(self): ] ) def test_read_receipt_with_empty_body( - self, name, user_agent: str, expected_status_code: int - ): + self, name: str, user_agent: str, expected_status_code: int + ) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -455,11 +461,11 @@ def test_read_receipt_with_empty_body( ) self.assertEqual(channel.code, expected_status_code) - def _get_read_receipt(self): + def _get_read_receipt(self) -> Optional[JsonDict]: """Syncs and returns the read receipt.""" # Checks if event is a read receipt - def is_read_receipt(event): + def is_read_receipt(event: JsonDict) -> bool: return event["type"] == "m.receipt" # Sync @@ -477,7 +483,8 @@ def is_read_receipt(event): ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ "ephemeral" ]["events"] - return next(filter(is_read_receipt, ephemeral_events), None) + receipt_event = filter(is_read_receipt, ephemeral_events) + return next(receipt_event, None) class UnreadMessagesTestCase(unittest.HomeserverTestCase): @@ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -533,7 +540,7 @@ def prepare(self, reactor, clock, hs): tok=self.tok, ) - def test_unread_counts(self): + def test_unread_counts(self) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -640,7 +647,7 @@ def test_unread_counts(self): ) self._check_unread_count(5) - def _check_unread_count(self, expected_count: int): + def _check_unread_count(self, expected_count: int) -> None: """Syncs and compares the unread count with the expected value.""" channel = self.make_request( @@ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_noop_sync_does_not_tightloop(self): + def test_noop_sync_does_not_tightloop(self) -> None: """If the sync times out, we shouldn't cache the result Essentially a regression test for #8518. @@ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): devices.register_servlets, ] - def test_user_with_no_rooms_receives_self_device_list_updates(self): + def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: """Tests that a user with no rooms still receives their own device list updates""" device_id = "TESTDEVICE" diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index ac6b86ff6b02..bfc04785b7b2 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -15,12 +15,12 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin -from synapse.rest.client import login, room +from synapse.rest.client import account, login, profile, room from synapse.types import JsonDict, Requester, StateMap from synapse.util.frozenutils import unfreeze @@ -80,6 +80,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): admin.register_servlets, login.register_servlets, room.register_servlets, + profile.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -139,7 +141,7 @@ async def check(ev, state): {}, access_token=self.tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) callback.assert_called_once() @@ -157,7 +159,7 @@ async def check(ev, state): {}, access_token=self.tok, ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) def test_third_party_rules_workaround_synapse_errors_pass_through(self): """ @@ -193,7 +195,7 @@ async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]: access_token=self.tok, ) # Check the error code - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) # Check the JSON body has had the `nasty` key injected self.assertEqual( channel.json_body, @@ -329,10 +331,10 @@ def test_send_event(self): self.hs.get_module_api().create_and_send_event_into_room(event_dict) ) - self.assertEquals(event.sender, self.user_id) - self.assertEquals(event.room_id, self.room_id) - self.assertEquals(event.type, "m.room.message") - self.assertEquals(event.content, content) + self.assertEqual(event.sender, self.user_id) + self.assertEqual(event.room_id, self.room_id) + self.assertEqual(event.type, "m.room.message") + self.assertEqual(event.content, content) @unittest.override_config( { @@ -530,3 +532,216 @@ def _update_power_levels(self, event_default: int = 0): }, tok=self.tok, ) + + def test_on_profile_update(self): + """Tests that the on_profile_update module callback is correctly called on + profile updates. + """ + displayname = "Foo" + avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" + + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m) + + # Change the display name. + channel = self.make_request( + "PUT", + "/_matrix/client/v3/profile/%s/displayname" % self.user_id, + {"displayname": displayname}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called once for our user. + m.assert_called_once() + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Test that by_admin is False. + self.assertFalse(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertIsNone(profile_info.avatar_url) + + # Change the avatar. + channel = self.make_request( + "PUT", + "/_matrix/client/v3/profile/%s/avatar_url" % self.user_id, + {"avatar_url": avatar_url}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called once for our user. + self.assertEqual(m.call_count, 2) + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Test that by_admin is False. + self.assertFalse(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertEqual(profile_info.avatar_url, avatar_url) + + def test_on_profile_update_admin(self): + """Tests that the on_profile_update module callback is correctly called on + profile updates triggered by a server admin. + """ + displayname = "Foo" + avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" + + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Change a user's profile. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % self.user_id, + {"displayname": displayname, "avatar_url": avatar_url}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called twice (since we update the display name + # and avatar separately). + self.assertEqual(m.call_count, 2) + + # Get the arguments for the last call and check it's about the right user. + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Check that by_admin is True. + self.assertTrue(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertEqual(profile_info.avatar_url, avatar_url) + + def test_on_user_deactivation_status_changed(self): + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_user_deactivation_status_changed_callbacks.append( + deactivation_mock, + ) + # Also register a mocked callback for profile updates, to check that the + # deactivation code calls it in a way that let modules know the user is being + # deactivated. + profile_mock = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append( + profile_mock, + ) + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + tok = self.login("altan", "password") + + # Deactivate that user. + channel = self.make_request( + "POST", + "/_matrix/client/v3/account/deactivate", + { + "auth": { + "type": LoginType.PASSWORD, + "password": "password", + "identifier": { + "type": "m.id.user", + "user": user_id, + }, + }, + "erase": True, + }, + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID, and with a True + # deactivated flag and a False by_admin flag. + self.assertEqual(args[0], user_id) + self.assertTrue(args[1]) + self.assertFalse(args[2]) + + # Check that the profile update callback was called twice (once for the display + # name and once for the avatar URL), and that the "deactivation" boolean is true. + self.assertEqual(profile_mock.call_count, 2) + args = profile_mock.call_args[0] + self.assertTrue(args[3]) + + def test_on_user_deactivation_status_changed_admin(self): + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation triggered by a server admin as + well as a reactivation. + """ + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_user_deactivation_status_changed_callbacks.append(m) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": True}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + m.assert_called_once() + args = m.call_args[0] + + # Check that the mock was called with the right user ID, and with True deactivated + # and by_admin flags. + self.assertEqual(args[0], user_id) + self.assertTrue(args[1]) + self.assertTrue(args[2]) + + # Reactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": False, "password": "hackme"}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + self.assertEqual(m.call_count, 2) + args = m.call_args[0] + + # Check that the mock was called with the right user ID, and with a False + # deactivated flag and a True by_admin flag. + self.assertEqual(args[0], user_id) + self.assertFalse(args[1]) + self.assertTrue(args[2]) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index ee0abd5295af..8b2da88e8a58 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -57,7 +57,7 @@ async def get_user_by_access_token(token=None, allow_guest=False): async def _insert_client_ip(*args, **kwargs): return None - hs.get_datastore().insert_client_ip = _insert_client_ip + hs.get_datastores().main.insert_client_ip = _insert_client_ip return hs @@ -72,9 +72,9 @@ def test_set_typing(self): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=UserID.from_string(self.user_id), @@ -84,7 +84,7 @@ def test_set_typing(self): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -101,7 +101,7 @@ def test_set_not_typing(self): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": false}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) def test_typing_timeout(self): channel = self.make_request( @@ -109,19 +109,19 @@ def test_typing_timeout(self): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) self.reactor.advance(36) - self.assertEquals(self.event_source.get_current_key(), 2) + self.assertEqual(self.event_source.get_current_key(), 2) channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 3) + self.assertEqual(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index a42388b26ffb..b7d0f42daf5f 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -13,11 +13,14 @@ # limitations under the License. from typing import Optional +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventContentFields, EventTypes, RoomTypes from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin from synapse.rest.client import login, room, room_upgrade_rest_servlet from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -31,8 +34,8 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): room_upgrade_rest_servlet.register_servlets, ] - def prepare(self, reactor, clock, hs: "HomeServer"): - self.store = hs.get_datastore() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main self.creator = self.register_user("creator", "pass") self.creator_token = self.login(self.creator, "pass") @@ -60,15 +63,15 @@ def _upgrade_room( access_token=token or self.creator_token, ) - def test_upgrade(self): + def test_upgrade(self) -> None: """ Upgrading a room should work fine. """ channel = self._upgrade_room() - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) - def test_not_in_room(self): + def test_not_in_room(self) -> None: """ Upgrading a room should work fine. """ @@ -77,15 +80,15 @@ def test_not_in_room(self): roomless_token = self.login(roomless, "pass") channel = self._upgrade_room(roomless_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) - def test_power_levels(self): + def test_power_levels(self) -> None: """ Another user can upgrade the room if their power level is increased. """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -103,15 +106,15 @@ def test_power_levels(self): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) - def test_power_levels_user_default(self): + def test_power_levels_user_default(self) -> None: """ Another user can upgrade the room if the default power level for users is increased. """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -129,15 +132,15 @@ def test_power_levels_user_default(self): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) - def test_power_levels_tombstone(self): + def test_power_levels_tombstone(self) -> None: """ Another user can upgrade the room if they can send the tombstone event. """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -155,7 +158,7 @@ def test_power_levels_tombstone(self): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) power_levels = self.helper.get_state( self.room_id, @@ -164,7 +167,7 @@ def test_power_levels_tombstone(self): ) self.assertNotIn(self.other, power_levels["users"]) - def test_space(self): + def test_space(self) -> None: """Test upgrading a space.""" # Create a space. @@ -197,7 +200,7 @@ def test_space(self): # Upgrade the room! channel = self._upgrade_room(room_id=space_id) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) new_space_id = channel.json_body["replacement_room"] diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 1c0cb0cf4f34..28663826fcbc 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -19,6 +19,7 @@ import re import time import urllib.parse +from http import HTTPStatus from typing import ( Any, AnyStr, @@ -40,6 +41,7 @@ from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request @@ -47,15 +49,15 @@ from tests.test_utils.html_parsers import TestHtmlParser -@attr.s +@attr.s(auto_attribs=True) class RestHelper: """Contains extra helper functions to quickly and clearly perform a given REST action, which isn't the focus of the test. """ - hs = attr.ib() - site = attr.ib(type=Site) - auth_user_id = attr.ib() + hs: HomeServer + site: Site + auth_user_id: Optional[str] @overload def create_room_as( @@ -89,7 +91,7 @@ def create_room_as( is_public: Optional[bool] = None, room_version: Optional[str] = None, tok: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, extra_content: Optional[Dict] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ) -> Optional[str]: @@ -106,9 +108,13 @@ def create_room_as( default room version. tok: The access token to use in the request. expect_code: The expected HTTP response code. + extra_content: Extra keys to include in the body of the /createRoom request. + Note that if is_public is set, the "visibility" key will be overridden. + If room_version is set, the "room_version" key will be overridden. + custom_headers: HTTP headers to include in the request. Returns: - The ID of the newly created room. + The ID of the newly created room, or None if the request failed. """ temp_id = self.auth_user_id self.auth_user_id = room_creator @@ -133,12 +139,19 @@ def create_room_as( assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id - if expect_code == 200: + if expect_code == HTTPStatus.OK: return channel.json_body["room_id"] else: return None - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + def invite( + self, + room: str, + src: Optional[str] = None, + targ: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=src, @@ -152,7 +165,7 @@ def join( self, room: str, user: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, ) -> None: @@ -166,7 +179,14 @@ def join( expect_code=expect_code, ) - def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None): + def knock( + self, + room: Optional[str] = None, + user: Optional[str] = None, + reason: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: temp_id = self.auth_user_id self.auth_user_id = user path = "/knock/%s" % room @@ -195,7 +215,13 @@ def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None): self.auth_user_id = temp_id - def leave(self, room=None, user=None, expect_code=200, tok=None): + def leave( + self, + room: str, + user: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=user, @@ -205,14 +231,22 @@ def leave(self, room=None, user=None, expect_code=200, tok=None): expect_code=expect_code, ) - def ban(self, room: str, src: str, targ: str, **kwargs: object): + def ban( + self, + room: str, + src: str, + targ: str, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: """A convenience helper: `change_membership` with `membership` preset to "ban".""" self.change_membership( room=room, src=src, targ=targ, + tok=tok, membership=Membership.BAN, - **kwargs, + expect_code=expect_code, ) def change_membership( @@ -224,7 +258,7 @@ def change_membership( extra_data: Optional[dict] = None, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, ) -> None: """ @@ -290,13 +324,13 @@ def change_membership( def send( self, - room_id, - body=None, - txn_id=None, - tok=None, - expect_code=200, + room_id: str, + body: Optional[str] = None, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if body is None: body = "body_text_here" @@ -314,14 +348,14 @@ def send( def send_event( self, - room_id, - type, + room_id: str, + type: str, content: Optional[dict] = None, - txn_id=None, - tok=None, - expect_code=200, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -353,11 +387,11 @@ def _read_write_state( room_id: str, event_type: str, body: Optional[Dict[str, Any]], - tok: str, - expect_code: int = 200, + tok: Optional[str], + expect_code: int = HTTPStatus.OK, state_key: str = "", method: str = "GET", - ) -> Dict: + ) -> JsonDict: """Read or write some state from a given room Args: @@ -406,9 +440,9 @@ def get_state( room_id: str, event_type: str, tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Gets some state from a room Args: @@ -433,10 +467,10 @@ def send_state( room_id: str, event_type: str, body: Dict[str, Any], - tok: str, - expect_code: int = 200, + tok: Optional[str], + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Set some state in a room Args: @@ -463,8 +497,8 @@ def upload_media( image_data: bytes, tok: str, filename: str = "test.png", - expect_code: int = 200, - ) -> dict: + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: """Upload a piece of test media to the media repo Args: resource: The resource that will handle the upload request @@ -509,7 +543,7 @@ def login_via_oidc(self, remote_user_id: str) -> JsonDict: channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) # expect a confirmation page - assert channel.code == 200, channel.result + assert channel.code == HTTPStatus.OK, channel.result # fish the matrix login token out of the body of the confirmation page m = re.search( @@ -528,7 +562,7 @@ def login_via_oidc(self, remote_user_id: str) -> JsonDict: "/login", content={"type": "m.login.token", "token": login_token}, ) - assert channel.code == 200 + assert channel.code == HTTPStatus.OK return channel.json_body def auth_via_oidc( @@ -633,11 +667,16 @@ def complete_oidc_auth( (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] - async def mock_req(method: str, uri: str, data=None, headers=None): + async def mock_req( + method: str, + uri: str, + data: Optional[dict] = None, + headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + ): (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( - code=200, + code=HTTPStatus.OK, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), ) @@ -735,7 +774,7 @@ def initiate_sso_ui_auth( self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint ) # that should serve a confirmation page - assert channel.code == 200, channel.text_body + assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 4cf1ed5ddff0..cba9be17c4ca 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -94,7 +94,7 @@ def test_ensure_media_is_in_local_cache(self): self.assertTrue(os.path.exists(local_path)) # Asserts the file is under the expected local cache directory - self.assertEquals( + self.assertEqual( os.path.commonprefix([self.primary_base_path, local_path]), self.primary_base_path, ) @@ -243,7 +243,7 @@ def prepare(self, reactor, clock, hs): media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] self.thumbnail_resource = media_resource.children[b"thumbnail"] - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository() self.media_id = "example.com/12345" diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 36c495954ff3..02b96c9e6ecf 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -242,7 +242,7 @@ def default_config(self): return c def prepare(self, reactor, clock, hs): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() self.event_source = self.hs.get_event_sources() diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 36c933b9e91a..50c20c5b921a 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -26,7 +26,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") def test_background_remove_deleted_devices_from_device_inbox(self): diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 5ae491ff5a52..1f6a9eb07bfc 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -37,7 +37,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main # insert some test data for rid in ("room1", "room2"): @@ -88,18 +88,18 @@ def test_simple(self): res = self.get_success( self.store.have_seen_events("room1", ["event10", "event19"]) ) - self.assertEquals(res, {"event10"}) + self.assertEqual(res, {"event10"}) # that should result in a single db query - self.assertEquals(ctx.get_resource_usage().db_txn_count, 1) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( self.store.have_seen_events("room1", ["event10", "event19"]) ) - self.assertEquals(res, {"event10"}) - self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + self.assertEqual(res, {"event10"}) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) def test_query_via_event_cache(self): # fetch an event into the event cache @@ -108,8 +108,8 @@ def test_query_via_event_cache(self): # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: res = self.get_success(self.store.have_seen_events("room1", ["event10"])) - self.assertEquals(res, {"event10"}) - self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + self.assertEqual(res, {"event10"}) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) class EventCacheTestCase(unittest.HomeserverTestCase): @@ -122,7 +122,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login(self.user, "pass") @@ -163,7 +163,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): """Test event fetching during a database outage.""" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.room_id = f"!room:{hs.hostname}" self.event_ids = [f"event{i}" for i in range(20)] diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index d326a1d6a642..3ac4646969a4 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -20,7 +20,7 @@ class LockTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_simple_lock(self): """Test that we can take out a lock and that while we hold it nobody diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 7496974da3a8..9abd0cb4468d 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -28,7 +28,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 200b9198f910..4899cd5c3610 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,7 +20,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.storage = hs.get_datastore() + self.storage = hs.get_datastores().main self.table_name = "table_" + secrets.token_hex(6) self.get_success( diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index d697d2bc1e79..272cd3540220 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -21,7 +21,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = "@user:test" def _update_ignore_list( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ddcb7f554918..ee599f433667 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -88,21 +88,21 @@ def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: def test_retrieve_unknown_service_token(self) -> None: service = self.store.get_app_service_by_token("invalid_token") - self.assertEquals(service, None) + self.assertEqual(service, None) def test_retrieval_of_service(self) -> None: stored_service = self.store.get_app_service_by_token(self.as_token) assert stored_service is not None - self.assertEquals(stored_service.token, self.as_token) - self.assertEquals(stored_service.id, self.as_id) - self.assertEquals(stored_service.url, self.as_url) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) + self.assertEqual(stored_service.token, self.as_token) + self.assertEqual(stored_service.id, self.as_id) + self.assertEqual(stored_service.url, self.as_url) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_ALIASES], []) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_ROOMS], []) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_USERS], []) def test_retrieval_of_all_services(self) -> None: services = self.store.get_app_services() - self.assertEquals(len(services), 3) + self.assertEqual(len(services), 3) class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): @@ -182,7 +182,7 @@ def test_get_appservice_state_none( ) -> None: service = Mock(id="999") state = self.get_success(self.store.get_appservice_state(service)) - self.assertEquals(None, state) + self.assertEqual(None, state) def test_get_appservice_state_up( self, @@ -194,7 +194,7 @@ def test_get_appservice_state_up( state = self.get_success( defer.ensureDeferred(self.store.get_appservice_state(service)) ) - self.assertEquals(ApplicationServiceState.UP, state) + self.assertEqual(ApplicationServiceState.UP, state) def test_get_appservice_state_down( self, @@ -210,7 +210,7 @@ def test_get_appservice_state_down( ) service = Mock(id=self.as_list[1]["id"]) state = self.get_success(self.store.get_appservice_state(service)) - self.assertEquals(ApplicationServiceState.DOWN, state) + self.assertEqual(ApplicationServiceState.DOWN, state) def test_get_appservices_by_state_none( self, @@ -218,7 +218,7 @@ def test_get_appservices_by_state_none( services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(0, len(services)) + self.assertEqual(0, len(services)) def test_set_appservices_state_down( self, @@ -235,7 +235,7 @@ def test_set_appservices_state_down( (ApplicationServiceState.DOWN.value,), ) ) - self.assertEquals(service.id, rows[0][0]) + self.assertEqual(service.id, rows[0][0]) def test_set_appservices_state_multiple_up( self, @@ -258,7 +258,7 @@ def test_set_appservices_state_multiple_up( (ApplicationServiceState.UP.value,), ) ) - self.assertEquals(service.id, rows[0][0]) + self.assertEqual(service.id, rows[0][0]) def test_create_appservice_txn_first( self, @@ -267,12 +267,12 @@ def test_create_appservice_txn_first( events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) ) - self.assertEquals(txn.id, 1) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 1) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_older_last_txn( self, @@ -283,11 +283,11 @@ def test_create_appservice_txn_older_last_txn( self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9646) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9646) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_up_to_date_last_txn( self, @@ -296,11 +296,11 @@ def test_create_appservice_txn_up_to_date_last_txn( events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9644) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9644) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_up_fuzzing( self, @@ -320,11 +320,11 @@ def test_create_appservice_txn_up_fuzzing( self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9644) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9644) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_complete_appservice_txn_first_txn( self, @@ -346,8 +346,8 @@ def test_complete_appservice_txn_first_txn( (service.id,), ) ) - self.assertEquals(1, len(res)) - self.assertEquals(txn_id, res[0][0]) + self.assertEqual(1, len(res)) + self.assertEqual(txn_id, res[0][0]) res = self.get_success( self.db_pool.runQuery( @@ -357,7 +357,7 @@ def test_complete_appservice_txn_first_txn( (txn_id,), ) ) - self.assertEquals(0, len(res)) + self.assertEqual(0, len(res)) def test_complete_appservice_txn_existing_in_state_table( self, @@ -379,9 +379,9 @@ def test_complete_appservice_txn_existing_in_state_table( (service.id,), ) ) - self.assertEquals(1, len(res)) - self.assertEquals(txn_id, res[0][0]) - self.assertEquals(ApplicationServiceState.UP.value, res[0][1]) + self.assertEqual(1, len(res)) + self.assertEqual(txn_id, res[0][0]) + self.assertEqual(ApplicationServiceState.UP.value, res[0][1]) res = self.get_success( self.db_pool.runQuery( @@ -391,7 +391,7 @@ def test_complete_appservice_txn_existing_in_state_table( (txn_id,), ) ) - self.assertEquals(0, len(res)) + self.assertEqual(0, len(res)) def test_get_oldest_unsent_txn_none( self, @@ -399,7 +399,7 @@ def test_get_oldest_unsent_txn_none( service = Mock(id=self.as_list[0]["id"]) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) - self.assertEquals(None, txn) + self.assertEqual(None, txn) def test_get_oldest_unsent_txn(self) -> None: service = Mock(id=self.as_list[0]["id"]) @@ -416,9 +416,9 @@ def test_get_oldest_unsent_txn(self) -> None: self.get_success(self._insert_txn(service.id, 12, other_events)) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) - self.assertEquals(service, txn.service) - self.assertEquals(10, txn.id) - self.assertEquals(events, txn.events) + self.assertEqual(service, txn.service) + self.assertEqual(10, txn.id) + self.assertEqual(events, txn.events) def test_get_appservices_by_state_single( self, @@ -433,8 +433,8 @@ def test_get_appservices_by_state_single( services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(1, len(services)) - self.assertEquals(self.as_list[0]["id"], services[0].id) + self.assertEqual(1, len(services)) + self.assertEqual(self.as_list[0]["id"], services[0].id) def test_get_appservices_by_state_multiple( self, @@ -455,8 +455,8 @@ def test_get_appservices_by_state_multiple( services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(2, len(services)) - self.assertEquals( + self.assertEqual(2, len(services)) + self.assertEqual( {self.as_list[2]["id"], self.as_list[0]["id"]}, {services[0].id, services[1].id}, ) @@ -467,7 +467,7 @@ def prepare( self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: self.service = Mock(id="foo") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_appservice_state(self.service, ApplicationServiceState.UP) ) @@ -476,12 +476,12 @@ def test_get_type_stream_id_for_appservice_no_value(self) -> None: value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) - self.assertEquals(value, 0) + self.assertEqual(value, 0) value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "presence") ) - self.assertEquals(value, 0) + self.assertEqual(value, 0) def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 6156dfac4e58..39dcc094bd8b 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -24,7 +24,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -42,7 +42,7 @@ def test_do_background_update(self): # the target runtime for each bg update target_background_update_duration_ms = 100 - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", @@ -102,7 +102,7 @@ async def update(progress, count): class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -138,7 +138,7 @@ class MockCM: ) def test_controller(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 3e4f0579c9cb..a8ffb52c0503 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -103,7 +103,7 @@ def test_select_one_1col(self): ) ) - self.assertEquals("Value", value) + self.assertEqual("Value", value) self.mock_txn.execute.assert_called_with( "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -121,7 +121,7 @@ def test_select_one_3col(self): ) ) - self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) + self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) self.mock_txn.execute.assert_called_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -154,7 +154,7 @@ def test_select_list(self): ) ) - self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) + self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.mock_txn.execute.assert_called_with( "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index a59c28f89681..ce89c9691206 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -30,7 +30,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() # Create a test user and room @@ -242,7 +242,7 @@ def make_homeserver(self, reactor, clock): return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() self.event_creator_handler = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c8ac67e35b67..49ad3c132425 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -35,7 +35,7 @@ def make_homeserver(self, reactor, clock): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_insert_new_client_ip(self): self.reactor.advance(12345678) @@ -666,7 +666,7 @@ def make_homeserver(self, reactor, clock): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user_id = self.register_user("bob", "abc123", True) def test_request_with_xforwarded(self): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index b547bf8d9978..21ffc5a9095b 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -19,7 +19,7 @@ class DeviceStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_store_new_device(self): self.get_success( diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 43628ce44f41..20bf3ca17b9a 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -19,7 +19,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") @@ -31,7 +31,7 @@ def test_room_to_alias(self): ) ) - self.assertEquals( + self.assertEqual( ["#my-room:test"], (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), ) diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 7556171d8ac4..fb96ab3a2fd0 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -28,7 +28,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_room_keys_version_delete(self): diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 3bf6e337f4e1..0f04493ad06c 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -17,7 +17,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_key_without_device_name(self): now = 1470174257070 diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index e3273a93f9a2..401020fd6361 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -30,7 +30,7 @@ class EventChainStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._next_stream_ordering = 1 def test_simple(self): @@ -492,7 +492,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") self.requester = create_requester(self.user_id) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 667ca90a4d3e..645d564d1c40 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -31,7 +31,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_get_prev_events_for_room(self): room_id = "@ROOM:local" diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 5e02d2c7f09e..eb47563bdf03 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -30,7 +30,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.persist_events_store = hs.get_datastores().persist_events def test_get_unread_push_actions_for_user_in_range_for_http(self): @@ -57,7 +57,7 @@ def _assert_counts(noitf_count, highlight_count): "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) ) - self.assertEquals( + self.assertEqual( counts, NotifCounts( notify_count=noitf_count, diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index f462a8b1c721..ef5e25873c22 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -32,7 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self.persistence = self.hs.get_storage().persistence - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.register_user("user", "pass") self.token = self.login("user", "pass") @@ -329,3 +329,110 @@ def test_do_not_prune_gap_if_not_dummy(self): # Check the new extremity is just the new remote event. self.assert_extremities([local_message_event_id, remote_event_2.event_id]) + + +class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.state = self.hs.get_state_handler() + self.persistence = self.hs.get_storage().persistence + self.store = self.hs.get_datastores().main + + def test_remote_user_rooms_cache_invalidated(self): + """Test that if the server leaves a room the `get_rooms_for_user` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_rooms_for_user` to add the remote user to the cache + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), {room_id}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), set()) + + def test_room_remote_user_cache_invalidated(self): + """Test that if the server leaves a room the `get_users_in_room` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_users_in_room` to add the remote user to the cache + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(set(users), {user_id, remote_user}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(users, []) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 748607828496..6ac4b93f981a 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -26,7 +26,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -459,7 +459,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -585,7 +585,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index a94b5fd721f2..9059095525e0 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -37,7 +37,7 @@ def decode_verify_key_base64(key_id: str, key_base64: str): class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:KEY_ID_2" @@ -74,7 +74,7 @@ def test_get_server_verify_keys(self): def test_cache(self): """Check that updates correctly invalidate the cache.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index f8d11bac4ec5..5806cb0e4bb7 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -22,7 +22,7 @@ class DataStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: super(DataStoreTestCase, self).setUp() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = UserID.from_string("@abcde:test") self.displayname = "Frank" @@ -38,12 +38,12 @@ def test_get_users_paginate(self) -> None: self.store.get_users_paginate(0, 10, name="bc", guests=False) ) - self.assertEquals(1, total) - self.assertEquals(self.displayname, users.pop()["displayname"]) + self.assertEqual(1, total) + self.assertEqual(self.displayname, users.pop()["displayname"]) users, total = self.get_success( self.store.get_users_paginate(0, 10, name="BC", guests=False) ) - self.assertEquals(1, total) - self.assertEquals(self.displayname, users.pop()["displayname"]) + self.assertEqual(1, total) + self.assertEqual(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d6b4cdd788ff..79648d45dba4 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -45,7 +45,7 @@ def default_config(self): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # Advance the clock a bit reactor.advance(FORTY_DAYS) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index d37736edf8f9..a019d06e09c5 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -22,7 +22,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_frank = UserID.from_string("@frank:test") @@ -33,7 +33,7 @@ def test_displayname(self) -> None: self.store.set_profile_displayname(self.u_frank.localpart, "Frank") ) - self.assertEquals( + self.assertEqual( "Frank", ( self.get_success( @@ -60,7 +60,7 @@ def test_avatar_url(self) -> None: ) ) - self.assertEquals( + self.assertEqual( "http://my.site/here", ( self.get_success( diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 22a77c3cccc5..08cc60237ec1 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -30,7 +30,7 @@ def make_homeserver(self, reactor, clock): def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = self.hs.get_storage() def test_purge_history(self): @@ -47,7 +47,7 @@ def test_purge_history(self): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) - token_str = self.get_success(token.to_string(self.hs.get_datastore())) + token_str = self.get_success(token.to_string(self.hs.get_datastores().main)) # Purge everything before this topological token self.get_success( diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 8c95a0a2fb1f..03e9cc7d4ad9 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -30,7 +30,7 @@ def default_config(self): return config def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 97480652824d..a49ac1525ec7 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -20,7 +20,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = "@my-user:test" self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"] @@ -30,7 +30,7 @@ def prepare(self, reactor, clock, hs): def test_register(self): self.get_success(self.store.register_user(self.user_id, self.pwhash)) - self.assertEquals( + self.assertEqual( { # TODO(paul): Surely this field should be 'user_id', not 'name' "name": self.user_id, @@ -131,7 +131,7 @@ def test_3pid_inhibit_invalid_validation_session_error(self): ), ThreepidValidationError, ) - self.assertEquals(e.value.msg, "Unknown session_id", e) + self.assertEqual(e.value.msg, "Unknown session_id", e) # Set the config setting to true. self.store._ignore_unknown_session_error = True @@ -146,4 +146,4 @@ def test_3pid_inhibit_invalid_validation_session_error(self): ), ThreepidValidationError, ) - self.assertEquals(e.value.msg, "Validation token not found or has expired", e) + self.assertEqual(e.value.msg, "Validation token not found or has expired", e) diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index cfc8098af67e..0baa54312e33 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -56,7 +56,7 @@ def default_config(self): def test_rolling_back(self): """Test that workers can start if the DB is a newer schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -72,7 +72,7 @@ def test_rolling_back(self): def test_not_upgraded_old_schema_version(self): """Test that workers don't start if the DB has an older schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -92,7 +92,7 @@ def test_not_upgraded_current_schema_version_with_outstanding_deltas(self): Test that workers don't start if the DB is on the current schema version, but there are still outstanding delta migrations to run. """ - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 31ce7f62528f..5b011e18cd69 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -23,7 +23,7 @@ class RoomStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#a-room-name:test") @@ -71,7 +71,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # Room events need the full datastore, for persist_event() and # get_room_state() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() @@ -104,7 +104,7 @@ def STALE_test_room_name(self): self.store.get_current_state(room_id=self.room.to_string()) ) - self.assertEquals(1, len(state)) + self.assertEqual(1, len(state)) self.assertObjectHasAttributes( {"type": "m.room.name", "room_id": self.room.to_string(), "name": name}, state[0], @@ -121,7 +121,7 @@ def STALE_test_room_topic(self): self.store.get_current_state(room_id=self.room.to_string()) ) - self.assertEquals(1, len(state)) + self.assertEqual(1, len(state)) self.assertObjectHasAttributes( {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic}, state[0], diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8971ecccbd6d..8dfc1e1db903 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -13,13 +13,16 @@ # limitations under the License. import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.api.errors import StoreError from synapse.rest.client import login, room from synapse.storage.engines import PostgresEngine -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, skip_unless +from tests.utils import USE_POSTGRES_FOR_TESTS -class NullByteInsertionTest(HomeserverTestCase): +class EventSearchInsertionTest(HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -46,11 +49,11 @@ def test_null_byte(self): self.assertIn("event_id", response) # Check that search works for the message where the null byte was replaced - store = self.hs.get_datastore() + store = self.hs.get_datastores().main result = self.get_success( store.search_msgs([room_id], "hi bob", ["content.body"]) ) - self.assertEquals(result.get("count"), 1) + self.assertEqual(result.get("count"), 1) if isinstance(store.database_engine, PostgresEngine): self.assertIn("hi", result.get("highlights")) self.assertIn("bob", result.get("highlights")) @@ -59,16 +62,126 @@ def test_null_byte(self): result = self.get_success( store.search_msgs([room_id], "another", ["content.body"]) ) - self.assertEquals(result.get("count"), 1) + self.assertEqual(result.get("count"), 1) if isinstance(store.database_engine, PostgresEngine): self.assertIn("another", result.get("highlights")) # Check that search works for a search term that overlaps with the message # containing a null byte and an unrelated message. result = self.get_success(store.search_msgs([room_id], "hi", ["content.body"])) - self.assertEquals(result.get("count"), 2) + self.assertEqual(result.get("count"), 2) result = self.get_success( store.search_msgs([room_id], "hi alice", ["content.body"]) ) if isinstance(store.database_engine, PostgresEngine): self.assertIn("alice", result.get("highlights")) + + def test_non_string(self): + """Test that non-string `value`s are not inserted into `event_search`. + + This is particularly important when using sqlite, since a sqlite column can hold + both strings and integers. When using Postgres, integers are automatically + converted to strings. + + Regression test for #11918. + """ + store = self.hs.get_datastores().main + + # Register a user and create a room + user_id = self.register_user("alice", "password") + access_token = self.login("alice", "password") + room_id = self.helper.create_room_as("alice", tok=access_token) + room_version = self.get_success(store.get_room_version(room_id)) + + # Construct a message with a numeric body to be received over federation + # The message can't be sent using the client API, since Synapse's event + # validation will reject it. + prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + prev_event = self.get_success(store.get_event(prev_event_ids[0])) + prev_state_map = self.get_success( + self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) + ) + + event_dict = { + "type": EventTypes.Message, + "content": {"msgtype": "m.text", "body": 2}, + "room_id": room_id, + "sender": user_id, + "depth": prev_event.depth + 1, + "prev_events": prev_event_ids, + "origin_server_ts": self.clock.time_msec(), + } + builder = self.hs.get_event_builder_factory().for_room_version( + room_version, event_dict + ) + event = self.get_success( + builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=self.hs.get_event_auth_handler().compute_auth_events( + builder, + prev_state_map, + for_verification=False, + ), + depth=event_dict["depth"], + ) + ) + + # Receive the event + self.get_success( + self.hs.get_federation_event_handler().on_receive_pdu( + self.hs.hostname, event + ) + ) + + # The event should not have an entry in the `event_search` table + f = self.get_failure( + store.db_pool.simple_select_one_onecol( + "event_search", + {"room_id": room_id, "event_id": event.event_id}, + "event_id", + ), + StoreError, + ) + self.assertEqual(f.value.code, 404) + + @skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite") + def test_sqlite_non_string_deletion_background_update(self): + """Test the background update to delete bad rows from `event_search`.""" + store = self.hs.get_datastores().main + + # Populate `event_search` with dummy data + self.get_success( + store.db_pool.simple_insert_many( + "event_search", + keys=["event_id", "room_id", "key", "value"], + values=[ + ("event1", "room_id", "content.body", "hi"), + ("event2", "room_id", "content.body", "2"), + ("event3", "room_id", "content.body", 3), + ], + desc="populate_event_search", + ) + ) + + # Run the background update + store.db_pool.updates._all_done = False + self.get_success( + store.db_pool.simple_insert( + "background_updates", + { + "update_name": "event_search_sqlite_delete_non_strings", + "progress_json": "{}", + }, + ) + ) + self.wait_for_background_updates() + + # The non-string `value`s ought to be gone now. + values = self.get_success( + store.db_pool.simple_select_onecol( + "event_search", + {"room_id": "room_id"}, + "value", + ), + ) + self.assertCountEqual(values, ["hi", "2"]) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5cfdfe9b852e..b8f09a8ee081 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -35,7 +35,7 @@ def prepare(self, reactor, clock, hs: TestHomeServer): # We can't test the RoomMemberStore on its own without the other event # storage logic - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -55,7 +55,7 @@ def test_one_member(self): ) ) - self.assertEquals([self.room], [m.room_id for m in rooms_for_user]) + self.assertEqual([self.room], [m.room_id for m in rooms_for_user]) def test_count_known_servers(self): """ @@ -212,7 +212,7 @@ def test__null_byte_in_display_name_properly_handled(self): class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() def test_can_rerun_update(self): diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 70d52b088c81..f88f1c55fc6f 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -28,7 +28,7 @@ class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() @@ -992,3 +992,112 @@ def test_state_filter_difference_simple_cases(self): StateFilter.none(), StateFilter.all(), ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self): + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index ce782c7e1d4d..6a1cf3305455 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -115,7 +115,7 @@ def _filter_messages(self, filter: JsonDict) -> List[EventBase]: ) events, next_key = self.get_success( - self.hs.get_datastore().paginate_room_events( + self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, from_key=from_token.room_key, to_key=None, diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index bea9091d3089..e05daa285e4c 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -20,7 +20,7 @@ class TransactionStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_get_set_transactions(self): """Tests that we can successfully get a non-existent entry for diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 48f1e9d8411b..7f1964eb6a3d 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -149,7 +149,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_dir_helper = GetUserDirectoryTables(self.store) def _purge_and_rebuild_user_dir(self) -> None: @@ -415,7 +415,7 @@ async def mocked_process_users(*args: Any, **kwargs: Any) -> int: class UserDirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_distributor.py b/tests/test_distributor.py index f8341041ee6e..31546ea52bb0 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -48,7 +48,7 @@ def test_signal_catch(self): observers[0].assert_called_once_with("Go") observers[1].assert_called_once_with("Go") - self.assertEquals(mock_logger.warning.call_count, 1) + self.assertEqual(mock_logger.warning.call_count, 1) self.assertIsInstance(mock_logger.warning.call_args[0][0], str) def test_signal_prereg(self): diff --git a/tests/test_federation.py b/tests/test_federation.py index 2b9804aba005..c39816de855d 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -52,11 +52,13 @@ def setUp(self): ) )[0]["room_id"] - self.store = self.homeserver.get_datastore() + self.store = self.homeserver.get_datastores().main # Figure out what the most recent event is most_recent = self.get_success( - self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.homeserver.get_datastores().main.get_latest_event_ids_in_room( + self.room_id + ) )[0] join_event = make_event_from_dict( @@ -185,7 +187,7 @@ def query_user_devices(destination, user_id): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. - store = self.homeserver.get_datastore() + store = self.homeserver.get_datastores().main store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at diff --git a/tests/test_mau.py b/tests/test_mau.py index 80ab40e255ea..46bd3075de7b 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -52,7 +52,7 @@ def default_config(self): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated diff --git a/tests/test_state.py b/tests/test_state.py index 76e0e8ca7ff6..e4baa6913746 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Collection, Dict, List, Optional from unittest.mock import Mock from twisted.internet import defer @@ -70,7 +70,7 @@ def create_event( return event -class StateGroupStore: +class _DummyStore: def __init__(self): self._event_to_state_group = {} self._group_to_state = {} @@ -105,6 +105,11 @@ async def get_events(self, event_ids, **kwargs): if e_id in self._event_id_to_event } + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + return {e: False for e in event_ids} + async def get_state_group_delta(self, name): return None, None @@ -157,12 +162,12 @@ def get_leaves(self): class StateTestCase(unittest.TestCase): def setUp(self): - self.store = StateGroupStore() - storage = Mock(main=self.store, state=self.store) + self.dummy_store = _DummyStore() + storage = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", - "get_datastore", + "get_datastores", "get_storage", "get_auth", "get_state_handler", @@ -173,7 +178,7 @@ def setUp(self): ] ) hs.config = default_config("tesths", True) - hs.get_datastore.return_value = self.store + hs.get_datastores.return_value = Mock(main=self.dummy_store) hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) @@ -198,7 +203,7 @@ def test_branch_no_conflict(self): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store: dict[str, EventContext] = {} @@ -206,7 +211,7 @@ def test_branch_no_conflict(self): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context ctx_c = context_store["C"] @@ -242,7 +247,7 @@ def test_branch_basic_conflict(self): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -250,7 +255,7 @@ def test_branch_basic_conflict(self): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between B and C @@ -300,7 +305,7 @@ def test_branch_have_banned_conflict(self): edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -308,7 +313,7 @@ def test_branch_have_banned_conflict(self): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between C and D because bans win over other @@ -375,7 +380,7 @@ def test_branch_have_perms_conflict(self): self._add_depths(nodes, edges) graph = Graph(nodes, edges) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -383,7 +388,7 @@ def test_branch_have_perms_conflict(self): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # B ends up winning the resolution between B and C because power levels @@ -476,7 +481,7 @@ def test_trivial_annotate_message(self): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -484,7 +489,7 @@ def test_trivial_annotate_message(self): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -510,7 +515,7 @@ def test_trivial_annotate_state(self): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -518,7 +523,7 @@ def test_trivial_annotate_state(self): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -554,8 +559,8 @@ def test_resolve_message_conflict(self): create_event(type="test4", state_key=""), ] - self.store.register_events(old_state_1) - self.store.register_events(old_state_2) + self.dummy_store.register_events(old_state_1) + self.dummy_store.register_events(old_state_2) context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -594,10 +599,10 @@ def test_resolve_state_conflict(self): create_event(type="test4", state_key=""), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -649,10 +654,10 @@ def test_standard_depth_conflict(self): create_event(type="test1", state_key="1", depth=2), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -695,7 +700,7 @@ def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): sg1 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_1, event.room_id, None, @@ -703,10 +708,10 @@ def _get_context( {(e.type, e.state_key): e.event_id for e in old_state_1}, ) ) - self.store.register_event_id_state_group(prev_event_id_1, sg1) + self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1) sg2 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_2, event.room_id, None, @@ -714,7 +719,7 @@ def _get_context( {(e.type, e.state_key): e.event_id for e in old_state_2}, ) ) - self.store.register_event_id_state_group(prev_event_id_2, sg2) + self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2) result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 67dcf567cdb8..37fada5c5369 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -54,7 +54,7 @@ def test_ui_auth(self): request_data = json.dumps({"username": "kermit", "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["session"], str) @@ -99,7 +99,7 @@ def test_ui_auth(self): # We don't bother checking that the response is correct - we'll leave that to # other tests. We just want to make sure we're on the right path. - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) # Finish the UI auth for terms request_data = json.dumps( @@ -117,7 +117,7 @@ def test_ui_auth(self): # We're interested in getting a response that looks like a successful # registration, not so much that the details are exactly what we want. - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["user_id"], str) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index f2ef1c6051eb..d04bcae0fabe 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -25,7 +25,7 @@ def test_advance_time(self): self.clock.advance_time(20) - self.assertEquals(20, self.clock.time() - start_time) + self.assertEqual(20, self.clock.time() - start_time) def test_later(self): invoked = [0, 0] diff --git a/tests/test_types.py b/tests/test_types.py index 0d0c00d97a06..80888a744d1b 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -22,9 +22,9 @@ class UserIDTestCase(unittest.HomeserverTestCase): def test_parse(self): user = UserID.from_string("@1234abcd:test") - self.assertEquals("1234abcd", user.localpart) - self.assertEquals("test", user.domain) - self.assertEquals(True, self.hs.is_mine(user)) + self.assertEqual("1234abcd", user.localpart) + self.assertEqual("test", user.domain) + self.assertEqual(True, self.hs.is_mine(user)) def test_pase_empty(self): with self.assertRaises(SynapseError): @@ -33,7 +33,7 @@ def test_pase_empty(self): def test_build(self): user = UserID("5678efgh", "my.domain") - self.assertEquals(user.to_string(), "@5678efgh:my.domain") + self.assertEqual(user.to_string(), "@5678efgh:my.domain") def test_compare(self): userA = UserID.from_string("@userA:my.domain") @@ -48,14 +48,14 @@ class RoomAliasTestCase(unittest.HomeserverTestCase): def test_parse(self): room = RoomAlias.from_string("#channel:test") - self.assertEquals("channel", room.localpart) - self.assertEquals("test", room.domain) - self.assertEquals(True, self.hs.is_mine(room)) + self.assertEqual("channel", room.localpart) + self.assertEqual("test", room.domain) + self.assertEqual(True, self.hs.is_mine(room)) def test_build(self): room = RoomAlias("channel", "my.domain") - self.assertEquals(room.to_string(), "#channel:my.domain") + self.assertEqual(room.to_string(), "#channel:my.domain") def test_validate(self): id_string = "#test:domain,test" diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index e9ec9e085b53..c654e36ee4f4 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -85,7 +85,9 @@ async def create_event( **kwargs, ) -> Tuple[EventBase, EventContext]: if room_version is None: - room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) + room_version = await hs.get_datastores().main.get_room_version_id( + kwargs["room_id"] + ) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0b08d67d435..219b5660b117 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -93,7 +93,9 @@ def test_erased_user(self) -> None: events_to_filter.append(evt) # the erasey user gets erased - self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs")) + self.get_success( + self.hs.get_datastores().main.mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. filtered = self.get_success( diff --git a/tests/unittest.py b/tests/unittest.py index a71892cb9dbe..326895f4c9f5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -51,7 +51,10 @@ from synapse import events from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite @@ -149,12 +152,12 @@ def tearDown(orig): def assertObjectHasAttributes(self, attrs, obj): """Asserts that the given object has each of the attributes given, and - that the value of each matches according to assertEquals.""" + that the value of each matches according to assertEqual.""" for key in attrs.keys(): if not hasattr(obj, key): raise AssertionError("Expected obj to have a '.%s'" % key) try: - self.assertEquals(attrs[key], getattr(obj, key)) + self.assertEqual(attrs[key], getattr(obj, key)) except AssertionError as e: raise (type(e))(f"Assert error for '.{key}':") from e @@ -166,7 +169,7 @@ def assert_dict(self, required, actual): actual (dict): The test result. Extra keys will not be checked. """ for key in required: - self.assertEquals( + self.assertEqual( required[key], actual[key], msg="%s mismatch. %s" % (key, actual) ) @@ -277,7 +280,7 @@ def setUp(self): # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( - self.hs.get_datastore().add_access_token_to_user( + self.hs.get_datastores().main.add_access_token_to_user( self.helper.auth_user_id, "some_fake_token", None, @@ -334,7 +337,7 @@ def wait_on_thread(self, deferred, timeout=10): def wait_for_background_updates(self) -> None: """Block until all background database updates have completed.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main while not self.get_success( store.db_pool.updates.has_completed_background_updates() ): @@ -501,7 +504,7 @@ async def run_bg_updates(): self.get_success(stor.db_pool.updates.run_background_updates(False)) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) - stor = hs.get_datastore() + stor = hs.get_datastores().main # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": @@ -719,14 +722,16 @@ def add_extremity(self, room_id, event_id): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", ) ) - self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,)) + self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( + (room_id,) + ) def attempt_wrong_password_login(self, username, password): """Attempts to login as the user with the given password, asserting @@ -772,7 +777,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) self.get_success( - hs.get_datastore().store_server_verify_keys( + hs.get_datastores().main.store_server_verify_keys( from_server=self.OTHER_SERVER_NAME, ts_added_ms=clock.time_msec(), verify_keys=[ @@ -839,6 +844,24 @@ def make_signed_federation_request( client_ip=client_ip, ) + def add_hashes_and_signatures( + self, + event_dict: JsonDict, + room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + ) -> JsonDict: + """Adds hashes and signatures to the given event dict + + Returns: + The modified event dict, for convenience + """ + add_hashes_and_signatures( + room_version, + event_dict, + signature_name=self.OTHER_SERVER_NAME, + signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + ) + return event_dict + def _auth_header_for_request( origin: str, diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index c613ce3f1055..02b99b466a26 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -31,7 +31,7 @@ def test_hit(self): cache = DeferredCache("test") cache.prefill("foo", 123) - self.assertEquals(self.successResultOf(cache.get("foo")), 123) + self.assertEqual(self.successResultOf(cache.get("foo")), 123) def test_hit_deferred(self): cache = DeferredCache("test") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index ced3efd93f8b..19741ffcdaf1 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -434,8 +434,8 @@ def func(self, key): a = A() - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals((yield a.func("bar")), "bar") + self.assertEqual((yield a.func("foo")), "foo") + self.assertEqual((yield a.func("bar")), "bar") @defer.inlineCallbacks def test_hit(self): @@ -450,10 +450,10 @@ def func(self, key): a = A() yield a.func("foo") - self.assertEquals(callcount[0], 1) + self.assertEqual(callcount[0], 1) - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals(callcount[0], 1) + self.assertEqual((yield a.func("foo")), "foo") + self.assertEqual(callcount[0], 1) @defer.inlineCallbacks def test_invalidate(self): @@ -468,13 +468,13 @@ def func(self, key): a = A() yield a.func("foo") - self.assertEquals(callcount[0], 1) + self.assertEqual(callcount[0], 1) a.func.invalidate(("foo",)) yield a.func("foo") - self.assertEquals(callcount[0], 2) + self.assertEqual(callcount[0], 2) def test_invalidate_missing(self): class A: @@ -499,7 +499,7 @@ def func(self, key): for k in range(0, 12): yield a.func(k) - self.assertEquals(callcount[0], 12) + self.assertEqual(callcount[0], 12) # There must have been at least 2 evictions, meaning if we calculate # all 12 values again, we must get called at least 2 more times @@ -525,8 +525,8 @@ def func(self, key): a.func.prefill(("foo",), 456) - self.assertEquals(a.func("foo").result, 456) - self.assertEquals(callcount[0], 0) + self.assertEqual(a.func("foo").result, 456) + self.assertEqual(callcount[0], 0) @defer.inlineCallbacks def test_invalidate_context(self): @@ -547,19 +547,19 @@ def func2(self, key, cache_context): a = A() yield a.func2("foo") - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 1) a.func.invalidate(("foo",)) yield a.func("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 1) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) @defer.inlineCallbacks def test_eviction_context(self): @@ -581,22 +581,22 @@ def func2(self, key, cache_context): yield a.func2("foo") yield a.func2("foo2") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func("foo3") - self.assertEquals(callcount[0], 3) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 3) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 4) - self.assertEquals(callcount2[0], 3) + self.assertEqual(callcount[0], 4) + self.assertEqual(callcount2[0], 3) @defer.inlineCallbacks def test_double_get(self): @@ -619,30 +619,30 @@ def func2(self, key, cache_context): yield a.func2("foo") - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 1) a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 1) yield a.func2("foo") a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 2) - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 2) a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 3) yield a.func("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 3) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 3) class CachedListDescriptorTestCase(unittest.TestCase): @@ -673,14 +673,14 @@ async def list_fn(self, args1, arg2): self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 self.assertEqual(current_context(), c1) - obj.mock.assert_called_once_with((10, 20), 2) + obj.mock.assert_called_once_with({10, 20}, 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = {30: "peas"} r = yield obj.list_fn([20, 30], 2) - obj.mock.assert_called_once_with((30,), 2) + obj.mock.assert_called_once_with({30}, 2) self.assertEqual(r, {20: "chips", 30: "peas"}) obj.mock.reset_mock() @@ -701,7 +701,7 @@ async def list_fn(self, args1, arg2): obj.mock.return_value = {40: "gravy"} iterable = (x for x in [10, 40, 40]) r = yield obj.list_fn(iterable, 2) - obj.mock.assert_called_once_with((40,), 2) + obj.mock.assert_called_once_with({40}, 2) self.assertEqual(r, {10: "fish", 40: "gravy"}) def test_concurrent_lookups(self): @@ -729,7 +729,7 @@ def list_fn(self, args1) -> "Deferred[dict]": d3 = obj.list_fn([10]) # the mock should have been called exactly once - obj.mock.assert_called_once_with((10,)) + obj.mock.assert_called_once_with({10}) obj.mock.reset_mock() # ... and none of the calls should yet be complete @@ -771,7 +771,7 @@ async def list_fn(self, args1, arg2): # cache miss obj.mock.return_value = {10: "fish", 20: "chips"} r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) - obj.mock.assert_called_once_with((10, 20), 2) + obj.mock.assert_called_once_with({10, 20}, 2) self.assertEqual(r1, {10: "fish", 20: "chips"}) obj.mock.reset_mock() diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index ab89cab81256..362014f4cb6f 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import traceback + from twisted.internet import defer -from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.internet.task import Clock +from twisted.python.failure import Failure from synapse.logging.context import ( SENTINEL_CONTEXT, @@ -21,7 +24,12 @@ PreserveLoggingContext, current_context, ) -from synapse.util.async_helpers import ObservableDeferred, timeout_deferred +from synapse.util.async_helpers import ( + ObservableDeferred, + concurrently_execute, + stop_cancellation, + timeout_deferred, +) from tests.unittest import TestCase @@ -171,3 +179,151 @@ def errback(res, deferred_name): ) self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(current_context(), context_one) + + +class _TestException(Exception): + pass + + +class ConcurrentlyExecuteTest(TestCase): + def test_limits_runners(self): + """If we have more tasks than runners, we should get the limit of runners""" + started = 0 + waiters = [] + processed = [] + + async def callback(v): + # when we first enter, bump the start count + nonlocal started + started += 1 + + # record the fact we got an item + processed.append(v) + + # wait for the goahead before returning + d2 = Deferred() + waiters.append(d2) + await d2 + + # set it going + d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3)) + + # check we got exactly 3 processes + self.assertEqual(started, 3) + self.assertEqual(len(waiters), 3) + + # let one finish + waiters.pop().callback(0) + + # ... which should start another + self.assertEqual(started, 4) + self.assertEqual(len(waiters), 3) + + # we still shouldn't be done + self.assertNoResult(d2) + + # finish the job + while waiters: + waiters.pop().callback(0) + + # check everything got done + self.assertEqual(started, 5) + self.assertCountEqual(processed, [1, 2, 3, 4, 5]) + self.successResultOf(d2) + + def test_preserves_stacktraces(self): + """Test that the stacktrace from an exception thrown in the callback is preserved""" + d1 = Deferred() + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + raise _TestException("bah") + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute" and "callback". + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-1].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) + + def test_preserves_stacktraces_on_preformed_failure(self): + """Test that the stacktrace on a Failure returned by the callback is preserved""" + d1 = Deferred() + f = Failure(_TestException("bah")) + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + await defer.fail(f) + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute", "callback", + # and some magic from inside ensureDeferred that happens when .fail + # is called. + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-2].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) + + +class StopCancellationTests(TestCase): + """Tests for the `stop_cancellation` function.""" + + def test_succeed(self): + """Test that the new `Deferred` receives the result.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Success should propagate through. + deferred.callback("success") + self.assertTrue(wrapper_deferred.called) + self.assertEqual("success", self.successResultOf(wrapper_deferred)) + + def test_failure(self): + """Test that the new `Deferred` receives the `Failure`.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Failure should propagate through. + deferred.errback(ValueError("abc")) + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, ValueError) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + def test_cancellation(self): + """Test that cancellation of the new `Deferred` leaves the original running.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, CancelledError) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled." + ) + + # Now make the inner `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py new file mode 100644 index 000000000000..38e9f58ac6ca --- /dev/null +++ b/tests/util/test_check_dependencies.py @@ -0,0 +1,131 @@ +from contextlib import contextmanager +from typing import Generator, Optional +from unittest.mock import patch + +from synapse.util.check_dependencies import ( + DependencyException, + check_requirements, + metadata, +) + +from tests.unittest import TestCase + + +class DummyDistribution(metadata.Distribution): + def __init__(self, version: str): + self._version = version + + @property + def version(self): + return self._version + + def locate_file(self, path): + raise NotImplementedError() + + def read_text(self, filename): + raise NotImplementedError() + + +old = DummyDistribution("0.1.2") +old_release_candidate = DummyDistribution("0.1.2rc3") +new = DummyDistribution("1.2.3") +new_release_candidate = DummyDistribution("1.2.3rc4") + +# could probably use stdlib TestCase --- no need for twisted here + + +class TestDependencyChecker(TestCase): + @contextmanager + def mock_installed_package( + self, distribution: Optional[DummyDistribution] + ) -> Generator[None, None, None]: + """Pretend that looking up any distribution yields the given `distribution`.""" + + def mock_distribution(name: str): + if distribution is None: + raise metadata.PackageNotFoundError + else: + return distribution + + with patch( + "synapse.util.check_dependencies.metadata.distribution", + mock_distribution, + ): + yield + + def test_mandatory_dependency(self) -> None: + """Complain if a required package is missing or old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(None): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(new): + # should not raise + check_requirements() + + def test_checks_ignore_dev_dependencies(self) -> None: + """Bot generic and per-extra checks should ignore dev dependencies.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'mypy'"], + ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}): + # We're testing that none of these calls raise. + with self.mock_installed_package(None): + check_requirements() + check_requirements("cool-extra") + with self.mock_installed_package(old): + check_requirements() + check_requirements("cool-extra") + with self.mock_installed_package(new): + check_requirements() + check_requirements("cool-extra") + + def test_generic_check_of_optional_dependency(self) -> None: + """Complain if an optional package is old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'cool-extra'"], + ): + with self.mock_installed_package(None): + # should not raise + check_requirements() + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(new): + # should not raise + check_requirements() + + def test_check_for_extra_dependencies(self) -> None: + """Complain if a package required for an extra is missing or old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'cool-extra'"], + ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}): + with self.mock_installed_package(None): + self.assertRaises(DependencyException, check_requirements, "cool-extra") + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements, "cool-extra") + with self.mock_installed_package(new): + # should not raise + check_requirements("cool-extra") + + def test_release_candidates_satisfy_dependency(self) -> None: + """ + Tests that release candidates count as far as satisfying a dependency + is concerned. + (Regression test, see #12176.) + """ + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(old_release_candidate): + self.assertRaises(DependencyException, check_requirements) + + with self.mock_installed_package(new_release_candidate): + # should not raise + check_requirements() diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index e6e13ba06cc8..7f60aae5ba0d 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -26,8 +26,8 @@ def test_get_set(self): cache = ExpiringCache("test", clock, max_len=1) cache["key"] = "value" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache["key"], "value") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache["key"], "value") def test_eviction(self): clock = MockClock() @@ -35,13 +35,13 @@ def test_eviction(self): cache["key"] = "value" cache["key2"] = "value2" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache.get("key2"), "value2") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache.get("key2"), "value2") cache["key3"] = "value3" - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), "value2") - self.assertEquals(cache.get("key3"), "value3") + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), "value2") + self.assertEqual(cache.get("key3"), "value3") def test_iterable_eviction(self): clock = MockClock() @@ -51,15 +51,15 @@ def test_iterable_eviction(self): cache["key2"] = [2, 3] cache["key3"] = [4, 5] - self.assertEquals(cache.get("key"), [1]) - self.assertEquals(cache.get("key2"), [2, 3]) - self.assertEquals(cache.get("key3"), [4, 5]) + self.assertEqual(cache.get("key"), [1]) + self.assertEqual(cache.get("key2"), [2, 3]) + self.assertEqual(cache.get("key3"), [4, 5]) cache["key4"] = [6, 7] - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), None) - self.assertEquals(cache.get("key3"), [4, 5]) - self.assertEquals(cache.get("key4"), [6, 7]) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), None) + self.assertEqual(cache.get("key3"), [4, 5]) + self.assertEqual(cache.get("key4"), [6, 7]) def test_time_eviction(self): clock = MockClock() @@ -69,13 +69,13 @@ def test_time_eviction(self): clock.advance_time(0.5) cache["key2"] = 2 - self.assertEquals(cache.get("key"), 1) - self.assertEquals(cache.get("key2"), 2) + self.assertEqual(cache.get("key"), 1) + self.assertEqual(cache.get("key2"), 2) clock.advance_time(0.9) - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), 2) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), 2) clock.advance_time(1) - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), None) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), None) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 621b0f9fcdf0..2ad321e1847c 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -17,7 +17,7 @@ class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(current_context().name, value) + self.assertEqual(current_context().name, value) def test_with_context(self): with LoggingContext("test"): diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 291644eb7dd0..321fc1776f8c 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -27,37 +27,37 @@ class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache["key"], "value") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache["key"], "value") def test_eviction(self): cache = LruCache(2) cache[1] = 1 cache[2] = 2 - self.assertEquals(cache.get(1), 1) - self.assertEquals(cache.get(2), 2) + self.assertEqual(cache.get(1), 1) + self.assertEqual(cache.get(2), 2) cache[3] = 3 - self.assertEquals(cache.get(1), None) - self.assertEquals(cache.get(2), 2) - self.assertEquals(cache.get(3), 3) + self.assertEqual(cache.get(1), None) + self.assertEqual(cache.get(2), 2) + self.assertEqual(cache.get(3), 3) def test_setdefault(self): cache = LruCache(1) - self.assertEquals(cache.setdefault("key", 1), 1) - self.assertEquals(cache.get("key"), 1) - self.assertEquals(cache.setdefault("key", 2), 1) - self.assertEquals(cache.get("key"), 1) + self.assertEqual(cache.setdefault("key", 1), 1) + self.assertEqual(cache.get("key"), 1) + self.assertEqual(cache.setdefault("key", 2), 1) + self.assertEqual(cache.get("key"), 1) cache["key"] = 2 # Make sure overriding works. - self.assertEquals(cache.get("key"), 2) + self.assertEqual(cache.get("key"), 2) def test_pop(self): cache = LruCache(1) cache["key"] = 1 - self.assertEquals(cache.pop("key"), 1) - self.assertEquals(cache.pop("key"), None) + self.assertEqual(cache.pop("key"), 1) + self.assertEqual(cache.pop("key"), None) def test_del_multi(self): cache = LruCache(4, cache_type=TreeCache) @@ -66,23 +66,23 @@ def test_del_multi(self): cache[("vehicles", "car")] = "vroom" cache[("vehicles", "train")] = "chuff" - self.assertEquals(len(cache), 4) + self.assertEqual(len(cache), 4) - self.assertEquals(cache.get(("animal", "cat")), "mew") - self.assertEquals(cache.get(("vehicles", "car")), "vroom") + self.assertEqual(cache.get(("animal", "cat")), "mew") + self.assertEqual(cache.get(("vehicles", "car")), "vroom") cache.del_multi(("animal",)) - self.assertEquals(len(cache), 2) - self.assertEquals(cache.get(("animal", "cat")), None) - self.assertEquals(cache.get(("animal", "dog")), None) - self.assertEquals(cache.get(("vehicles", "car")), "vroom") - self.assertEquals(cache.get(("vehicles", "train")), "chuff") + self.assertEqual(len(cache), 2) + self.assertEqual(cache.get(("animal", "cat")), None) + self.assertEqual(cache.get(("animal", "dog")), None) + self.assertEqual(cache.get(("vehicles", "car")), "vroom") + self.assertEqual(cache.get(("vehicles", "train")), "chuff") # Man from del_multi say "Yes". def test_clear(self): cache = LruCache(1) cache["key"] = 1 cache.clear() - self.assertEquals(len(cache), 0) + self.assertEqual(len(cache), 0) @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) def test_special_size(self): @@ -105,10 +105,10 @@ def test_get(self): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_multi_get(self): m = Mock() @@ -124,10 +124,10 @@ def test_multi_get(self): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_set(self): m = Mock() @@ -140,10 +140,10 @@ def test_set(self): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_pop(self): m = Mock() @@ -153,13 +153,13 @@ def test_pop(self): self.assertFalse(m.called) cache.pop("key") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.pop("key") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_del_multi(self): m1 = Mock() @@ -173,17 +173,17 @@ def test_del_multi(self): cache.set(("b", "1"), "value", callbacks=[m3]) cache.set(("b", "2"), "value", callbacks=[m4]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) - self.assertEquals(m4.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) + self.assertEqual(m4.call_count, 0) cache.del_multi(("a",)) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 1) - self.assertEquals(m3.call_count, 0) - self.assertEquals(m4.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 1) + self.assertEqual(m3.call_count, 0) + self.assertEqual(m4.call_count, 0) def test_clear(self): m1 = Mock() @@ -193,13 +193,13 @@ def test_clear(self): cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) cache.clear() - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 1) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 1) def test_eviction(self): m1 = Mock(name="m1") @@ -210,33 +210,33 @@ def test_eviction(self): cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key3", "value", callbacks=[m3]) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key3", "value") - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.get("key2") - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key1", "value", callbacks=[m1]) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 1) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 1) class LruCacheSizedTestCase(unittest.HomeserverTestCase): @@ -247,20 +247,20 @@ def test_evict(self): cache["key3"] = [3] cache["key4"] = [4] - self.assertEquals(cache["key1"], [0]) - self.assertEquals(cache["key2"], [1, 2]) - self.assertEquals(cache["key3"], [3]) - self.assertEquals(cache["key4"], [4]) - self.assertEquals(len(cache), 5) + self.assertEqual(cache["key1"], [0]) + self.assertEqual(cache["key2"], [1, 2]) + self.assertEqual(cache["key3"], [3]) + self.assertEqual(cache["key4"], [4]) + self.assertEqual(len(cache), 5) cache["key5"] = [5, 6] - self.assertEquals(len(cache), 4) - self.assertEquals(cache.get("key1"), None) - self.assertEquals(cache.get("key2"), None) - self.assertEquals(cache["key3"], [3]) - self.assertEquals(cache["key4"], [4]) - self.assertEquals(cache["key5"], [5, 6]) + self.assertEqual(len(cache), 4) + self.assertEqual(cache.get("key1"), None) + self.assertEqual(cache.get("key2"), None) + self.assertEqual(cache["key3"], [3]) + self.assertEqual(cache["key4"], [4]) + self.assertEqual(cache["key5"], [5, 6]) def test_zero_size_drop_from_cache(self) -> None: """Test that `drop_from_cache` works correctly with 0-sized entries.""" diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9e1bebdc8335..26cb71c640c3 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -24,7 +24,7 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self): """A happy-path case with a new destination and a successful operation""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) # advance the clock a bit before making the request @@ -38,7 +38,7 @@ def test_new_destination(self): def test_limiter(self): """General test case which walks through the process of a failing request""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index a10071c70fc9..0774625b8551 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -13,6 +13,7 @@ # limitations under the License. from twisted.internet import defer +from twisted.internet.defer import Deferred from synapse.util.async_helpers import ReadWriteLock @@ -83,3 +84,32 @@ def test_rwlock(self): self.assertTrue(d.called) with d.result: pass + + def test_lock_handoff_to_nonblocking_writer(self): + """Test a writer handing the lock to another writer that completes instantly.""" + rwlock = ReadWriteLock() + key = "key" + + unblock: "Deferred[None]" = Deferred() + + async def blocking_write(): + with await rwlock.write(key): + await unblock + + async def nonblocking_write(): + with await rwlock.write(key): + pass + + d1 = defer.ensureDeferred(blocking_write()) + d2 = defer.ensureDeferred(nonblocking_write()) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + # Unblock the first writer. The second writer will complete without blocking. + unblock.callback(None) + self.assertTrue(d1.called) + self.assertTrue(d2.called) + + # The `ReadWriteLock` should operate as normal. + d3 = defer.ensureDeferred(nonblocking_write()) + self.assertTrue(d3.called) diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 60663720536d..567cb184684e 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -23,61 +23,61 @@ def test_get_set_onelevel(self): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" - self.assertEquals(cache.get(("a",)), "A") - self.assertEquals(cache.get(("b",)), "B") - self.assertEquals(len(cache), 2) + self.assertEqual(cache.get(("a",)), "A") + self.assertEqual(cache.get(("b",)), "B") + self.assertEqual(len(cache), 2) def test_pop_onelevel(self): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" - self.assertEquals(cache.pop(("a",)), "A") - self.assertEquals(cache.pop(("a",)), None) - self.assertEquals(cache.get(("b",)), "B") - self.assertEquals(len(cache), 1) + self.assertEqual(cache.pop(("a",)), "A") + self.assertEqual(cache.pop(("a",)), None) + self.assertEqual(cache.get(("b",)), "B") + self.assertEqual(len(cache), 1) def test_get_set_twolevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.get(("a", "a")), "AA") - self.assertEquals(cache.get(("a", "b")), "AB") - self.assertEquals(cache.get(("b", "a")), "BA") - self.assertEquals(len(cache), 3) + self.assertEqual(cache.get(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "b")), "AB") + self.assertEqual(cache.get(("b", "a")), "BA") + self.assertEqual(len(cache), 3) def test_pop_twolevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.pop(("a", "a")), "AA") - self.assertEquals(cache.get(("a", "a")), None) - self.assertEquals(cache.get(("a", "b")), "AB") - self.assertEquals(cache.pop(("b", "a")), "BA") - self.assertEquals(cache.pop(("b", "a")), None) - self.assertEquals(len(cache), 1) + self.assertEqual(cache.pop(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "a")), None) + self.assertEqual(cache.get(("a", "b")), "AB") + self.assertEqual(cache.pop(("b", "a")), "BA") + self.assertEqual(cache.pop(("b", "a")), None) + self.assertEqual(len(cache), 1) def test_pop_mixedlevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.get(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "a")), "AA") popped = cache.pop(("a",)) - self.assertEquals(cache.get(("a", "a")), None) - self.assertEquals(cache.get(("a", "b")), None) - self.assertEquals(cache.get(("b", "a")), "BA") - self.assertEquals(len(cache), 1) + self.assertEqual(cache.get(("a", "a")), None) + self.assertEqual(cache.get(("a", "b")), None) + self.assertEqual(cache.get(("b", "a")), "BA") + self.assertEqual(len(cache), 1) - self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) + self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) def test_clear(self): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" cache.clear() - self.assertEquals(len(cache), 0) + self.assertEqual(len(cache), 0) def test_contains(self): cache = TreeCache() diff --git a/tests/utils.py b/tests/utils.py index c06fc320f3df..ef99c72e0b50 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -367,7 +367,7 @@ async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" persistence_store = hs.get_storage().persistence - store = hs.get_datastore() + store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() diff --git a/tox.ini b/tox.ini index 32679e9106c8..04b972e2c58c 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,11 @@ envlist = packaging, py37, py38, py39, py310, check_codestyle, check_isort # we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208 minversion = 2.3.2 +# the tox-venv plugin makes tox use python's built-in `venv` module rather than +# the legacy `virtualenv` tool. `virtualenv` embeds its own `pip`, `setuptools`, +# etc, and ends up being rather unreliable. +requires = tox-venv + [base] deps = python-subunit @@ -119,6 +124,9 @@ usedevelop = false deps = Automat == 0.8.0 lxml + # markupsafe 2.1 introduced a change that breaks Jinja 2.x. Since we depend on + # Jinja >= 2.9, it means this test suite will fail if markupsafe >= 2.1 is installed. + markupsafe < 2.1 {[base]deps} commands = @@ -158,17 +166,7 @@ commands = [testenv:check_isort] extras = lint -commands = isort -c --df --sp setup.cfg {[base]lint_targets} - -[testenv:check-newsfragment] -skip_install = true -usedevelop = false -deps = towncrier>=18.6.0rc1 -commands = - python -m towncrier.check --compare-with=origin/develop - -[testenv:check-sampleconfig] -commands = {toxinidir}/scripts-dev/generate_sample_config --check +commands = isort -c --df {[base]lint_targets} [testenv:combine] skip_install = true