diff --git a/.ci/scripts/calculate_jobs.py b/.ci/scripts/calculate_jobs.py index 661887e20985..50e11e6504ff 100755 --- a/.ci/scripts/calculate_jobs.py +++ b/.ci/scripts/calculate_jobs.py @@ -47,9 +47,10 @@ def set_output(key: str, value: str): "database": "sqlite", "extras": "all", } - for version in ("3.9", "3.10", "3.11", "3.12.0-rc.1") + for version in ("3.9", "3.10", "3.11") ) + trial_postgres_tests = [ { "python-version": "3.8", diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml index 7b839f59c1d9..ec6391cf8fd4 100644 --- a/.github/workflows/latest_deps.yml +++ b/.github/workflows/latest_deps.yml @@ -57,8 +57,8 @@ jobs: # `pip install matrix-synapse[all]` as closely as possible. - run: poetry update --no-dev - run: poetry run pip list > after.txt && (diff -u before.txt after.txt || true) - - name: Remove unhelpful options from mypy config - run: sed -e '/warn_unused_ignores = True/d' -e '/warn_redundant_casts = True/d' -i mypy.ini + - name: Remove warn_unused_ignores from mypy config + run: sed '/warn_unused_ignores = True/d' -i mypy.ini - run: poetry run mypy trial: needs: check_repo diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index 7d629a4ed097..67ccc03f6e2d 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -54,8 +54,8 @@ jobs: poetry remove twisted poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref || 'trunk' }} poetry install --no-interaction --extras "all test" - - name: Remove unhelpful options from mypy config - run: sed -e '/warn_unused_ignores = True/d' -e '/warn_redundant_casts = True/d' -i mypy.ini + - name: Remove warn_unused_ignores from mypy config + run: sed '/warn_unused_ignores = True/d' -i mypy.ini - run: poetry run mypy trial: diff --git a/Cargo.lock b/Cargo.lock index 4d60f8dcb62a..61c0f1bd0402 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,9 +13,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.75" +version = "1.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +checksum = "3b13c32d80ecc7ab747b80c3784bce54ee8a7a0cc4fbda9bf4cda2cf6fe90854" [[package]] name = "arc-swap" @@ -291,9 +291,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.4" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" +checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" dependencies = [ "aho-corasick", "memchr", @@ -303,9 +303,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.7" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" +checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" dependencies = [ "aho-corasick", "memchr", @@ -314,9 +314,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.5" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" [[package]] name = "ryu" @@ -332,18 +332,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.188" +version = "1.0.184" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "2c911f4b04d7385c9035407a4eff5903bf4fe270fa046fda448b69e797f4fff0" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.184" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "c1df27f5b29406ada06609b2e2f77fb34f6dbb104a457a671cc31dbed237e09e" dependencies = [ "proc-macro2", "quote", @@ -352,9 +352,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "076066c5f1078eac5b722a31827a8832fe108bed65dfa75e233c89f8206e976c" dependencies = [ "itoa", "ryu", diff --git a/changelog.d/15816.feature b/changelog.d/15816.feature deleted file mode 100644 index 9248dd6792cc..000000000000 --- a/changelog.d/15816.feature +++ /dev/null @@ -1 +0,0 @@ -Add configuration setting for CAS protocol version. Contributed by AurĂ©lien Grimpard. diff --git a/changelog.d/16008.doc b/changelog.d/16008.doc deleted file mode 100644 index 1142224951c4..000000000000 --- a/changelog.d/16008.doc +++ /dev/null @@ -1 +0,0 @@ -Update links to the matrix.org blog. diff --git a/changelog.d/16099.misc b/changelog.d/16099.misc deleted file mode 100644 index d0e281136668..000000000000 --- a/changelog.d/16099.misc +++ /dev/null @@ -1 +0,0 @@ -Prepare unit tests for Python 3.12. diff --git a/changelog.d/16113.feature b/changelog.d/16113.feature deleted file mode 100644 index 69fdaaebacc1..000000000000 --- a/changelog.d/16113.feature +++ /dev/null @@ -1 +0,0 @@ -Suppress notifications from message edits per [MSC3958](https://github.com/matrix-org/matrix-spec-proposals/pull/3958). diff --git a/changelog.d/16121.misc b/changelog.d/16121.misc deleted file mode 100644 index f325d2a31dbd..000000000000 --- a/changelog.d/16121.misc +++ /dev/null @@ -1 +0,0 @@ -Attempt to fix the twisted trunk job. diff --git a/changelog.d/16135.misc b/changelog.d/16135.misc deleted file mode 100644 index cba8733d0201..000000000000 --- a/changelog.d/16135.misc +++ /dev/null @@ -1 +0,0 @@ -Describe which rate limiter was hit in logs. diff --git a/changelog.d/16136.feature b/changelog.d/16136.feature deleted file mode 100644 index 4ad98a88c309..000000000000 --- a/changelog.d/16136.feature +++ /dev/null @@ -1 +0,0 @@ -Return a `Retry-After` with `M_LIMIT_EXCEEDED` error responses. diff --git a/changelog.d/16155.bugfix b/changelog.d/16155.bugfix deleted file mode 100644 index 8b2dc0400672..000000000000 --- a/changelog.d/16155.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix IPv6-related bugs on SMTP settings, adding groundwork to fix similar issues. Contributed by @evilham and @telmich (ungleich.ch). diff --git a/changelog.d/16168.doc b/changelog.d/16168.doc deleted file mode 100644 index 7dadb047bef9..000000000000 --- a/changelog.d/16168.doc +++ /dev/null @@ -1 +0,0 @@ -Document which admin APIs are disabled when experimental [MSC3861](https://github.com/matrix-org/matrix-spec-proposals/pull/3861) support is enabled. diff --git a/changelog.d/16170.misc b/changelog.d/16170.misc deleted file mode 100644 index c950b5436705..000000000000 --- a/changelog.d/16170.misc +++ /dev/null @@ -1 +0,0 @@ -Simplify presence code when using workers. diff --git a/changelog.d/16171.misc b/changelog.d/16171.misc deleted file mode 100644 index 4d709cb56e19..000000000000 --- a/changelog.d/16171.misc +++ /dev/null @@ -1 +0,0 @@ -Track per-device information in the presence code. diff --git a/changelog.d/16172.misc b/changelog.d/16172.misc deleted file mode 100644 index 4d709cb56e19..000000000000 --- a/changelog.d/16172.misc +++ /dev/null @@ -1 +0,0 @@ -Track per-device information in the presence code. diff --git a/changelog.d/16175.misc b/changelog.d/16175.misc deleted file mode 100644 index 308fbc225923..000000000000 --- a/changelog.d/16175.misc +++ /dev/null @@ -1 +0,0 @@ -Stop using the `event_txn_id` table. diff --git a/changelog.d/16178.doc b/changelog.d/16178.doc deleted file mode 100644 index ea21e19240bd..000000000000 --- a/changelog.d/16178.doc +++ /dev/null @@ -1 +0,0 @@ -Document `exclude_rooms_from_sync` configuration option. diff --git a/changelog.d/16179.misc b/changelog.d/16179.misc deleted file mode 100644 index 8d04954ab97a..000000000000 --- a/changelog.d/16179.misc +++ /dev/null @@ -1 +0,0 @@ -Use `AsyncMock` instead of custom code. diff --git a/changelog.d/16180.misc b/changelog.d/16180.misc deleted file mode 100644 index 8d04954ab97a..000000000000 --- a/changelog.d/16180.misc +++ /dev/null @@ -1 +0,0 @@ -Use `AsyncMock` instead of custom code. diff --git a/changelog.d/16183.misc b/changelog.d/16183.misc deleted file mode 100644 index 305d5baa6e03..000000000000 --- a/changelog.d/16183.misc +++ /dev/null @@ -1 +0,0 @@ -Improve error reporting of invalid data passed to `/_matrix/key/v2/query`. diff --git a/changelog.d/16184.misc b/changelog.d/16184.misc deleted file mode 100644 index 3c0baddfe1c6..000000000000 --- a/changelog.d/16184.misc +++ /dev/null @@ -1 +0,0 @@ -Task scheduler: add replication notify for new task to launch ASAP. diff --git a/changelog.d/16185.bugfix b/changelog.d/16185.bugfix deleted file mode 100644 index e62c9c7a0d8b..000000000000 --- a/changelog.d/16185.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a spec compliance issue where requests to the `/publicRooms` federation API would specify `include_all_networks` as a string. diff --git a/changelog.d/16186.misc b/changelog.d/16186.misc deleted file mode 100644 index 93ceaeafc9b9..000000000000 --- a/changelog.d/16186.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints. diff --git a/changelog.d/16187.misc b/changelog.d/16187.misc deleted file mode 100644 index 989147274a70..000000000000 --- a/changelog.d/16187.misc +++ /dev/null @@ -1 +0,0 @@ -Bump black version to 23.7.0. diff --git a/changelog.d/16188.misc b/changelog.d/16188.misc deleted file mode 100644 index 93ceaeafc9b9..000000000000 --- a/changelog.d/16188.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints. diff --git a/changelog.d/16201.misc b/changelog.d/16201.misc deleted file mode 100644 index 93ceaeafc9b9..000000000000 --- a/changelog.d/16201.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints. diff --git a/changelog.d/16205.bugfix b/changelog.d/16205.bugfix deleted file mode 100644 index 97ac92a14889..000000000000 --- a/changelog.d/16205.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix inaccurate error message while attempting to ban or unban a user with the same or higher PL by spliting the conditional statements. Contributed by @leviosacz. \ No newline at end of file diff --git a/changelog.d/16210.bugfix b/changelog.d/16210.bugfix deleted file mode 100644 index 39c35a1fe144..000000000000 --- a/changelog.d/16210.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix rare bug that broke looping calls, which could lead to e.g. linearly increasing memory usage. Introduced in v1.90.0. diff --git a/changelog.d/16211.bugfix b/changelog.d/16211.bugfix deleted file mode 100644 index ab1816386c1c..000000000000 --- a/changelog.d/16211.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where uploading images would fail if we could not generate thumbnails for them. diff --git a/changelog.d/16212.misc b/changelog.d/16212.misc deleted file mode 100644 index 19cf9b102d37..000000000000 --- a/changelog.d/16212.misc +++ /dev/null @@ -1 +0,0 @@ -Log the details of background update failures. diff --git a/changelog.d/16213.misc b/changelog.d/16213.misc deleted file mode 100644 index 8c14f5fd51ad..000000000000 --- a/changelog.d/16213.misc +++ /dev/null @@ -1 +0,0 @@ -Fix the latest-deps CI job. diff --git a/changelog.d/16220.bugfix b/changelog.d/16220.bugfix deleted file mode 100644 index dcfac6bda110..000000000000 --- a/changelog.d/16220.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a performance regression introduced in Synapse 1.91.0 where event persistence would cause excessive CPU usage over time. diff --git a/changelog.d/16241.misc b/changelog.d/16241.misc deleted file mode 100644 index 0fc5f34c5ce1..000000000000 --- a/changelog.d/16241.misc +++ /dev/null @@ -1 +0,0 @@ -Cache device resync requests over replication. diff --git a/docs/admin_api/account_validity.md b/docs/admin_api/account_validity.md index dfa69e515bfc..87d8f7150e8c 100644 --- a/docs/admin_api/account_validity.md +++ b/docs/admin_api/account_validity.md @@ -1,7 +1,5 @@ # Account validity API -**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) - This API allows a server administrator to manage the validity of an account. To use it, you must enable the account validity feature (under `account_validity`) in Synapse's configuration. diff --git a/docs/admin_api/register_api.md b/docs/admin_api/register_api.md index e9a235ada5e2..dd2830f3a18a 100644 --- a/docs/admin_api/register_api.md +++ b/docs/admin_api/register_api.md @@ -1,7 +1,5 @@ # Shared-Secret Registration -**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) - This API allows for the creation of users in an administrative and non-interactive way. This is generally used for bootstrapping a Synapse instance with administrator accounts. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 8032e05497ad..99abfea3a0fb 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -218,7 +218,7 @@ The following parameters should be set in the URL: - `name` - Is optional and filters to only return users with user ID localparts **or** displaynames that contain this value. - `guests` - string representing a bool - Is optional and if `false` will **exclude** guest users. - Defaults to `true` to include guest users. This parameter is not supported when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) + Defaults to `true` to include guest users. - `admins` - Optional flag to filter admins. If `true`, only admins are queried. If `false`, admins are excluded from the query. When the flag is absent (the default), **both** admins and non-admins are included in the search results. - `deactivated` - string representing a bool - Is optional and if `true` will **include** deactivated users. @@ -390,8 +390,6 @@ The following actions are **NOT** performed. The list may be incomplete. ## Reset password -**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) - Changes the password of another user. This will automatically log the user out of all their devices. The api is: @@ -415,8 +413,6 @@ The parameter `logout_devices` is optional and defaults to `true`. ## Get whether a user is a server administrator or not -**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) - The api is: ``` @@ -434,8 +430,6 @@ A response body like the following is returned: ## Change whether a user is a server administrator or not -**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) - Note that you cannot demote yourself. The api is: @@ -729,8 +723,6 @@ delete largest/smallest or newest/oldest files first. ## Login as a user -**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) - Get an access token that can be used to authenticate as that user. Useful for when admins wish to do actions on behalf of a user. diff --git a/docs/development/releases.md b/docs/development/releases.md index 6e83c81e27eb..c9a8c6994597 100644 --- a/docs/development/releases.md +++ b/docs/development/releases.md @@ -12,7 +12,7 @@ Note that this schedule might be modified depending on the availability of the Synapse team, e.g. releases may be skipped to avoid holidays. Release announcements can be found in the -[release category of the Matrix blog](https://matrix.org/category/releases). +[release category of the Matrix blog](https://matrix.org/blog/category/releases). ## Bugfix releases @@ -34,4 +34,4 @@ be held to be released together. In some cases, a pre-disclosure of a security release will be issued as a notice to Synapse operators that there is an upcoming security release. These can be -found in the [security category of the Matrix blog](https://matrix.org/category/security). +found in the [security category of the Matrix blog](https://matrix.org/blog/category/security). diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md index ba95bcf03801..c5130859d426 100644 --- a/docs/usage/administration/admin_api/registration_tokens.md +++ b/docs/usage/administration/admin_api/registration_tokens.md @@ -1,7 +1,5 @@ # Registration Tokens -**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) - This API allows you to manage tokens which can be used to authenticate registration requests, as proposed in [MSC3231](https://github.com/matrix-org/matrix-doc/blob/main/proposals/3231-token-authenticated-registration.md) diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 0b1725816e3d..743c51d76adf 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3420,7 +3420,6 @@ Has the following sub-options: to style the login flow according to the identity provider in question. See the [spec](https://spec.matrix.org/latest/) for possible options here. * `server_url`: The URL of the CAS authorization endpoint. -* `protocol_version`: The CAS protocol version, defaults to none (version 3 is required if you want to use "required_attributes"). * `displayname_attribute`: The attribute of the CAS response to use as the display name. If no name is given here, no displayname will be set. * `required_attributes`: It is possible to configure Synapse to only allow logins if CAS attributes @@ -3434,7 +3433,6 @@ Example configuration: cas_config: enabled: true server_url: "https://cas-server.com" - protocol_version: 3 displayname_attribute: name required_attributes: userGroup: "staff" @@ -3867,19 +3865,6 @@ Example configuration: ```yaml forget_rooms_on_leave: false ``` ---- -### `exclude_rooms_from_sync` -A list of rooms to exclude from sync responses. This is useful for server -administrators wishing to group users into a room without these users being able -to see it from their client. - -By default, no room is excluded. - -Example configuration: -```yaml -exclude_rooms_from_sync: - - !foo:example.com -``` --- ## Opentracing diff --git a/mypy.ini b/mypy.ini index fb5f44c939d8..311a951aa8de 100644 --- a/mypy.ini +++ b/mypy.ini @@ -87,9 +87,18 @@ ignore_missing_imports = True [mypy-saml2.*] ignore_missing_imports = True +[mypy-service_identity.*] +ignore_missing_imports = True + [mypy-srvlookup.*] ignore_missing_imports = True # https://github.com/twisted/treq/pull/366 [mypy-treq.*] ignore_missing_imports = True + +[mypy-incremental.*] +ignore_missing_imports = True + +[mypy-setuptools_rust.*] +ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index 0688d5d92e3c..e62c10da9f76 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "alabaster" @@ -148,33 +148,36 @@ lxml = ["lxml"] [[package]] name = "black" -version = "23.7.0" +version = "23.3.0" description = "The uncompromising code formatter." optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, - {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, - {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, - {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, - {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, - {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, - {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, - {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, - {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, - {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, - {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, - {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, - {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, + {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"}, + {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"}, + {file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"}, + {file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"}, + {file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"}, + {file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"}, + {file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"}, + {file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"}, + {file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"}, + {file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"}, + {file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"}, + {file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"}, + {file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"}, + {file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"}, + {file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"}, + {file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"}, ] [package.dependencies] @@ -541,13 +544,13 @@ files = [ [[package]] name = "elementpath" -version = "4.1.5" +version = "4.1.0" description = "XPath 1.0/2.0/3.0/3.1 parsers and selectors for ElementTree and lxml" optional = true python-versions = ">=3.7" files = [ - {file = "elementpath-4.1.5-py3-none-any.whl", hash = "sha256:2ac1a2fb31eb22bbbf817f8cf6752f844513216263f0e3892c8e79782fe4bb55"}, - {file = "elementpath-4.1.5.tar.gz", hash = "sha256:c2d6dc524b29ef751ecfc416b0627668119d8812441c555d7471da41d4bacb8d"}, + {file = "elementpath-4.1.0-py3-none-any.whl", hash = "sha256:2b1b524223d70fd6dd63a36b9bc32e4919c96a272c2d1454094c4d85086bc6f8"}, + {file = "elementpath-4.1.0.tar.gz", hash = "sha256:dbd7eba3cf0b3b4934f627ba24851a3e0798ef2bc9104555a4cd831f2e6e8e14"}, ] [package.extras] @@ -1445,43 +1448,43 @@ files = [ [[package]] name = "mypy" -version = "1.4.1" +version = "1.0.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.7" files = [ - {file = "mypy-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:566e72b0cd6598503e48ea610e0052d1b8168e60a46e0bfd34b3acf2d57f96a8"}, - {file = "mypy-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ca637024ca67ab24a7fd6f65d280572c3794665eaf5edcc7e90a866544076878"}, - {file = "mypy-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dde1d180cd84f0624c5dcaaa89c89775550a675aff96b5848de78fb11adabcd"}, - {file = "mypy-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c4d8e89aa7de683e2056a581ce63c46a0c41e31bd2b6d34144e2c80f5ea53dc"}, - {file = "mypy-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:bfdca17c36ae01a21274a3c387a63aa1aafe72bff976522886869ef131b937f1"}, - {file = "mypy-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7549fbf655e5825d787bbc9ecf6028731973f78088fbca3a1f4145c39ef09462"}, - {file = "mypy-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98324ec3ecf12296e6422939e54763faedbfcc502ea4a4c38502082711867258"}, - {file = "mypy-1.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:141dedfdbfe8a04142881ff30ce6e6653c9685b354876b12e4fe6c78598b45e2"}, - {file = "mypy-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8207b7105829eca6f3d774f64a904190bb2231de91b8b186d21ffd98005f14a7"}, - {file = "mypy-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:16f0db5b641ba159eff72cff08edc3875f2b62b2fa2bc24f68c1e7a4e8232d01"}, - {file = "mypy-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:470c969bb3f9a9efcedbadcd19a74ffb34a25f8e6b0e02dae7c0e71f8372f97b"}, - {file = "mypy-1.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5952d2d18b79f7dc25e62e014fe5a23eb1a3d2bc66318df8988a01b1a037c5b"}, - {file = "mypy-1.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:190b6bab0302cec4e9e6767d3eb66085aef2a1cc98fe04936d8a42ed2ba77bb7"}, - {file = "mypy-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9d40652cc4fe33871ad3338581dca3297ff5f2213d0df345bcfbde5162abf0c9"}, - {file = "mypy-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:01fd2e9f85622d981fd9063bfaef1aed6e336eaacca00892cd2d82801ab7c042"}, - {file = "mypy-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2460a58faeea905aeb1b9b36f5065f2dc9a9c6e4c992a6499a2360c6c74ceca3"}, - {file = "mypy-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2746d69a8196698146a3dbe29104f9eb6a2a4d8a27878d92169a6c0b74435b6"}, - {file = "mypy-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ae704dcfaa180ff7c4cfbad23e74321a2b774f92ca77fd94ce1049175a21c97f"}, - {file = "mypy-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:43d24f6437925ce50139a310a64b2ab048cb2d3694c84c71c3f2a1626d8101dc"}, - {file = "mypy-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c482e1246726616088532b5e964e39765b6d1520791348e6c9dc3af25b233828"}, - {file = "mypy-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:43b592511672017f5b1a483527fd2684347fdffc041c9ef53428c8dc530f79a3"}, - {file = "mypy-1.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34a9239d5b3502c17f07fd7c0b2ae6b7dd7d7f6af35fbb5072c6208e76295816"}, - {file = "mypy-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5703097c4936bbb9e9bce41478c8d08edd2865e177dc4c52be759f81ee4dd26c"}, - {file = "mypy-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e02d700ec8d9b1859790c0475df4e4092c7bf3272a4fd2c9f33d87fac4427b8f"}, - {file = "mypy-1.4.1-py3-none-any.whl", hash = "sha256:45d32cec14e7b97af848bddd97d85ea4f0db4d5a149ed9676caa4eb2f7402bb4"}, - {file = "mypy-1.4.1.tar.gz", hash = "sha256:9bbcd9ab8ea1f2e1c8031c21445b511442cc45c89951e49bbf852cbb70755b1b"}, -] - -[package.dependencies] -mypy-extensions = ">=1.0.0" + {file = "mypy-1.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:71a808334d3f41ef011faa5a5cd8153606df5fc0b56de5b2e89566c8093a0c9a"}, + {file = "mypy-1.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:920169f0184215eef19294fa86ea49ffd4635dedfdea2b57e45cb4ee85d5ccaf"}, + {file = "mypy-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27a0f74a298769d9fdc8498fcb4f2beb86f0564bcdb1a37b58cbbe78e55cf8c0"}, + {file = "mypy-1.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:65b122a993d9c81ea0bfde7689b3365318a88bde952e4dfa1b3a8b4ac05d168b"}, + {file = "mypy-1.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5deb252fd42a77add936b463033a59b8e48eb2eaec2976d76b6878d031933fe4"}, + {file = "mypy-1.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2013226d17f20468f34feddd6aae4635a55f79626549099354ce641bc7d40262"}, + {file = "mypy-1.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:48525aec92b47baed9b3380371ab8ab6e63a5aab317347dfe9e55e02aaad22e8"}, + {file = "mypy-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c96b8a0c019fe29040d520d9257d8c8f122a7343a8307bf8d6d4a43f5c5bfcc8"}, + {file = "mypy-1.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:448de661536d270ce04f2d7dddaa49b2fdba6e3bd8a83212164d4174ff43aa65"}, + {file = "mypy-1.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:d42a98e76070a365a1d1c220fcac8aa4ada12ae0db679cb4d910fabefc88b994"}, + {file = "mypy-1.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e64f48c6176e243ad015e995de05af7f22bbe370dbb5b32bd6988438ec873919"}, + {file = "mypy-1.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fdd63e4f50e3538617887e9aee91855368d9fc1dea30da743837b0df7373bc4"}, + {file = "mypy-1.0.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:dbeb24514c4acbc78d205f85dd0e800f34062efcc1f4a4857c57e4b4b8712bff"}, + {file = "mypy-1.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a2948c40a7dd46c1c33765718936669dc1f628f134013b02ff5ac6c7ef6942bf"}, + {file = "mypy-1.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5bc8d6bd3b274dd3846597855d96d38d947aedba18776aa998a8d46fabdaed76"}, + {file = "mypy-1.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:17455cda53eeee0a4adb6371a21dd3dbf465897de82843751cf822605d152c8c"}, + {file = "mypy-1.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e831662208055b006eef68392a768ff83596035ffd6d846786578ba1714ba8f6"}, + {file = "mypy-1.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e60d0b09f62ae97a94605c3f73fd952395286cf3e3b9e7b97f60b01ddfbbda88"}, + {file = "mypy-1.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:0af4f0e20706aadf4e6f8f8dc5ab739089146b83fd53cb4a7e0e850ef3de0bb6"}, + {file = "mypy-1.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:24189f23dc66f83b839bd1cce2dfc356020dfc9a8bae03978477b15be61b062e"}, + {file = "mypy-1.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93a85495fb13dc484251b4c1fd7a5ac370cd0d812bbfc3b39c1bafefe95275d5"}, + {file = "mypy-1.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f546ac34093c6ce33f6278f7c88f0f147a4849386d3bf3ae193702f4fe31407"}, + {file = "mypy-1.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c6c2ccb7af7154673c591189c3687b013122c5a891bb5651eca3db8e6c6c55bd"}, + {file = "mypy-1.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b5a824b58c7c822c51bc66308e759243c32631896743f030daf449fe3677f3"}, + {file = "mypy-1.0.1-py3-none-any.whl", hash = "sha256:eda5c8b9949ed411ff752b9a01adda31afe7eae1e53e946dbdf9db23865e66c4"}, + {file = "mypy-1.0.1.tar.gz", hash = "sha256:28cea5a6392bb43d266782983b5a4216c25544cd7d80be681a155ddcdafd152d"}, +] + +[package.dependencies] +mypy-extensions = ">=0.4.3" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=4.1.0" +typing-extensions = ">=3.10" [package.extras] dmypy = ["psutil (>=4.0)"] @@ -1502,17 +1505,17 @@ files = [ [[package]] name = "mypy-zope" -version = "1.0.0" +version = "0.9.1" description = "Plugin for mypy to support zope interfaces" optional = false python-versions = "*" files = [ - {file = "mypy-zope-1.0.0.tar.gz", hash = "sha256:be815c2fcb5333aa87e8ec682029ad3214142fe2a05ea383f9ff2d77c98008b7"}, - {file = "mypy_zope-1.0.0-py3-none-any.whl", hash = "sha256:9732e9b2198f2aec3343b38a51905ff49d44dc9e39e8e8bc6fc490b232388209"}, + {file = "mypy-zope-0.9.1.tar.gz", hash = "sha256:4c87dbc71fec35f6533746ecdf9d400cd9281338d71c16b5676bb5ed00a97ca2"}, + {file = "mypy_zope-0.9.1-py3-none-any.whl", hash = "sha256:733d4399affe9e61e332ce9c4049418d6775c39b473e4b9f409d51c207c1b71a"}, ] [package.dependencies] -mypy = ">=1.0.0,<1.5.0" +mypy = ">=1.0.0,<1.1.0" "zope.interface" = "*" "zope.schema" = "*" @@ -1607,13 +1610,13 @@ files = [ [[package]] name = "phonenumbers" -version = "8.13.19" +version = "8.13.18" description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers." optional = false python-versions = "*" files = [ - {file = "phonenumbers-8.13.19-py2.py3-none-any.whl", hash = "sha256:ba542f20f6dc83be8f127f240f9b5b7e7c1dec42aceff1879400d4dc0c781d81"}, - {file = "phonenumbers-8.13.19.tar.gz", hash = "sha256:38180247697240ccedd74dec4bfbdbc22bb108b9c5f991f270ca3e41395e6f96"}, + {file = "phonenumbers-8.13.18-py2.py3-none-any.whl", hash = "sha256:3d802739a22592e4127139349937753dee9b6a20bdd5d56847cd885bdc766b1f"}, + {file = "phonenumbers-8.13.18.tar.gz", hash = "sha256:b360c756252805d44b447b5bca6d250cf6bd6c69b6f0f4258f3bfe5ab81bef69"}, ] [[package]] @@ -1741,22 +1744,24 @@ twisted = ["twisted"] [[package]] name = "psycopg2" -version = "2.9.7" +version = "2.9.6" description = "psycopg2 - Python-PostgreSQL Database Adapter" optional = true python-versions = ">=3.6" files = [ - {file = "psycopg2-2.9.7-cp310-cp310-win32.whl", hash = "sha256:1a6a2d609bce44f78af4556bea0c62a5e7f05c23e5ea9c599e07678995609084"}, - {file = "psycopg2-2.9.7-cp310-cp310-win_amd64.whl", hash = "sha256:b22ed9c66da2589a664e0f1ca2465c29b75aaab36fa209d4fb916025fb9119e5"}, - {file = "psycopg2-2.9.7-cp311-cp311-win32.whl", hash = "sha256:44d93a0109dfdf22fe399b419bcd7fa589d86895d3931b01fb321d74dadc68f1"}, - {file = "psycopg2-2.9.7-cp311-cp311-win_amd64.whl", hash = "sha256:91e81a8333a0037babfc9fe6d11e997a9d4dac0f38c43074886b0d9dead94fe9"}, - {file = "psycopg2-2.9.7-cp37-cp37m-win32.whl", hash = "sha256:d1210fcf99aae6f728812d1d2240afc1dc44b9e6cba526a06fb8134f969957c2"}, - {file = "psycopg2-2.9.7-cp37-cp37m-win_amd64.whl", hash = "sha256:e9b04cbef584310a1ac0f0d55bb623ca3244c87c51187645432e342de9ae81a8"}, - {file = "psycopg2-2.9.7-cp38-cp38-win32.whl", hash = "sha256:d5c5297e2fbc8068d4255f1e606bfc9291f06f91ec31b2a0d4c536210ac5c0a2"}, - {file = "psycopg2-2.9.7-cp38-cp38-win_amd64.whl", hash = "sha256:8275abf628c6dc7ec834ea63f6f3846bf33518907a2b9b693d41fd063767a866"}, - {file = "psycopg2-2.9.7-cp39-cp39-win32.whl", hash = "sha256:c7949770cafbd2f12cecc97dea410c514368908a103acf519f2a346134caa4d5"}, - {file = "psycopg2-2.9.7-cp39-cp39-win_amd64.whl", hash = "sha256:b6bd7d9d3a7a63faae6edf365f0ed0e9b0a1aaf1da3ca146e6b043fb3eb5d723"}, - {file = "psycopg2-2.9.7.tar.gz", hash = "sha256:f00cc35bd7119f1fed17b85bd1007855194dde2cbd8de01ab8ebb17487440ad8"}, + {file = "psycopg2-2.9.6-cp310-cp310-win32.whl", hash = "sha256:f7a7a5ee78ba7dc74265ba69e010ae89dae635eea0e97b055fb641a01a31d2b1"}, + {file = "psycopg2-2.9.6-cp310-cp310-win_amd64.whl", hash = "sha256:f75001a1cbbe523e00b0ef896a5a1ada2da93ccd752b7636db5a99bc57c44494"}, + {file = "psycopg2-2.9.6-cp311-cp311-win32.whl", hash = "sha256:53f4ad0a3988f983e9b49a5d9765d663bbe84f508ed655affdb810af9d0972ad"}, + {file = "psycopg2-2.9.6-cp311-cp311-win_amd64.whl", hash = "sha256:b81fcb9ecfc584f661b71c889edeae70bae30d3ef74fa0ca388ecda50b1222b7"}, + {file = "psycopg2-2.9.6-cp36-cp36m-win32.whl", hash = "sha256:11aca705ec888e4f4cea97289a0bf0f22a067a32614f6ef64fcf7b8bfbc53744"}, + {file = "psycopg2-2.9.6-cp36-cp36m-win_amd64.whl", hash = "sha256:36c941a767341d11549c0fbdbb2bf5be2eda4caf87f65dfcd7d146828bd27f39"}, + {file = "psycopg2-2.9.6-cp37-cp37m-win32.whl", hash = "sha256:869776630c04f335d4124f120b7fb377fe44b0a7645ab3c34b4ba42516951889"}, + {file = "psycopg2-2.9.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a8ad4a47f42aa6aec8d061fdae21eaed8d864d4bb0f0cade5ad32ca16fcd6258"}, + {file = "psycopg2-2.9.6-cp38-cp38-win32.whl", hash = "sha256:2362ee4d07ac85ff0ad93e22c693d0f37ff63e28f0615a16b6635a645f4b9214"}, + {file = "psycopg2-2.9.6-cp38-cp38-win_amd64.whl", hash = "sha256:d24ead3716a7d093b90b27b3d73459fe8cd90fd7065cf43b3c40966221d8c394"}, + {file = "psycopg2-2.9.6-cp39-cp39-win32.whl", hash = "sha256:1861a53a6a0fd248e42ea37c957d36950da00266378746588eab4f4b5649e95f"}, + {file = "psycopg2-2.9.6-cp39-cp39-win_amd64.whl", hash = "sha256:ded2faa2e6dfb430af7713d87ab4abbfc764d8d7fb73eafe96a24155f906ebf5"}, + {file = "psycopg2-2.9.6.tar.gz", hash = "sha256:f15158418fd826831b28585e2ab48ed8df2d0d98f502a2b4fe619e7d5ca29011"}, ] [[package]] @@ -2077,7 +2082,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2085,15 +2089,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2110,7 +2107,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2118,7 +2114,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -2334,28 +2329,28 @@ files = [ [[package]] name = "ruff" -version = "0.0.286" +version = "0.0.277" description = "An extremely fast Python linter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.286-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:8e22cb557e7395893490e7f9cfea1073d19a5b1dd337f44fd81359b2767da4e9"}, - {file = "ruff-0.0.286-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:68ed8c99c883ae79a9133cb1a86d7130feee0397fdf5ba385abf2d53e178d3fa"}, - {file = "ruff-0.0.286-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8301f0bb4ec1a5b29cfaf15b83565136c47abefb771603241af9d6038f8981e8"}, - {file = "ruff-0.0.286-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:acc4598f810bbc465ce0ed84417ac687e392c993a84c7eaf3abf97638701c1ec"}, - {file = "ruff-0.0.286-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88c8e358b445eb66d47164fa38541cfcc267847d1e7a92dd186dddb1a0a9a17f"}, - {file = "ruff-0.0.286-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0433683d0c5dbcf6162a4beb2356e820a593243f1fa714072fec15e2e4f4c939"}, - {file = "ruff-0.0.286-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddb61a0c4454cbe4623f4a07fef03c5ae921fe04fede8d15c6e36703c0a73b07"}, - {file = "ruff-0.0.286-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:47549c7c0be24c8ae9f2bce6f1c49fbafea83bca80142d118306f08ec7414041"}, - {file = "ruff-0.0.286-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:559aa793149ac23dc4310f94f2c83209eedb16908a0343663be19bec42233d25"}, - {file = "ruff-0.0.286-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d73cfb1c3352e7aa0ce6fb2321f36fa1d4a2c48d2ceac694cb03611ddf0e4db6"}, - {file = "ruff-0.0.286-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:3dad93b1f973c6d1db4b6a5da8690c5625a3fa32bdf38e543a6936e634b83dc3"}, - {file = "ruff-0.0.286-py3-none-musllinux_1_2_i686.whl", hash = "sha256:26afc0851f4fc3738afcf30f5f8b8612a31ac3455cb76e611deea80f5c0bf3ce"}, - {file = "ruff-0.0.286-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9b6b116d1c4000de1b9bf027131dbc3b8a70507788f794c6b09509d28952c512"}, - {file = "ruff-0.0.286-py3-none-win32.whl", hash = "sha256:556e965ac07c1e8c1c2d759ac512e526ecff62c00fde1a046acb088d3cbc1a6c"}, - {file = "ruff-0.0.286-py3-none-win_amd64.whl", hash = "sha256:5d295c758961376c84aaa92d16e643d110be32add7465e197bfdaec5a431a107"}, - {file = "ruff-0.0.286-py3-none-win_arm64.whl", hash = "sha256:1d6142d53ab7f164204b3133d053c4958d4d11ec3a39abf23a40b13b0784e3f0"}, - {file = "ruff-0.0.286.tar.gz", hash = "sha256:f1e9d169cce81a384a26ee5bb8c919fe9ae88255f39a1a69fd1ebab233a85ed2"}, + {file = "ruff-0.0.277-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:3250b24333ef419b7a232080d9724ccc4d2da1dbbe4ce85c4caa2290d83200f8"}, + {file = "ruff-0.0.277-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3e60605e07482183ba1c1b7237eca827bd6cbd3535fe8a4ede28cbe2a323cb97"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7baa97c3d7186e5ed4d5d4f6834d759a27e56cf7d5874b98c507335f0ad5aadb"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:74e4b206cb24f2e98a615f87dbe0bde18105217cbcc8eb785bb05a644855ba50"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:479864a3ccd8a6a20a37a6e7577bdc2406868ee80b1e65605478ad3b8eb2ba0b"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:468bfb0a7567443cec3d03cf408d6f562b52f30c3c29df19927f1e0e13a40cd7"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f32ec416c24542ca2f9cc8c8b65b84560530d338aaf247a4a78e74b99cd476b4"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:14a7b2f00f149c5a295f188a643ac25226ff8a4d08f7a62b1d4b0a1dc9f9b85c"}, + {file = "ruff-0.0.277-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9879f59f763cc5628aa01c31ad256a0f4dc61a29355c7315b83c2a5aac932b5"}, + {file = "ruff-0.0.277-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f612e0a14b3d145d90eb6ead990064e22f6f27281d847237560b4e10bf2251f3"}, + {file = "ruff-0.0.277-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:323b674c98078be9aaded5b8b51c0d9c424486566fb6ec18439b496ce79e5998"}, + {file = "ruff-0.0.277-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3a43fbe026ca1a2a8c45aa0d600a0116bec4dfa6f8bf0c3b871ecda51ef2b5dd"}, + {file = "ruff-0.0.277-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:734165ea8feb81b0d53e3bf523adc2413fdb76f1264cde99555161dd5a725522"}, + {file = "ruff-0.0.277-py3-none-win32.whl", hash = "sha256:88d0f2afb2e0c26ac1120e7061ddda2a566196ec4007bd66d558f13b374b9efc"}, + {file = "ruff-0.0.277-py3-none-win_amd64.whl", hash = "sha256:6fe81732f788894a00f6ade1fe69e996cc9e485b7c35b0f53fb00284397284b2"}, + {file = "ruff-0.0.277-py3-none-win_arm64.whl", hash = "sha256:2d4444c60f2e705c14cd802b55cd2b561d25bf4311702c463a002392d3116b22"}, + {file = "ruff-0.0.277.tar.gz", hash = "sha256:2dab13cdedbf3af6d4427c07f47143746b6b95d9e4a254ac369a0edb9280a0d2"}, ] [[package]] @@ -2390,13 +2385,13 @@ doc = ["Sphinx", "sphinx-rtd-theme"] [[package]] name = "sentry-sdk" -version = "1.30.0" +version = "1.29.2" description = "Python client for Sentry (https://sentry.io)" optional = true python-versions = "*" files = [ - {file = "sentry-sdk-1.30.0.tar.gz", hash = "sha256:7dc873b87e1faf4d00614afd1058bfa1522942f33daef8a59f90de8ed75cd10c"}, - {file = "sentry_sdk-1.30.0-py2.py3-none-any.whl", hash = "sha256:2e53ad63f96bb9da6570ba2e755c267e529edcf58580a2c0d2a11ef26e1e678b"}, + {file = "sentry-sdk-1.29.2.tar.gz", hash = "sha256:a99ee105384788c3f228726a88baf515fe7b5f1d2d0f215a03d194369f158df7"}, + {file = "sentry_sdk-1.29.2-py2.py3-none-any.whl", hash = "sha256:3e17215d8006612e2df02b0e73115eb8376c37e3f586d8436fa41644e605074d"}, ] [package.dependencies] @@ -2419,7 +2414,6 @@ httpx = ["httpx (>=0.16.0)"] huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] -opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] pure-eval = ["asttokens", "executing", "pure-eval"] pymongo = ["pymongo (>=3.1)"] pyspark = ["pyspark (>=2.4.4)"] @@ -2473,19 +2467,18 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( [[package]] name = "setuptools-rust" -version = "1.7.0" +version = "1.6.0" description = "Setuptools Rust extension plugin" optional = false python-versions = ">=3.7" files = [ - {file = "setuptools-rust-1.7.0.tar.gz", hash = "sha256:c7100999948235a38ae7e555fe199aa66c253dc384b125f5d85473bf81eae3a3"}, - {file = "setuptools_rust-1.7.0-py3-none-any.whl", hash = "sha256:071099885949132a2180d16abf907b60837e74b4085047ba7e9c0f5b365310c1"}, + {file = "setuptools-rust-1.6.0.tar.gz", hash = "sha256:c86e734deac330597998bfbc08da45187e6b27837e23bd91eadb320732392262"}, + {file = "setuptools_rust-1.6.0-py3-none-any.whl", hash = "sha256:e28ae09fb7167c44ab34434eb49279307d611547cb56cb9789955cdb54a1aed9"}, ] [package.dependencies] semantic-version = ">=2.8.2,<3" setuptools = ">=62.4" -tomli = {version = ">=1.2.1", markers = "python_version < \"3.11\""} typing-extensions = ">=3.7.4.3" [[package]] @@ -3009,13 +3002,13 @@ files = [ [[package]] name = "types-psycopg2" -version = "2.9.21.11" +version = "2.9.21.10" description = "Typing stubs for psycopg2" optional = false python-versions = "*" files = [ - {file = "types-psycopg2-2.9.21.11.tar.gz", hash = "sha256:d5077eacf90e61db8c0b8eea2fdc9d4a97d7aaa16865fb4bd7034a7571520b4d"}, - {file = "types_psycopg2-2.9.21.11-py3-none-any.whl", hash = "sha256:7a323d7744bc8a882fb5a6f63448e903fc70d3dc0d6da9ec1f9c6c4dc10a7102"}, + {file = "types-psycopg2-2.9.21.10.tar.gz", hash = "sha256:c2600892312ae1c34e12f145749795d93dc4eac3ef7dbf8a9c1bfd45385e80d7"}, + {file = "types_psycopg2-2.9.21.10-py3-none-any.whl", hash = "sha256:918224a0731a3650832e46633e720703b5beef7693a064e777d9748654fcf5e5"}, ] [[package]] @@ -3034,13 +3027,13 @@ cryptography = ">=35.0.0" [[package]] name = "types-pyyaml" -version = "6.0.12.11" +version = "6.0.12.10" description = "Typing stubs for PyYAML" optional = false python-versions = "*" files = [ - {file = "types-PyYAML-6.0.12.11.tar.gz", hash = "sha256:7d340b19ca28cddfdba438ee638cd4084bde213e501a3978738543e27094775b"}, - {file = "types_PyYAML-6.0.12.11-py3-none-any.whl", hash = "sha256:a461508f3096d1d5810ec5ab95d7eeecb651f3a15b71959999988942063bf01d"}, + {file = "types-PyYAML-6.0.12.10.tar.gz", hash = "sha256:ebab3d0700b946553724ae6ca636ea932c1b0868701d4af121630e78d695fc97"}, + {file = "types_PyYAML-6.0.12.10-py3-none-any.whl", hash = "sha256:662fa444963eff9b68120d70cda1af5a5f2aa57900003c2006d7626450eaae5f"}, ] [[package]] @@ -3214,22 +3207,22 @@ files = [ [[package]] name = "xmlschema" -version = "2.4.0" +version = "2.2.2" description = "An XML Schema validator and decoder" optional = true python-versions = ">=3.7" files = [ - {file = "xmlschema-2.4.0-py3-none-any.whl", hash = "sha256:dc87be0caaa61f42649899189aab2fd8e0d567f2cf548433ba7b79278d231a4a"}, - {file = "xmlschema-2.4.0.tar.gz", hash = "sha256:d74cd0c10866ac609e1ef94a5a69b018ad16e39077bc6393408b40c6babee793"}, + {file = "xmlschema-2.2.2-py3-none-any.whl", hash = "sha256:557f3632b54b6ff10576736bba62e43db84eb60f6465a83818576cd9ffcc1799"}, + {file = "xmlschema-2.2.2.tar.gz", hash = "sha256:0caa96668807b4b51c42a0fe2b6610752bc59f069615df3e34dcfffb962973fd"}, ] [package.dependencies] -elementpath = ">=4.1.5,<5.0.0" +elementpath = ">=4.0.0,<5.0.0" [package.extras] -codegen = ["elementpath (>=4.1.5,<5.0.0)", "jinja2"] -dev = ["Sphinx", "coverage", "elementpath (>=4.1.5,<5.0.0)", "flake8", "jinja2", "lxml", "lxml-stubs", "memory-profiler", "mypy", "sphinx-rtd-theme", "tox"] -docs = ["Sphinx", "elementpath (>=4.1.5,<5.0.0)", "jinja2", "sphinx-rtd-theme"] +codegen = ["elementpath (>=4.0.0,<5.0.0)", "jinja2"] +dev = ["Sphinx", "coverage", "elementpath (>=4.0.0,<5.0.0)", "flake8", "jinja2", "lxml", "lxml-stubs", "memory-profiler", "mypy", "sphinx-rtd-theme", "tox"] +docs = ["Sphinx", "elementpath (>=4.0.0,<5.0.0)", "jinja2", "sphinx-rtd-theme"] [[package]] name = "zipp" @@ -3350,4 +3343,4 @@ user-search = ["pyicu"] [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "4a3a82becd89b91e76e2bc2f8ba72123f665c517d9b841d9a34cd01b83a1adc3" +content-hash = "0a8c6605e7e1d0ac7188a5d02b47a029bfb0f917458b87cb40755911442383d8" diff --git a/pyproject.toml b/pyproject.toml index c1f95e945847..2a4ff1ea01c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ showcontent = true [tool.black] -target-version = ['py38', 'py39', 'py310', 'py311'] +target-version = ['py37', 'py38', 'py39', 'py310'] # black ignores everything in .gitignore by default, see # https://black.readthedocs.io/en/stable/usage_and_configuration/file_collection_and_discovery.html#gitignore # Use `extend-exclude` if you want to exclude something in addition to this. @@ -306,13 +306,10 @@ all = [ ] [tool.poetry.dev-dependencies] -# We pin development dependencies in poetry.lock so that our tests don't start -# failing on new releases. Keeping lower bounds loose here means that dependabot -# can bump versions without having to update the content-hash in the lockfile. -# This helps prevents merge conflicts when running a batch of dependabot updates. +# We pin black so that our tests don't start failing on new releases. isort = ">=5.10.1" -black = ">=22.7.0" -ruff = "0.0.286" +black = ">=22.3.0" +ruff = "0.0.277" # Typechecking lxml-stubs = ">=0.4.0" diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 14071105a05b..6e1eab2a3b29 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -197,6 +197,7 @@ fn bench_eval_message(b: &mut Bencher) { false, false, false, + false, ); b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 59fd27665aee..00baceda91fa 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -228,7 +228,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ // We don't want to notify on edits *unless* the edit directly mentions a // user, which is handled above. PushRule { - rule_id: Cow::Borrowed("global/override/.m.rule.suppress_edits"), + rule_id: Cow::Borrowed("global/override/.org.matrix.msc3958.suppress_edits"), priority_class: 5, conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventPropertyIs( EventPropertyIsCondition { diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 5b9bf9b26ae1..48e670478bf7 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -564,7 +564,7 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true), + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false), None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 8e91f506cc42..829fb79d0e5b 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -527,6 +527,7 @@ pub struct FilteredPushRules { msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, + msc3958_suppress_edits_enabled: bool, } #[pymethods] @@ -538,6 +539,7 @@ impl FilteredPushRules { msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, + msc3958_suppress_edits_enabled: bool, ) -> Self { Self { push_rules, @@ -545,6 +547,7 @@ impl FilteredPushRules { msc1767_enabled, msc3381_polls_enabled, msc3664_enabled, + msc3958_suppress_edits_enabled, } } @@ -581,6 +584,12 @@ impl FilteredPushRules { return false; } + if !self.msc3958_suppress_edits_enabled + && rule.rule_id == "global/override/.org.matrix.msc3958.suppress_edits" + { + return false; + } + true }) .map(|r| { diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 1f432d4ecfbf..d573a37b9aff 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -46,6 +46,7 @@ class FilteredPushRules: msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, + msc3958_suppress_edits_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... diff --git a/synapse/__init__.py b/synapse/__init__.py index 4a9bbc4d57b7..2f9c22a83352 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -21,14 +21,9 @@ import sys from typing import Any, Dict -from PIL import ImageFile - from synapse.util.rust import check_rust_lib_up_to_date from synapse.util.stringutils import strtobool -# Allow truncated JPEG images to be thumbnailed. -ImageFile.LOAD_TRUNCATED_IMAGES = True - # Check that we're not running on an unsupported Python version. # # Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index ab2b29cf1b49..49242800b858 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -482,10 +482,7 @@ def r( do_backward[0] = False if forward_rows or backward_rows: - assert txn.description is not None - headers: Optional[List[str]] = [ - column[0] for column in txn.description - ] + headers = [column[0] for column in txn.description] else: headers = None @@ -547,7 +544,6 @@ async def handle_search_table( def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select, (forward_chunk, self.batch_size)) rows = txn.fetchall() - assert txn.description is not None headers = [column[0] for column in txn.description] return headers, rows @@ -923,8 +919,7 @@ async def _setup_sent_transactions(self) -> Tuple[int, int, int]: def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select) rows = txn.fetchall() - assert txn.description is not None - headers = [column[0] for column in txn.description] + headers: List[str] = [column[0] for column in txn.description] ts_ind = headers.index("ts") diff --git a/synapse/api/errors.py b/synapse/api/errors.py index fdb2955be82b..7ffd72c42cd4 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -16,7 +16,6 @@ """Contains exceptions and error codes.""" import logging -import math import typing from enum import Enum from http import HTTPStatus @@ -211,11 +210,6 @@ def __init__( def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, **self._additional_fields) - @property - def debug_context(self) -> Optional[str]: - """Override this to add debugging context that shouldn't be sent to clients.""" - return None - class InvalidAPICallError(SynapseError): """You called an existing API endpoint, but fed that endpoint @@ -509,31 +503,19 @@ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled.""" - include_retry_after_header = False - def __init__( self, - limiter_name: str, code: int = 429, + msg: str = "Too Many Requests", retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - headers = ( - {"Retry-After": str(math.ceil(retry_after_ms / 1000))} - if self.include_retry_after_header and retry_after_ms is not None - else None - ) - super().__init__(code, "Too Many Requests", errcode, headers=headers) + super().__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms - self.limiter_name = limiter_name def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) - @property - def debug_context(self) -> Optional[str]: - return self.limiter_name - class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store""" diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 887b214d64a3..511790c7c5e4 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -61,16 +61,12 @@ class Ratelimiter: """ def __init__( - self, - store: DataStore, - clock: Clock, - cfg: RatelimitSettings, + self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int ): self.clock = clock - self.rate_hz = cfg.per_second - self.burst_count = cfg.burst_count + self.rate_hz = rate_hz + self.burst_count = burst_count self.store = store - self._limiter_name = cfg.key # An ordered dictionary representing the token buckets tracked by this rate # limiter. Each entry maps a key of arbitrary type to a tuple representing: @@ -309,8 +305,7 @@ async def ratelimit( if not allowed: raise LimitExceededError( - limiter_name=self._limiter_name, - retry_after_ms=int(1000 * (time_allowed - time_now_s)), + retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) @@ -327,9 +322,7 @@ def __init__( # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - store=self.store, - clock=self.clock, - cfg=RatelimitSettings(key=rc_message.key, per_second=0, burst_count=0), + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 ) self._rc_message = rc_message @@ -339,7 +332,8 @@ def __init__( self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( store=self.store, clock=self.clock, - cfg=rc_admin_redaction, + rate_hz=rc_admin_redaction.per_second, + burst_count=rc_admin_redaction.burst_count, ) else: self.admin_redaction_ratelimiter = None diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 69a831812759..1d268a1817cd 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -186,9 +186,9 @@ def parse_size(value: Union[str, int]) -> int: TypeError, if given something other than an integer or a string ValueError: if given a string not of the form described above. """ - if type(value) is int: # noqa: E721 + if type(value) is int: return value - elif isinstance(value, str): + elif type(value) is str: sizes = {"K": 1024, "M": 1024 * 1024} size = 1 suffix = value[-1] @@ -218,9 +218,9 @@ def parse_duration(value: Union[str, int]) -> int: TypeError, if given something other than an integer or a string ValueError: if given a string not of the form described above. """ - if type(value) is int: # noqa: E721 + if type(value) is int: return value - elif isinstance(value, str): + elif type(value) is str: second = 1000 minute = 60 * second hour = 60 * minute diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index a70dfbf41f93..919f81a9b716 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -34,7 +34,7 @@ class AppServiceConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.app_service_config_files = config.get("app_service_config_files", []) if not isinstance(self.app_service_config_files, list) or not all( - isinstance(x, str) for x in self.app_service_config_files + type(x) is str for x in self.app_service_config_files ): raise ConfigError( "Expected '%s' to be a list of AS config files:" diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 6e2d9addbf4c..c4e63e74118c 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -18,7 +18,7 @@ from synapse.config.sso import SsoAttributeRequirement from synapse.types import JsonDict -from ._base import Config, ConfigError +from ._base import Config from ._util import validate_config @@ -41,16 +41,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: public_baseurl = self.root.server.public_baseurl self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" - self.cas_protocol_version = cas_config.get("protocol_version") - if ( - self.cas_protocol_version is not None - and self.cas_protocol_version not in [1, 2, 3] - ): - raise ConfigError( - "Unsupported CAS protocol version %s (only versions 1, 2, 3 are supported)" - % (self.cas_protocol_version,), - ("cas_config", "protocol_version"), - ) self.cas_displayname_attribute = cas_config.get("displayname_attribute") required_attributes = cas_config.get("required_attributes") or {} self.cas_required_attributes = _parsed_required_attributes_def( @@ -64,7 +54,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: else: self.cas_server_url = None self.cas_service_url = None - self.cas_protocol_version = None self.cas_displayname_attribute = None self.cas_required_attributes = [] diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index cabe0d4397cd..277ea4675b29 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -18,7 +18,6 @@ import attr import attr.validators -from synapse.api.errors import LimitExceededError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config import ConfigError from synapse.config._base import Config, RootConfig @@ -384,6 +383,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) + # MSC3959: Do not generate notifications for edits. + self.msc3958_supress_edit_notifs = experimental.get( + "msc3958_supress_edit_notifs", False + ) + # MSC3967: Do not require UIA when first uploading cross signing keys self.msc3967_enabled = experimental.get("msc3967_enabled", False) @@ -407,11 +411,3 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.msc4010_push_rules_account_data = experimental.get( "msc4010_push_rules_account_data", False ) - - # MSC4041: Use HTTP header Retry-After to enable library-assisted retry handling - # - # This is a bit hacky, but the most reasonable way to *alway* include the - # headers. - LimitExceededError.include_retry_after_header = experimental.get( - "msc4041_enabled", False - ) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 4efbaeac0d7f..a5514e70a21d 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional import attr @@ -21,47 +21,16 @@ from ._base import Config -@attr.s(slots=True, frozen=True, auto_attribs=True) class RatelimitSettings: - key: str - per_second: float - burst_count: int - - @classmethod - def parse( - cls, - config: Dict[str, Any], - key: str, + def __init__( + self, + config: Dict[str, float], defaults: Optional[Dict[str, float]] = None, - ) -> "RatelimitSettings": - """Parse config[key] as a new-style rate limiter config. - - The key may refer to a nested dictionary using a full stop (.) to separate - each nested key. For example, use the key "a.b.c" to parse the following: - - a: - b: - c: - per_second: 10 - burst_count: 200 - - If this lookup fails, we'll fallback to the defaults. - """ + ): defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} - rl_config = config - for part in key.split("."): - rl_config = rl_config.get(part, {}) - - # By this point we should have hit the rate limiter parameters. - # We don't actually check this though! - rl_config = cast(Dict[str, float], rl_config) - - return cls( - key=key, - per_second=rl_config.get("per_second", defaults["per_second"]), - burst_count=int(rl_config.get("burst_count", defaults["burst_count"])), - ) + self.per_second = config.get("per_second", defaults["per_second"]) + self.burst_count = int(config.get("burst_count", defaults["burst_count"])) @attr.s(auto_attribs=True) @@ -80,14 +49,15 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # Load the new-style messages config if it exists. Otherwise fall back # to the old method. if "rc_message" in config: - self.rc_message = RatelimitSettings.parse( - config, "rc_message", defaults={"per_second": 0.2, "burst_count": 10.0} + self.rc_message = RatelimitSettings( + config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0} ) else: self.rc_message = RatelimitSettings( - key="rc_messages", - per_second=config.get("rc_messages_per_second", 0.2), - burst_count=config.get("rc_message_burst_count", 10.0), + { + "per_second": config.get("rc_messages_per_second", 0.2), + "burst_count": config.get("rc_message_burst_count", 10.0), + } ) # Load the new-style federation config, if it exists. Otherwise, fall @@ -109,59 +79,51 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: } ) - self.rc_registration = RatelimitSettings.parse(config, "rc_registration", {}) + self.rc_registration = RatelimitSettings(config.get("rc_registration", {})) - self.rc_registration_token_validity = RatelimitSettings.parse( - config, - "rc_registration_token_validity", + self.rc_registration_token_validity = RatelimitSettings( + config.get("rc_registration_token_validity", {}), defaults={"per_second": 0.1, "burst_count": 5}, ) # It is reasonable to login with a bunch of devices at once (i.e. when # setting up an account), but it is *not* valid to continually be # logging into new devices. - self.rc_login_address = RatelimitSettings.parse( - config, - "rc_login.address", + rc_login_config = config.get("rc_login", {}) + self.rc_login_address = RatelimitSettings( + rc_login_config.get("address", {}), defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_account = RatelimitSettings.parse( - config, - "rc_login.account", + self.rc_login_account = RatelimitSettings( + rc_login_config.get("account", {}), defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_failed_attempts = RatelimitSettings.parse( - config, - "rc_login.failed_attempts", - {}, + self.rc_login_failed_attempts = RatelimitSettings( + rc_login_config.get("failed_attempts", {}) ) self.federation_rr_transactions_per_room_per_second = config.get( "federation_rr_transactions_per_room_per_second", 50 ) + rc_admin_redaction = config.get("rc_admin_redaction") self.rc_admin_redaction = None - if "rc_admin_redaction" in config: - self.rc_admin_redaction = RatelimitSettings.parse( - config, "rc_admin_redaction", {} - ) + if rc_admin_redaction: + self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction) - self.rc_joins_local = RatelimitSettings.parse( - config, - "rc_joins.local", + self.rc_joins_local = RatelimitSettings( + config.get("rc_joins", {}).get("local", {}), defaults={"per_second": 0.1, "burst_count": 10}, ) - self.rc_joins_remote = RatelimitSettings.parse( - config, - "rc_joins.remote", + self.rc_joins_remote = RatelimitSettings( + config.get("rc_joins", {}).get("remote", {}), defaults={"per_second": 0.01, "burst_count": 10}, ) # Track the rate of joins to a given room. If there are too many, temporarily # prevent local joins and remote joins via this server. - self.rc_joins_per_room = RatelimitSettings.parse( - config, - "rc_joins_per_room", + self.rc_joins_per_room = RatelimitSettings( + config.get("rc_joins_per_room", {}), defaults={"per_second": 1, "burst_count": 10}, ) @@ -170,37 +132,31 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # * For requests received over federation this is keyed by the origin. # # Note that this isn't exposed in the configuration as it is obscure. - self.rc_key_requests = RatelimitSettings.parse( - config, - "rc_key_requests", + self.rc_key_requests = RatelimitSettings( + config.get("rc_key_requests", {}), defaults={"per_second": 20, "burst_count": 100}, ) - self.rc_3pid_validation = RatelimitSettings.parse( - config, - "rc_3pid_validation", + self.rc_3pid_validation = RatelimitSettings( + config.get("rc_3pid_validation") or {}, defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_room = RatelimitSettings.parse( - config, - "rc_invites.per_room", + self.rc_invites_per_room = RatelimitSettings( + config.get("rc_invites", {}).get("per_room", {}), defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_invites_per_user = RatelimitSettings.parse( - config, - "rc_invites.per_user", + self.rc_invites_per_user = RatelimitSettings( + config.get("rc_invites", {}).get("per_user", {}), defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_issuer = RatelimitSettings.parse( - config, - "rc_invites.per_issuer", + self.rc_invites_per_issuer = RatelimitSettings( + config.get("rc_invites", {}).get("per_issuer", {}), defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_third_party_invite = RatelimitSettings.parse( - config, - "rc_third_party_invite", + self.rc_third_party_invite = RatelimitSettings( + config.get("rc_third_party_invite", {}), defaults={"per_second": 0.0025, "burst_count": 5}, ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 2ac9f8b309cf..3a260a492bea 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -669,18 +669,12 @@ def _is_membership_change_allowed( errcode=Codes.INSUFFICIENT_POWER, ) elif Membership.BAN == membership: - if user_level < ban_level: + if user_level < ban_level or user_level <= target_level: raise UnstableSpecAuthError( 403, "You don't have permission to ban", errcode=Codes.INSUFFICIENT_POWER, ) - elif user_level <= target_level: - raise UnstableSpecAuthError( - 403, - "You don't have permission to ban this user", - errcode=Codes.INSUFFICIENT_POWER, - ) elif room_version.knock_join_rule and Membership.KNOCK == membership: if join_rule != JoinRules.KNOCK and ( not room_version.knock_restricted_join_rule @@ -852,11 +846,11 @@ def _check_power_levels( "kick", "invite", }: - if type(v) is not int: # noqa: E721 + if type(v) is not int: raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: if not isinstance(v, collections.abc.Mapping) or not all( - type(v) is int for v in v.values() # noqa: E721 + type(v) is int for v in v.values() ): raise SynapseError( 400, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 53af423a5a98..52acb219556f 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -702,7 +702,7 @@ def _copy_power_level_value_as_integer( :raises TypeError: if `old_value` is neither an integer nor a base-10 string representation of an integer. """ - if type(old_value) is int: # noqa: E721 + if type(old_value) is int: power_levels[key] = old_value return @@ -730,7 +730,7 @@ def validate_canonicaljson(value: Any) -> None: * Floats * NaN, Infinity, -Infinity """ - if type(value) is int: # noqa: E721 + if type(value) is int: if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 34625dd7a185..9278f1a1aa65 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -151,7 +151,7 @@ def _validate_retention(self, event: EventBase) -> None: max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if type(min_lifetime) is not int: # noqa: E721 + if type(min_lifetime) is not int: raise SynapseError( code=400, msg="'min_lifetime' must be an integer", @@ -159,7 +159,7 @@ def _validate_retention(self, event: EventBase) -> None: ) if max_lifetime is not None: - if type(max_lifetime) is not int: # noqa: E721 + if type(max_lifetime) is not int: raise SynapseError( code=400, msg="'max_lifetime' must be an integer", diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index d4e7dd45a9b8..31e0260b8312 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -280,7 +280,7 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB _strip_unsigned_values(pdu_json) depth = pdu_json["depth"] - if type(depth) is not int: # noqa: E721 + if type(depth) is not int: raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 607013f121bf..89bd597409c6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1891,7 +1891,7 @@ def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse": ) origin_server_ts = d.get("origin_server_ts") - if type(origin_server_ts) is not int: # noqa: E721 + if type(origin_server_ts) is not int: raise ValueError( "Invalid response: 'origin_server_ts' must be a int but received %r" % origin_server_ts diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 5ce3f345cbeb..0b17f713ea94 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -475,11 +475,13 @@ async def get_public_rooms( See synapse.federation.federation_client.FederationClient.get_public_rooms for more information. """ - path = _create_v1_path("/publicRooms") - if search_filter: # this uses MSC2197 (Search Filtering over Federation) - data: Dict[str, Any] = {"include_all_networks": include_all_networks} + path = _create_v1_path("/publicRooms") + + data: Dict[str, Any] = { + "include_all_networks": "true" if include_all_networks else "false" + } if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: @@ -503,15 +505,17 @@ async def get_public_rooms( ) raise else: + path = _create_v1_path("/publicRooms") + args: Dict[str, Union[str, Iterable[str]]] = { "include_all_networks": "true" if include_all_networks else "false" } if third_party_instance_id: - args["third_party_instance_id"] = third_party_instance_id + args["third_party_instance_id"] = (third_party_instance_id,) if limit: - args["limit"] = str(limit) + args["limit"] = [str(limit)] if since_token: - args["since"] = since_token + args["since"] = [since_token] try: response = await self.client.get_json( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2b0c50513095..59ecafa6a094 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -218,17 +218,19 @@ def __init__(self, hs: "HomeServer"): self._failed_uia_attempts_ratelimiter = Ratelimiter( store=self.store, clock=self.clock, - cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, + rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, ) # The number of seconds to keep a UI auth session active. self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout - # Ratelimiter for failed /login attempts + # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, + rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, ) self._clock = self.hs.get_clock() diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index a85054545356..5c71637038b6 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -67,7 +67,6 @@ def __init__(self, hs: "HomeServer"): self._cas_server_url = hs.config.cas.cas_server_url self._cas_service_url = hs.config.cas.cas_service_url - self._cas_protocol_version = hs.config.cas.cas_protocol_version self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute self._cas_required_attributes = hs.config.cas.cas_required_attributes @@ -122,10 +121,7 @@ async def _validate_ticket( Returns: The parsed CAS response. """ - if self._cas_protocol_version == 3: - uri = self._cas_server_url + "/p3/proxyValidate" - else: - uri = self._cas_server_url + "/proxyValidate" + uri = self._cas_server_url + "/proxyValidate" args = { "ticket": ticket, "service": self._build_service_param(service_args), diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 798c7039f9b4..17ff8821d974 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -90,7 +90,8 @@ def __init__(self, hs: "HomeServer"): self._ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - cfg=hs.config.ratelimiting.rc_key_requests, + rate_hz=hs.config.ratelimiting.rc_key_requests.per_second, + burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index d12803bf0f31..33359f6ed748 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -67,7 +67,6 @@ async def get_stream( context = await presence_handler.user_syncing( requester.user.to_string(), - requester.device_id, affect_presence=affect_presence, presence_state=PresenceState.ONLINE, ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 472879c964cc..3031384d25bb 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -66,12 +66,14 @@ def __init__(self, hs: "HomeServer"): self._3pid_validation_ratelimiter_ip = Ratelimiter( store=self.store, clock=hs.get_clock(), - cfg=hs.config.ratelimiting.rc_3pid_validation, + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) self._3pid_validation_ratelimiter_address = Ratelimiter( store=self.store, clock=hs.get_clock(), - cfg=hs.config.ratelimiting.rc_3pid_validation, + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) async def ratelimit_request_token_requests( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6be18cdefff..a74db1dccffa 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -379,7 +379,7 @@ def maybe_schedule_expiry(self, event: EventBase) -> None: """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if type(expiry_ts) is not int or event.is_state(): # noqa: E721 + if type(expiry_ts) is not int or event.is_state(): return # _schedule_expiry_for_event won't actually schedule anything if there's already @@ -908,6 +908,19 @@ async def get_event_id_from_transaction( if existing_event_id: return existing_event_id + # Some requsters don't have device IDs (appservice, guests, and access + # tokens minted with the admin API), fallback to checking the access token + # ID, which should be close enough. + if requester.access_token_id: + existing_event_id = ( + await self.store.get_event_id_from_transaction_id_and_token_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + ) + return existing_event_id async def get_event_from_transaction( @@ -1461,23 +1474,23 @@ async def handle_new_client_event( # We now persist the event (and update the cache in parallel, since we # don't want to block on it). - # - # Note: mypy gets confused if we inline dl and check with twisted#11770. - # Some kind of bug in mypy's deduction? - deferreds = ( - run_in_background( - self._persist_events, - requester=requester, - events_and_context=events_and_context, - ratelimit=ratelimit, - extra_users=extra_users, - ), - run_in_background( - self.cache_joined_hosts_for_events, events_and_context - ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), - ) + event, context = events_and_context[0] result, _ = await make_deferred_yieldable( - gather_results(deferreds, consumeErrors=True) + gather_results( + ( + run_in_background( + self._persist_events, + requester=requester, + events_and_context=events_and_context, + ratelimit=ratelimit, + extra_users=extra_users, + ), + run_in_background( + self.cache_joined_hosts_for_events, events_and_context + ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), + ), + consumeErrors=True, + ) ).addErrback(unwrapFirstError) return result @@ -1908,10 +1921,7 @@ async def persist_and_notify_client_events( # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. run_as_background_process( - "bump_presence_active_time", - self._bump_active_time, - requester.user, - requester.device_id, + "bump_presence_active_time", self._bump_active_time, requester.user ) async def _notify() -> None: @@ -1948,10 +1958,10 @@ async def _maybe_kick_guest_users( logger.info("maybe_kick_guest_users %r", current_state) await self.hs.get_room_member_handler().kick_guest_users(current_state) - async def _bump_active_time(self, user: UserID, device_id: Optional[str]) -> None: + async def _bump_active_time(self, user: UserID) -> None: try: presence = self.hs.get_presence_handler() - await presence.bump_presence_active_time(user, device_id) + await presence.bump_presence_active_time(user) except Exception: logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2f841863ae74..e8e9db4b91a6 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -23,7 +23,6 @@ """ import abc import contextlib -import itertools import logging from bisect import bisect from contextlib import contextmanager @@ -152,13 +151,15 @@ def __init__(self, hs: "HomeServer"): self._federation_queue = PresenceFederationQueue(hs, self) + self._busy_presence_enabled = hs.config.experimental.msc3026_enabled + self.VALID_PRESENCE: Tuple[str, ...] = ( PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE, ) - if hs.config.experimental.msc3026_enabled: + if self._busy_presence_enabled: self.VALID_PRESENCE += (PresenceState.BUSY,) active_presence = self.store.take_presence_startup_info() @@ -166,11 +167,7 @@ def __init__(self, hs: "HomeServer"): @abc.abstractmethod async def user_syncing( - self, - user_id: str, - device_id: Optional[str], - affect_presence: bool, - presence_state: str, + self, user_id: str, affect_presence: bool, presence_state: str ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -181,7 +178,6 @@ async def user_syncing( Args: user_id: the user that is starting a sync - device_id: the user's device that is starting a sync affect_presence: If false this function will be a no-op. Useful for streams that are not associated with an actual client that is being used by a user. @@ -189,17 +185,15 @@ async def user_syncing( """ @abc.abstractmethod - def get_currently_syncing_users_for_replication( - self, - ) -> Iterable[Tuple[str, Optional[str]]]: - """Get an iterable of syncing users and devices on this worker, to send to the presence handler + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + """Get an iterable of syncing users on this worker, to send to the presence handler This is called when a replication connection is established. It should return - a list of tuples of user ID & device ID, which are then sent as USER_SYNC commands - to inform the process handling presence about those users/devices. + a list of user ids, which are then sent as USER_SYNC commands to inform the + process handling presence about those users. Returns: - An iterable of tuples of user ID and device ID. + An iterable of user_id strings. """ async def get_state(self, target_user: UserID) -> UserPresenceState: @@ -260,39 +254,28 @@ async def current_state_for_user(self, user_id: str) -> UserPresenceState: async def set_state( self, target_user: UserID, - device_id: Optional[str], state: JsonDict, + ignore_status_msg: bool = False, force_notify: bool = False, - is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. - device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. + ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. + If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. - is_sync: True if this update was from a sync, which results in - *not* overriding a previously set BUSY status, updating the - user's last_user_sync_ts, and ignoring the "status_msg" field of - the `state` dict. """ @abc.abstractmethod - async def bump_presence_active_time( - self, user: UserID, device_id: Optional[str] - ) -> None: + async def bump_presence_active_time(self, user: UserID) -> None: """We've seen the user do something that indicates they're interacting with the app. """ async def update_external_syncs_row( # noqa: B027 (no-op by design) - self, - process_id: str, - user_id: str, - device_id: Optional[str], - is_syncing: bool, - sync_time_msec: int, + self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int ) -> None: """Update the syncing users for an external process as a delta. @@ -303,7 +286,6 @@ async def update_external_syncs_row( # noqa: B027 (no-op by design) syncing against. This allows synapse to process updates as user start and stop syncing against a given process. user_id: The user who has started or stopped syncing - device_id: The user's device that has started or stopped syncing is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ @@ -399,9 +381,7 @@ async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: # We set force_notify=True here so that this presence update is guaranteed to # increment the presence stream ID (which resending the current user's presence # otherwise would not do). - await self.set_state( - UserID.from_string(user_id), None, state, force_notify=True - ) + await self.set_state(UserID.from_string(user_id), state, force_notify=True) async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: raise NotImplementedError( @@ -434,18 +414,16 @@ def __init__(self, hs: "HomeServer"): hs.config.worker.writers.presence, ) - # The number of ongoing syncs on this process, by (user ID, device ID). + # The number of ongoing syncs on this process, by user id. # Empty if _presence_enabled is false. - self._user_device_to_num_current_syncs: Dict[ - Tuple[str, Optional[str]], int - ] = {} + self._user_to_num_current_syncs: Dict[str, int] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - # (user_id, device_id) -> last_sync_ms. Lists the devices that have stopped - # syncing but we haven't notified the presence writer of that yet - self._user_devices_going_offline: Dict[Tuple[str, Optional[str]], int] = {} + # user_id -> last_sync_ms. Lists the users that have stopped syncing but + # we haven't notified the presence writer of that yet + self.users_going_offline: Dict[str, int] = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -468,54 +446,42 @@ async def _on_shutdown(self) -> None: ClearUserSyncsCommand(self.instance_id) ) - def send_user_sync( - self, - user_id: str, - device_id: Optional[str], - is_syncing: bool, - last_sync_ms: int, - ) -> None: + def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None: if self._presence_enabled: self.hs.get_replication_command_handler().send_user_sync( - self.instance_id, user_id, device_id, is_syncing, last_sync_ms + self.instance_id, user_id, is_syncing, last_sync_ms ) - def mark_as_coming_online(self, user_id: str, device_id: Optional[str]) -> None: + def mark_as_coming_online(self, user_id: str) -> None: """A user has started syncing. Send a UserSync to the presence writer, unless they had recently stopped syncing. """ - going_offline = self._user_devices_going_offline.pop((user_id, device_id), None) + going_offline = self.users_going_offline.pop(user_id, None) if not going_offline: # Safe to skip because we haven't yet told the presence writer they # were offline - self.send_user_sync(user_id, device_id, True, self.clock.time_msec()) + self.send_user_sync(user_id, True, self.clock.time_msec()) - def mark_as_going_offline(self, user_id: str, device_id: Optional[str]) -> None: + def mark_as_going_offline(self, user_id: str) -> None: """A user has stopped syncing. We wait before notifying the presence writer as its likely they'll come back soon. This allows us to avoid sending a stopped syncing immediately followed by a started syncing notification to the presence writer """ - self._user_devices_going_offline[(user_id, device_id)] = self.clock.time_msec() + self.users_going_offline[user_id] = self.clock.time_msec() def send_stop_syncing(self) -> None: """Check if there are any users who have stopped syncing a while ago and haven't come back yet. If there are poke the presence writer about them. """ now = self.clock.time_msec() - for (user_id, device_id), last_sync_ms in list( - self._user_devices_going_offline.items() - ): + for user_id, last_sync_ms in list(self.users_going_offline.items()): if now - last_sync_ms > UPDATE_SYNCING_USERS_MS: - self._user_devices_going_offline.pop((user_id, device_id), None) - self.send_user_sync(user_id, device_id, False, last_sync_ms) + self.users_going_offline.pop(user_id, None) + self.send_user_sync(user_id, False, last_sync_ms) async def user_syncing( - self, - user_id: str, - device_id: Optional[str], - affect_presence: bool, - presence_state: str, + self, user_id: str, affect_presence: bool, presence_state: str ) -> ContextManager[None]: """Record that a user is syncing. @@ -525,32 +491,36 @@ async def user_syncing( if not affect_presence or not self._presence_enabled: return _NullContextManager() - # Note that this causes last_active_ts to be incremented which is not - # what the spec wants. - await self.set_state( - UserID.from_string(user_id), - device_id, - state={"presence": presence_state}, - is_sync=True, - ) + prev_state = await self.current_state_for_user(user_id) + if prev_state.state != PresenceState.BUSY: + # We set state here but pass ignore_status_msg = True as we don't want to + # cause the status message to be cleared. + # Note that this causes last_active_ts to be incremented which is not + # what the spec wants: see comment in the BasePresenceHandler version + # of this function. + await self.set_state( + UserID.from_string(user_id), + {"presence": presence_state}, + ignore_status_msg=True, + ) - curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0) - self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1 + curr_sync = self._user_to_num_current_syncs.get(user_id, 0) + self._user_to_num_current_syncs[user_id] = curr_sync + 1 - # If this is the first in-flight sync, notify replication - if self._user_device_to_num_current_syncs[(user_id, device_id)] == 1: - self.mark_as_coming_online(user_id, device_id) + # If we went from no in flight sync to some, notify replication + if self._user_to_num_current_syncs[user_id] == 1: + self.mark_as_coming_online(user_id) def _end() -> None: # We check that the user_id is in user_to_num_current_syncs because # user_to_num_current_syncs may have been cleared if we are # shutting down. - if (user_id, device_id) in self._user_device_to_num_current_syncs: - self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1 + if user_id in self._user_to_num_current_syncs: + self._user_to_num_current_syncs[user_id] -= 1 - # If there are no more in-flight syncs, notify replication - if self._user_device_to_num_current_syncs[(user_id, device_id)] == 0: - self.mark_as_going_offline(user_id, device_id) + # If we went from one in flight sync to non, notify replication + if self._user_to_num_current_syncs[user_id] == 0: + self.mark_as_going_offline(user_id) @contextlib.contextmanager def _user_syncing() -> Generator[None, None, None]: @@ -617,34 +587,28 @@ async def process_replication_rows( # If this is a federation sender, notify about presence updates. await self.maybe_send_presence_to_interested_destinations(state_to_notify) - def get_currently_syncing_users_for_replication( - self, - ) -> Iterable[Tuple[str, Optional[str]]]: + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: return [ - user_id_device_id - for user_id_device_id, count in self._user_device_to_num_current_syncs.items() + user_id + for user_id, count in self._user_to_num_current_syncs.items() if count > 0 ] async def set_state( self, target_user: UserID, - device_id: Optional[str], state: JsonDict, + ignore_status_msg: bool = False, force_notify: bool = False, - is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. - device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. + ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. + If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. - is_sync: True if this update was from a sync, which results in - *not* overriding a previously set BUSY status, updating the - user's last_user_sync_ts, and ignoring the "status_msg" field of - the `state` dict. """ presence = state["presence"] @@ -661,15 +625,12 @@ async def set_state( await self._set_state_client( instance_name=self._presence_writer_instance, user_id=user_id, - device_id=device_id, state=state, + ignore_status_msg=ignore_status_msg, force_notify=force_notify, - is_sync=is_sync, ) - async def bump_presence_active_time( - self, user: UserID, device_id: Optional[str] - ) -> None: + async def bump_presence_active_time(self, user: UserID) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -680,9 +641,7 @@ async def bump_presence_active_time( # Proxy request to instance that writes presence user_id = user.to_string() await self._bump_active_client( - instance_name=self._presence_writer_instance, - user_id=user_id, - device_id=device_id, + instance_name=self._presence_writer_instance, user_id=user_id ) @@ -744,23 +703,17 @@ def __init__(self, hs: "HomeServer"): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self._user_device_to_num_current_syncs: Dict[ - Tuple[str, Optional[str]], int - ] = {} + self.user_to_num_current_syncs: Dict[str, int] = {} # Keeps track of the number of *ongoing* syncs on other processes. - # # While any sync is ongoing on another process the user will never # go offline. - # # Each process has a unique identifier and an update frequency. If # no update is received from that process within the update period then # we assume that all the sync requests on that process have stopped. - # Stored as a dict from process_id to set of (user_id, device_id), and - # a dict of process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs: Dict[ - str, Set[Tuple[str, Optional[str]]] - ] = {} + # Stored as a dict from process_id to set of user_id, and a dict of + # process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs: Dict[str, Set[str]] = {} self.external_process_last_updated_ms: Dict[str, int] = {} self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -965,10 +918,7 @@ async def _handle_timeouts(self) -> None: # that were syncing on that process to see if they need to be timed # out. users_to_check.update( - user_id - for user_id, device_id in self.external_process_to_current_syncs.pop( - process_id, () - ) + self.external_process_to_current_syncs.pop(process_id, ()) ) self.external_process_last_updated_ms.pop(process_id) @@ -981,15 +931,11 @@ async def _handle_timeouts(self) -> None: syncing_user_ids = { user_id - for (user_id, _), count in self._user_device_to_num_current_syncs.items() + for user_id, count in self.user_to_num_current_syncs.items() if count } - syncing_user_ids.update( - user_id - for user_id, _ in itertools.chain( - *self.external_process_to_current_syncs.values() - ) - ) + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) changes = handle_timeouts( states, @@ -1000,9 +946,7 @@ async def _handle_timeouts(self) -> None: return await self._update_states(changes) - async def bump_presence_active_time( - self, user: UserID, device_id: Optional[str] - ) -> None: + async def bump_presence_active_time(self, user: UserID) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -1025,7 +969,6 @@ async def bump_presence_active_time( async def user_syncing( self, user_id: str, - device_id: Optional[str], affect_presence: bool = True, presence_state: str = PresenceState.ONLINE, ) -> ContextManager[None]: @@ -1037,8 +980,7 @@ async def user_syncing( when users disconnect/reconnect. Args: - user_id: the user that is starting a sync - device_id: the user's device that is starting a sync + user_id affect_presence: If false this function will be a no-op. Useful for streams that are not associated with an actual client that is being used by a user. @@ -1047,21 +989,52 @@ async def user_syncing( if not affect_presence or not self._presence_enabled: return _NullContextManager() - curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0) - self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1 + curr_sync = self.user_to_num_current_syncs.get(user_id, 0) + self.user_to_num_current_syncs[user_id] = curr_sync + 1 - # Note that this causes last_active_ts to be incremented which is not - # what the spec wants. - await self.set_state( - UserID.from_string(user_id), - device_id, - state={"presence": presence_state}, - is_sync=True, - ) + prev_state = await self.current_state_for_user(user_id) + + # If they're busy then they don't stop being busy just by syncing, + # so just update the last sync time. + if prev_state.state != PresenceState.BUSY: + # XXX: We set_state separately here and just update the last_active_ts above + # This keeps the logic as similar as possible between the worker and single + # process modes. Using set_state will actually cause last_active_ts to be + # updated always, which is not what the spec calls for, but synapse has done + # this for... forever, I think. + await self.set_state( + UserID.from_string(user_id), + {"presence": presence_state}, + ignore_status_msg=True, + ) + # Retrieve the new state for the logic below. This should come from the + # in-memory cache. + prev_state = await self.current_state_for_user(user_id) + + # To keep the single process behaviour consistent with worker mode, run the + # same logic as `update_external_syncs_row`, even though it looks weird. + if prev_state.state == PresenceState.OFFLINE: + await self._update_states( + [ + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), + ) + ] + ) + # otherwise, set the new presence state & update the last sync time, + # but don't update last_active_ts as this isn't an indication that + # they've been active (even though it's probably been updated by + # set_state above) + else: + await self._update_states( + [prev_state.copy_and_replace(last_user_sync_ts=self.clock.time_msec())] + ) async def _end() -> None: try: - self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1 + self.user_to_num_current_syncs[user_id] -= 1 prev_state = await self.current_state_for_user(user_id) await self._update_states( @@ -1083,19 +1056,12 @@ def _user_syncing() -> Generator[None, None, None]: return _user_syncing() - def get_currently_syncing_users_for_replication( - self, - ) -> Iterable[Tuple[str, Optional[str]]]: + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: # since we are the process handling presence, there is nothing to do here. return [] async def update_external_syncs_row( - self, - process_id: str, - user_id: str, - device_id: Optional[str], - is_syncing: bool, - sync_time_msec: int, + self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int ) -> None: """Update the syncing users for an external process as a delta. @@ -1104,7 +1070,6 @@ async def update_external_syncs_row( syncing against. This allows synapse to process updates as user start and stop syncing against a given process. user_id: The user who has started or stopped syncing - device_id: The user's device that has started or stopped syncing is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ @@ -1115,27 +1080,31 @@ async def update_external_syncs_row( process_id, set() ) - # USER_SYNC is sent when a user's device starts or stops syncing on - # a remote # process. (But only for the initial and last sync for that - # device.) - # - # When a device *starts* syncing it also calls set_state(...) which - # will update the state, last_active_ts, and last_user_sync_ts. - # Simply ensure the user & device is tracked as syncing in this case. - # - # When a device *stops* syncing, update the last_user_sync_ts and mark - # them as no longer syncing. Note this doesn't quite match the - # monolith behaviour, which updates last_user_sync_ts at the end of - # every sync, not just the last in-flight sync. - if is_syncing and (user_id, device_id) not in process_presence: - process_presence.add((user_id, device_id)) - elif not is_syncing and (user_id, device_id) in process_presence: - new_state = prev_state.copy_and_replace( - last_user_sync_ts=sync_time_msec + updates = [] + if is_syncing and user_id not in process_presence: + if prev_state.state == PresenceState.OFFLINE: + updates.append( + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=sync_time_msec, + last_user_sync_ts=sync_time_msec, + ) + ) + else: + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + ) + process_presence.add(user_id) + elif user_id in process_presence: + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) ) - await self._update_states([new_state]) - process_presence.discard((user_id, device_id)) + if not is_syncing: + process_presence.discard(user_id) + + if updates: + await self._update_states(updates) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() @@ -1149,9 +1118,7 @@ async def update_external_syncs_clear(self, process_id: str) -> None: process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = await self.current_state_for_users( - {user_id for user_id, device_id in process_presence} - ) + prev_states = await self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() await self._update_states( @@ -1236,22 +1203,18 @@ async def incoming_presence(self, origin: str, content: JsonDict) -> None: async def set_state( self, target_user: UserID, - device_id: Optional[str], state: JsonDict, + ignore_status_msg: bool = False, force_notify: bool = False, - is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. - device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. + ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. + If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. - is_sync: True if this update was from a sync, which results in - *not* overriding a previously set BUSY status, updating the - user's last_user_sync_ts, and ignoring the "status_msg" field of - the `state` dict. """ status_msg = state.get("status_msg", None) presence = state["presence"] @@ -1264,28 +1227,19 @@ async def set_state( return user_id = target_user.to_string() - now = self.clock.time_msec() prev_state = await self.current_state_for_user(user_id) - # Syncs do not override a previous presence of busy. - # - # TODO: This is a hack for lack of multi-device support. Unfortunately - # removing this requires coordination with clients. - if prev_state.state == PresenceState.BUSY and is_sync: - presence = PresenceState.BUSY - new_fields = {"state": presence} - if presence == PresenceState.ONLINE or presence == PresenceState.BUSY: - new_fields["last_active_ts"] = now - - if is_sync: - new_fields["last_user_sync_ts"] = now - else: - # Syncs do not override the status message. + if not ignore_status_msg: new_fields["status_msg"] = status_msg + if presence == PresenceState.ONLINE or ( + presence == PresenceState.BUSY and self._busy_presence_enabled + ): + new_fields["last_active_ts"] = self.clock.time_msec() + await self._update_states( [prev_state.copy_and_replace(**new_fields)], force_notify=force_notify ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index de0f04e3fe48..1d8d4a72e7a2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -112,7 +112,8 @@ def __init__(self, hs: "HomeServer"): self._join_rate_limiter_local = Ratelimiter( store=self.store, clock=self.clock, - cfg=hs.config.ratelimiting.rc_joins_local, + rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, + burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) # Tracks joins from local users to rooms this server isn't a member of. # I.e. joins this server makes by requesting /make_join /send_join from @@ -120,7 +121,8 @@ def __init__(self, hs: "HomeServer"): self._join_rate_limiter_remote = Ratelimiter( store=self.store, clock=self.clock, - cfg=hs.config.ratelimiting.rc_joins_remote, + rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, + burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) # TODO: find a better place to keep this Ratelimiter. # It needs to be @@ -133,7 +135,8 @@ def __init__(self, hs: "HomeServer"): self._join_rate_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - cfg=hs.config.ratelimiting.rc_joins_per_room, + rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count, ) # Ratelimiter for invites, keyed by room (across all issuers, all @@ -141,7 +144,8 @@ def __init__(self, hs: "HomeServer"): self._invites_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - cfg=hs.config.ratelimiting.rc_invites_per_room, + rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, ) # Ratelimiter for invites, keyed by recipient (across all rooms, all @@ -149,7 +153,8 @@ def __init__(self, hs: "HomeServer"): self._invites_per_recipient_limiter = Ratelimiter( store=self.store, clock=self.clock, - cfg=hs.config.ratelimiting.rc_invites_per_user, + rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, ) # Ratelimiter for invites, keyed by issuer (across all rooms, all @@ -157,13 +162,15 @@ def __init__(self, hs: "HomeServer"): self._invites_per_issuer_limiter = Ratelimiter( store=self.store, clock=self.clock, - cfg=hs.config.ratelimiting.rc_invites_per_issuer, + rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count, ) self._third_party_invite_limiter = Ratelimiter( store=self.store, clock=self.clock, - cfg=hs.config.ratelimiting.rc_third_party_invite, + rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second, + burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count, ) self.request_ratelimiter = hs.get_request_ratelimiter() diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index dd559b4c450f..dad3e23470fb 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -35,7 +35,6 @@ UnsupportedRoomVersionError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.config.ratelimiting import RatelimitSettings from synapse.events import EventBase from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache @@ -95,9 +94,7 @@ def __init__(self, hs: "HomeServer"): self._server_name = hs.hostname self._federation_client = hs.get_federation_client() self._ratelimiter = Ratelimiter( - store=self._store, - clock=hs.get_clock(), - cfg=RatelimitSettings("", per_second=5, burst_count=10), + store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 ) # If a user tries to fetch the same page multiple times in quick succession, diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index 05e21509deac..804cc6e81e00 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -23,11 +23,9 @@ import twisted from twisted.internet.defer import Deferred -from twisted.internet.endpoints import HostnameEndpoint -from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory +from twisted.internet.interfaces import IOpenSSLContextFactory from twisted.internet.ssl import optionsForClientTLS from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory -from twisted.protocols.tls import TLSMemoryBIOFactory from synapse.logging.context import make_deferred_yieldable from synapse.types import ISynapseReactor @@ -99,7 +97,6 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory: **kwargs, ) - factory: IProtocolFactory if _is_old_twisted: # before twisted 21.2, we have to override the ESMTPSender protocol to disable # TLS @@ -113,13 +110,22 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory: factory = build_sender_factory(hostname=smtphost if enable_tls else None) if force_tls: - factory = TLSMemoryBIOFactory(optionsForClientTLS(smtphost), True, factory) - - endpoint = HostnameEndpoint( - reactor, smtphost, smtpport, timeout=30, bindAddress=None - ) - - await make_deferred_yieldable(endpoint.connect(factory)) + reactor.connectSSL( + smtphost, + smtpport, + factory, + optionsForClientTLS(smtphost), + timeout=30, + bindAddress=None, + ) + else: + reactor.connectTCP( + smtphost, + smtpport, + factory, + timeout=30, + bindAddress=None, + ) await make_deferred_yieldable(d) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 11342ccac8a3..583c03447c17 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -243,7 +243,7 @@ def _validate(v: Any) -> bool: return ( isinstance(v, list) and len(v) == 2 - and type(v[0]) == int # noqa: E721 + and type(v[0]) == int and isinstance(v[1], dict) ) diff --git a/synapse/http/server.py b/synapse/http/server.py index 3bbf91298e3d..5109cec983c9 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -115,13 +115,7 @@ def return_json_error( if exc.headers is not None: for header, value in exc.headers.items(): request.setHeader(header, value) - error_ctx = exc.debug_context - if error_ctx: - logger.info( - "%s SynapseError: %s - %s (%s)", request, error_code, exc.msg, error_ctx - ) - else: - logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) + logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) elif f.check(CancelledError): error_code = HTTP_STATUS_REQUEST_CANCELLED error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN} diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 98c6038ff23f..b78d6e17c93c 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -44,7 +44,6 @@ "processName", "relativeCreated", "stack_info", - "taskName", "thread", "threadName", } diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 64c6ae451208..f62bea968fe4 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -809,24 +809,23 @@ def run_in_background( # type: ignore[misc] # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain # value. Convert it to a `Deferred`. - d: "defer.Deferred[R]" if isinstance(res, typing.Coroutine): # Wrap the coroutine in a `Deferred`. - d = defer.ensureDeferred(res) + res = defer.ensureDeferred(res) elif isinstance(res, defer.Deferred): - d = res + pass elif isinstance(res, Awaitable): # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable` # or `Future` from `make_awaitable`. - d = defer.ensureDeferred(_unwrap_awaitable(res)) + res = defer.ensureDeferred(_unwrap_awaitable(res)) else: # `res` is a plain value. Wrap it in a `Deferred`. - d = defer.succeed(res) + res = defer.succeed(res) - if d.called and not d.paused: + if res.called and not res.paused: # The function should have maintained the logcontext, so we can # optimise out the messing about - return d + return res # The function may have reset the context before returning, so # we need to restore it now. @@ -844,8 +843,8 @@ def run_in_background( # type: ignore[misc] # which is supposed to have a single entry and exit point. But # by spawning off another deferred, we are effectively # adding a new exit point.) - d.addBoth(_set_context_cb, ctx) - return d + res.addBoth(_set_context_cb, ctx) + return res T = TypeVar("T") @@ -878,7 +877,7 @@ def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T] ResultT = TypeVar("ResultT") -def _set_context_cb(result: ResultT, context: LoggingContextOrSentinel) -> ResultT: +def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: """A callback function which just sets the logging context""" set_current_context(context) return result diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 5c3045e197e9..be910128aa4e 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -910,10 +910,10 @@ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> async def _wrapper( *args: P.args, **kwargs: P.kwargs ) -> Any: # Return type is RInner - # type-ignore: func() returns R, but mypy doesn't know that R is - # Awaitable here. - with wrapping_logic(func, *args, **kwargs): # type: ignore[arg-type] - return await func(*args, **kwargs) + with wrapping_logic(func, *args, **kwargs): + # type-ignore: func() returns R, but mypy doesn't know that R is + # Awaitable here. + return await func(*args, **kwargs) # type: ignore[misc] else: # The other case here handles sync functions including those decorated with @@ -980,7 +980,8 @@ def trace_with_opname( See the module's doc string for usage examples. """ - @contextlib.contextmanager + # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 + @contextlib.contextmanager # type: ignore[arg-type] def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: @@ -1023,7 +1024,8 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: if not opentracing: return func - @contextlib.contextmanager + # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 + @contextlib.contextmanager # type: ignore[arg-type] def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 1b7b014f9ac2..4b750c700b89 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -214,10 +214,7 @@ async def create_content( user_id=auth_user, ) - try: - await self._generate_thumbnails(None, media_id, media_id, media_type) - except Exception as e: - logger.info("Failed to generate thumbnails: %s", e) + await self._generate_thumbnails(None, media_id, media_id, media_type) return MXCUri(self.server_name, media_id) diff --git a/synapse/media/oembed.py b/synapse/media/oembed.py index 2ce842c98d4a..5ad9eec80b97 100644 --- a/synapse/media/oembed.py +++ b/synapse/media/oembed.py @@ -204,7 +204,7 @@ def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult: calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) - if type(val) is int: # noqa: E721 + if type(val) is int: open_graph_response[f"og:video:{size}"] = val elif oembed_type == "link": diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index d8979813b335..2bfa58ceee5d 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -78,7 +78,7 @@ def __init__(self, input_path: str): image_exif = self.image._getexif() # type: ignore if image_exif is not None: image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) - assert type(image_orientation) is int # noqa: E721 + assert type(image_orientation) is int self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) except Exception as e: # A lot of parsing errors can happen when parsing EXIF diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 554634579ed0..990c079c815b 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -379,7 +379,7 @@ async def _action_for_event_by_user( keys = list(notification_levels.keys()) for key in keys: level = notification_levels.get(key, SENTINEL) - if level is not SENTINEL and type(level) is not int: # noqa: E721 + if level is not SENTINEL and type(level) is not int: try: notification_levels[key] = int(level) except (TypeError, ValueError): @@ -472,11 +472,7 @@ async def _action_for_event_by_user( def _is_simple_value(value: Any) -> bool: - return ( - isinstance(value, (bool, str)) - or type(value) is int # noqa: E721 - or value is None - ) + return isinstance(value, (bool, str)) or type(value) is int or value is None def _flatten_dict( diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 209833d28753..73f3de364205 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -62,7 +62,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): NAME = "multi_user_device_resync" PATH_ARGS = () - CACHE = True + CACHE = False def __init__(self, hs: "HomeServer"): super().__init__(hs) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index 6c9e79fb07c9..db16aac9c206 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Tuple from twisted.web.server import Request @@ -51,14 +51,14 @@ def __init__(self, hs: "HomeServer"): self._presence_handler = hs.get_presence_handler() @staticmethod - async def _serialize_payload(user_id: str, device_id: Optional[str]) -> JsonDict: # type: ignore[override] - return {"device_id": device_id} + async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override] + return {} async def _handle_request( # type: ignore[override] self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: await self._presence_handler.bump_presence_active_time( - UserID.from_string(user_id), content.get("device_id") + UserID.from_string(user_id) ) return (200, {}) @@ -73,8 +73,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint): { "state": { ... }, - "force_notify": false, - "is_sync": false + "ignore_status_msg": false, + "force_notify": false } 200 OK @@ -95,16 +95,14 @@ def __init__(self, hs: "HomeServer"): @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, - device_id: Optional[str], state: JsonDict, + ignore_status_msg: bool = False, force_notify: bool = False, - is_sync: bool = False, ) -> JsonDict: return { - "device_id": device_id, "state": state, + "ignore_status_msg": ignore_status_msg, "force_notify": force_notify, - "is_sync": is_sync, } async def _handle_request( # type: ignore[override] @@ -112,10 +110,9 @@ async def _handle_request( # type: ignore[override] ) -> Tuple[int, JsonDict]: await self._presence_handler.set_state( UserID.from_string(user_id), - content.get("device_id"), content["state"], + content["ignore_status_msg"], content["force_notify"], - content.get("is_sync", False), ) return (200, {}) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index e616b5e1c8ad..10f5c98ff8a9 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -267,38 +267,27 @@ class UserSyncCommand(Command): NAME = "USER_SYNC" def __init__( - self, - instance_id: str, - user_id: str, - device_id: Optional[str], - is_syncing: bool, - last_sync_ms: int, + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int ): self.instance_id = instance_id self.user_id = user_id - self.device_id = device_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand": - device_id: Optional[str] - instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4) - - if device_id == "None": - device_id = None + instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(instance_id, user_id, device_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, state == "start", int(last_sync_ms)) def to_line(self) -> str: return " ".join( ( self.instance_id, self.user_id, - str(self.device_id), "start" if self.is_syncing else "end", str(self.last_sync_ms), ) @@ -463,17 +452,6 @@ def to_line(self) -> str: return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key]) -class NewActiveTaskCommand(_SimpleCommand): - """Sent to inform instance handling background tasks that a new active task is available to run. - - Format:: - - NEW_ACTIVE_TASK "" - """ - - NAME = "NEW_ACTIVE_TASK" - - _COMMANDS: Tuple[Type[Command], ...] = ( ServerCommand, RdataCommand, @@ -488,7 +466,6 @@ class NewActiveTaskCommand(_SimpleCommand): RemoteServerUpCommand, ClearUserSyncsCommand, LockReleasedCommand, - NewActiveTaskCommand, ) # Map of command name to command type. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d9045d7b73f5..38adcbe1d0e8 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -40,7 +40,6 @@ Command, FederationAckCommand, LockReleasedCommand, - NewActiveTaskCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -239,10 +238,6 @@ def __init__(self, hs: "HomeServer"): if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() - self._task_scheduler = None - if hs.config.worker.run_background_tasks: - self._task_scheduler = hs.get_task_scheduler() - if hs.config.redis.redis_enabled: # If we're using Redis, it's the background worker that should # receive USER_IP commands and store the relevant client IPs. @@ -428,11 +423,7 @@ def on_USER_SYNC( if self._is_presence_writer: return self._presence_handler.update_external_syncs_row( - cmd.instance_id, - cmd.user_id, - cmd.device_id, - cmd.is_syncing, - cmd.last_sync_ms, + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) else: return None @@ -672,15 +663,6 @@ def on_LOCK_RELEASED( cmd.instance_name, cmd.lock_name, cmd.lock_key ) - async def on_NEW_ACTIVE_TASK( - self, conn: IReplicationConnection, cmd: NewActiveTaskCommand - ) -> None: - """Called when get a new NEW_ACTIVE_TASK command.""" - if self._task_scheduler: - task = await self._task_scheduler.get_task(cmd.data) - if task: - await self._task_scheduler._launch_task(task) - def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -703,9 +685,9 @@ def new_connection(self, connection: IReplicationConnection) -> None: ) now = self._clock.time_msec() - for user_id, device_id in currently_syncing: + for user_id in currently_syncing: connection.send_command( - UserSyncCommand(self._instance_id, user_id, device_id, True, now) + UserSyncCommand(self._instance_id, user_id, True, now) ) def lost_connection(self, connection: IReplicationConnection) -> None: @@ -757,16 +739,11 @@ def send_federation_ack(self, token: int) -> None: self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( - self, - instance_id: str, - user_id: str, - device_id: Optional[str], - is_syncing: bool, - last_sync_ms: int, + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int ) -> None: """Poke the master that a user has started/stopped syncing.""" self.send_command( - UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms) + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) ) def send_user_ip( @@ -799,10 +776,6 @@ def on_lock_released( if instance_name == self._instance_name: self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key)) - def send_new_active_task(self, task_id: str) -> None: - """Called when a new task has been scheduled for immediate launch and is ACTIVE.""" - self.send_command(NewActiveTaskCommand(task_id)) - UpdateToken = TypeVar("UpdateToken") UpdateRow = TypeVar("UpdateRow") diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 94170715fb77..55e752fda85a 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -157,7 +157,7 @@ async def on_POST( logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: ts = body["purge_up_to_ts"] - if type(ts) is not int: # noqa: E721 + if type(ts) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "purge_up_to_ts must be an int", diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index ffce92d45ee1..95e751288b03 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -143,7 +143,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: else: # Get length of token to generate (default is 16) length = body.get("length", 16) - if type(length) is not int: # noqa: E721 + if type(length) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "length must be an integer", @@ -163,8 +163,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: uses_allowed = body.get("uses_allowed", None) if not ( - uses_allowed is None - or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 + uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -173,16 +172,13 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ) expiry_time = body.get("expiry_time", None) - if expiry_time is not None and type(expiry_time) is not int: # noqa: E721 + if type(expiry_time) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if ( - type(expiry_time) is int # noqa: E721 - and expiry_time < self.clock.time_msec() - ): + if type(expiry_time) is int and expiry_time < self.clock.time_msec(): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", @@ -287,7 +283,7 @@ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi uses_allowed = body["uses_allowed"] if not ( uses_allowed is None - or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 + or (type(uses_allowed) is int and uses_allowed >= 0) ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -298,16 +294,13 @@ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi if "expiry_time" in body: expiry_time = body["expiry_time"] - if expiry_time is not None and type(expiry_time) is not int: # noqa: E721 + if type(expiry_time) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if ( - type(expiry_time) is int # noqa: E721 - and expiry_time < self.clock.time_msec() - ): + if type(expiry_time) is int and expiry_time < self.clock.time_msec(): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 625a47ec1a5a..240e6254b0bd 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1172,17 +1172,14 @@ async def on_POST( messages_per_second = body.get("messages_per_second", 0) burst_count = body.get("burst_count", 0) - if ( - type(messages_per_second) is not int # noqa: E721 - or messages_per_second < 0 - ): + if type(messages_per_second) is not int or messages_per_second < 0: raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) - if type(burst_count) is not int or burst_count < 0: # noqa: E721 + if type(burst_count) is not int or burst_count < 0: raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 7be327e26f08..d724c6892067 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -120,12 +120,14 @@ def __init__(self, hs: "HomeServer"): self._address_ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - cfg=self.hs.config.ratelimiting.rc_login_address, + rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - cfg=self.hs.config.ratelimiting.rc_login_account, + rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, ) # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index d189a923b5bf..b1629f94a5f8 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.ratelimiting import Ratelimiter -from synapse.config.ratelimiting import RatelimitSettings from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -67,18 +66,15 @@ def __init__(self, hs: "HomeServer"): self.token_timeout = hs.config.auth.login_via_existing_token_timeout self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth - # Ratelimit aggressively to a maximum of 1 request per minute. + # Ratelimit aggressively to a maxmimum of 1 request per minute. # # This endpoint can be used to spawn additional sessions and could be # abused by a malicious client to create many sessions. self._ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - cfg=RatelimitSettings( - key="", - per_second=1 / 60, - burst_count=1, - ), + rate_hz=1 / 60, + burst_count=1, ) @interactive_auth_handler diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py index d578faa96984..8e193330f8bc 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py @@ -97,7 +97,7 @@ async def on_PUT( raise SynapseError(400, "Unable to parse state") if self._use_presence: - await self.presence_handler.set_state(user, requester.device_id, state) + await self.presence_handler.set_state(user, state) return 200, {} diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 1707e519723a..4f96e51eeb93 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -52,9 +52,7 @@ async def on_POST( ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await self.presence_handler.bump_presence_active_time( - requester.user, requester.device_id - ) + await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 869a37445950..316e7b99821e 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -94,9 +94,7 @@ async def on_POST( Codes.INVALID_PARAM, ) - await self.presence_handler.bump_presence_active_time( - requester.user, requester.device_id - ) + await self.presence_handler.bump_presence_active_time(requester.user) if receipt_type == ReceiptTypes.FULLY_READ: await self.read_marker_handler.received_client_read_marker( diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 132623462adc..77e3b91b7999 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -376,7 +376,8 @@ def __init__(self, hs: "HomeServer"): self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - cfg=hs.config.ratelimiting.rc_registration_token_validity, + rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, + burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index ee93e459f6bf..ac1a63ca2745 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -55,7 +55,7 @@ async def on_POST( "Param 'reason' must be a string", Codes.BAD_JSON, ) - if type(body.get("score", 0)) is not int: # noqa: E721 + if type(body.get("score", 0)) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 553938ce9d13..dc498001e450 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1229,9 +1229,7 @@ async def on_PUT( content = parse_json_object_from_request(request) - await self.presence_handler.bump_presence_active_time( - requester.user, requester.device_id - ) + await self.presence_handler.bump_presence_active_time(requester.user) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 42bdd3bb108b..d7854ed4fd9d 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -205,7 +205,6 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: context = await self.presence_handler.user_syncing( user.to_string(), - requester.device_id, affect_presence=affect_presence, presence_state=set_presence, ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 0aaa838d0478..981fd1f58a68 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -16,7 +16,6 @@ import re from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple -from pydantic import Extra, StrictInt, StrictStr from signedjson.sign import sign_json from twisted.web.server import Request @@ -25,10 +24,9 @@ from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, - parse_and_validate_json_object_from_request, parse_integer, + parse_json_object_from_request, ) -from synapse.rest.models import RequestBodyModel from synapse.storage.keys import FetchKeyResultForRemote from synapse.types import JsonDict from synapse.util import json_decoder @@ -40,13 +38,6 @@ logger = logging.getLogger(__name__) -class _KeyQueryCriteriaDataModel(RequestBodyModel): - class Config: - extra = Extra.allow - - minimum_valid_until_ts: Optional[StrictInt] - - class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported @@ -105,9 +96,6 @@ class RemoteKey(RestServlet): CATEGORY = "Federation requests" - class PostBody(RequestBodyModel): - server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]] - def __init__(self, hs: "HomeServer"): self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main @@ -149,29 +137,24 @@ async def on_GET( ) minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") - query = { - server: { - key_id: _KeyQueryCriteriaDataModel( - minimum_valid_until_ts=minimum_valid_until_ts - ) - } - } + arguments = {} + if minimum_valid_until_ts is not None: + arguments["minimum_valid_until_ts"] = minimum_valid_until_ts + query = {server: {key_id: arguments}} else: query = {server: {}} return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - content = parse_and_validate_json_object_from_request(request, self.PostBody) + content = parse_json_object_from_request(request) - query = content.server_keys + query = content["server_keys"] return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, - query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]], - query_remote_on_cache_miss: bool = False, + self, query: JsonDict, query_remote_on_cache_miss: bool = False ) -> JsonDict: logger.info("Handling query for keys %r", query) @@ -213,10 +196,8 @@ async def query_keys( else: ts_added_ms = key_result.added_ts ts_valid_until_ms = key_result.valid_until_ts - req_key = query.get(server_name, {}).get( - key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None) - ) - req_valid_until = req_key.minimum_valid_until_ts + req_key = query.get(server_name, {}).get(key_id, {}) + req_valid_until = req_key.get("minimum_valid_until_ts") if req_valid_until is not None: if ts_valid_until_ms < req_valid_until: logger.debug( diff --git a/synapse/server.py b/synapse/server.py index 71ead524d684..7cdd3ea3c2e1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -408,7 +408,8 @@ def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( store=self.get_datastores().main, clock=self.get_clock(), - cfg=self.config.ratelimiting.rc_registration, + rate_hz=self.config.ratelimiting.rc_registration.per_second, + burst_count=self.config.ratelimiting.rc_registration.burst_count, ) @cache_in_self @@ -913,7 +914,6 @@ def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager: """Usage metrics shared between phone home stats and the prometheus exporter.""" return CommonUsageMetricsManager(self) - @cache_in_self def get_worker_locks_handler(self) -> WorkerLocksHandler: return WorkerLocksHandler(self) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 7619f405fa09..ddca0af1da39 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -405,14 +405,14 @@ async def run_background_updates(self, sleep: bool) -> None: try: result = await self.do_next_background_update(sleep) back_to_back_failures = 0 - except Exception as e: - logger.exception("Error doing update: %s", e) + except Exception: back_to_back_failures += 1 if back_to_back_failures >= 5: self._aborted = True raise RuntimeError( "5 back-to-back background update failures; aborting." ) + logger.exception("Error doing update") else: if result: logger.info( diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 55ac313f33b0..a1c8fb0f46a4 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -31,7 +31,6 @@ Iterator, List, Optional, - Sequence, Tuple, Type, TypeVar, @@ -359,21 +358,7 @@ def rowcount(self) -> int: return self.txn.rowcount @property - def description( - self, - ) -> Optional[ - Sequence[ - Tuple[ - str, - Optional[Any], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - ] - ] - ]: + def description(self) -> Any: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0c1ed752406f..c1353b18c1cd 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -978,12 +978,26 @@ def _persist_transaction_ids_txn( """Persist the mapping from transaction IDs to event IDs (if defined).""" inserted_ts = self._clock.time_msec() + to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = [] to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = [] for event, _ in events_and_contexts: txn_id = getattr(event.internal_metadata, "txn_id", None) + token_id = getattr(event.internal_metadata, "token_id", None) device_id = getattr(event.internal_metadata, "device_id", None) if txn_id is not None: + if token_id is not None: + to_insert_token_id.append( + ( + event.event_id, + event.room_id, + event.sender, + token_id, + txn_id, + inserted_ts, + ) + ) + if device_id is not None: to_insert_device_id.append( ( @@ -996,7 +1010,26 @@ def _persist_transaction_ids_txn( ) ) - # Synapse relies on the device_id to scope transactions for events.. + # Synapse usually relies on the device_id to scope transactions for events, + # except for users without device IDs (appservice, guests, and access + # tokens minted with the admin API) which use the access token ID instead. + # + # TODO https://github.com/matrix-org/synapse/issues/16042 + if to_insert_token_id: + self.db_pool.simple_insert_many_txn( + txn, + table="event_txn_id", + keys=( + "event_id", + "room_id", + "user_id", + "token_id", + "txn_id", + "inserted_ts", + ), + values=to_insert_token_id, + ) + if to_insert_device_id: self.db_pool.simple_insert_many_txn( txn, @@ -1638,7 +1671,7 @@ def _update_metadata_tables_txn( if self._ephemeral_messages_enabled: # If there's an expiry timestamp on the event, store it. expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if type(expiry_ts) is int and not event.is_state(): # noqa: E721 + if type(expiry_ts) is int and not event.is_state(): self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) # Insert into the room_memberships table. @@ -2006,10 +2039,10 @@ def _store_retention_policy_for_room_txn( ): if ( "min_lifetime" in event.content - and type(event.content["min_lifetime"]) is not int # noqa: E721 + and type(event.content["min_lifetime"]) is not int ) or ( "max_lifetime" in event.content - and type(event.content["max_lifetime"]) is not int # noqa: E721 + and type(event.content["max_lifetime"]) is not int ): # Ignore the event if one of the value isn't an integer. return diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 1eb313040ed9..7e7648c95112 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -2022,6 +2022,25 @@ def get_next_event_to_expire_txn( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) + async def get_event_id_from_transaction_id_and_token_id( + self, room_id: str, user_id: str, token_id: int, txn_id: str + ) -> Optional[str]: + """Look up if we have already persisted an event for the transaction ID, + returning the event ID if so. + """ + return await self.db_pool.simple_select_one_onecol( + table="event_txn_id", + keyvalues={ + "room_id": room_id, + "user_id": user_id, + "token_id": token_id, + "txn_id": txn_id, + }, + retcol="event_id", + allow_none=True, + desc="get_event_id_from_transaction_id_and_token_id", + ) + async def get_event_id_from_transaction_id_and_device_id( self, room_id: str, user_id: str, device_id: str, txn_id: str ) -> Optional[str]: @@ -2053,35 +2072,29 @@ async def get_already_persisted_events( """ mapping = {} - txn_id_to_event: Dict[Tuple[str, str, str, str], str] = {} + txn_id_to_event: Dict[Tuple[str, int, str], str] = {} for event in events: - device_id = getattr(event.internal_metadata, "device_id", None) + token_id = getattr(event.internal_metadata, "token_id", None) txn_id = getattr(event.internal_metadata, "txn_id", None) - if device_id and txn_id: + if token_id and txn_id: # Check if this is a duplicate of an event in the given events. - existing = txn_id_to_event.get( - (event.room_id, event.sender, device_id, txn_id) - ) + existing = txn_id_to_event.get((event.room_id, token_id, txn_id)) if existing: mapping[event.event_id] = existing continue # Check if this is a duplicate of an event we've already # persisted. - existing = await self.get_event_id_from_transaction_id_and_device_id( - event.room_id, event.sender, device_id, txn_id + existing = await self.get_event_id_from_transaction_id_and_token_id( + event.room_id, event.sender, token_id, txn_id ) if existing: mapping[event.event_id] = existing - txn_id_to_event[ - (event.room_id, event.sender, device_id, txn_id) - ] = existing + txn_id_to_event[(event.room_id, token_id, txn_id)] = existing else: - txn_id_to_event[ - (event.room_id, event.sender, device_id, txn_id) - ] = event.event_id + txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id return mapping diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 5a01ec213759..54d40e7a3ab0 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type from weakref import WeakValueDictionary -from twisted.internet.task import LoopingCall +from twisted.internet.interfaces import IReactorCore from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -26,7 +26,6 @@ LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.types import ISynapseReactor from synapse.util import Clock from synapse.util.stringutils import random_string @@ -359,7 +358,7 @@ class Lock: def __init__( self, - reactor: ISynapseReactor, + reactor: IReactorCore, clock: Clock, store: LockStore, read_write: bool, @@ -378,25 +377,19 @@ def __init__( self._table = "worker_read_write_locks" if read_write else "worker_locks" - # We might be called from a non-main thread, so we defer setting up the - # looping call. - self._looping_call: Optional[LoopingCall] = None - reactor.callFromThread(self._setup_looping_call) - - self._dropped = False - - def _setup_looping_call(self) -> None: - self._looping_call = self._clock.looping_call( + self._looping_call = clock.looping_call( self._renew, _RENEWAL_INTERVAL_MS, - self._store, - self._clock, - self._read_write, - self._lock_name, - self._lock_key, - self._token, + store, + clock, + read_write, + lock_name, + lock_key, + token, ) + self._dropped = False + @staticmethod @wrap_as_background_process("Lock._renew") async def _renew( @@ -466,7 +459,7 @@ async def release(self) -> None: if self._dropped: return - if self._looping_call and self._looping_call.running: + if self._looping_call.running: self._looping_call.stop() await self._store.db_pool.simple_delete( @@ -493,9 +486,8 @@ def __del__(self) -> None: # We should not be dropped without the lock being released (unless # we're shutting down), but if we are then let's at least stop # renewing the lock. - if self._looping_call and self._looping_call.running: - # We might be called from a non-main thread. - self._reactor.callFromThread(self._looping_call.stop) + if self._looping_call.running: + self._looping_call.stop() if self._reactor.running: logger.error( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index bec0dc2afeeb..c13c0bc7d725 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -88,6 +88,7 @@ def _load_rules( msc1767_enabled=experimental_config.msc1767_enabled, msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, + msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs, ) return filtered_rules diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 422f11f59e9e..649d3c8e9f96 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 81 # remember to update the list below when updating +SCHEMA_VERSION = 80 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -114,15 +114,19 @@ Changes in SCHEMA_VERSION = 80 - The event_txn_id_device_id is always written to for new events. - Add tables for the task scheduler. - -Changes in SCHEMA_VERSION = 81 - - The event_txn_id is no longer written to for new events. """ SCHEMA_COMPAT_VERSION = ( - # The `event_txn_id_device_id` must be written to for new events. - 80 + # Queries against `event_stream_ordering` columns in membership tables must + # be disambiguated. + # + # The threads_id column must written to with non-null values for the + # event_push_actions, event_push_actions_staging, and event_push_summary tables. + # + # insertions to the column `full_user_id` of tables profiles and user_filters can no + # longer be null + 76 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 029eedcc6fae..bf7bd351e0cc 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -470,7 +470,7 @@ def __init__(self) -> None: def deferred(self, key: KT) -> "defer.Deferred[VT]": if not self._deferred: self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - return self._deferred.observe().addCallback(lambda res: res[key]) + return self._deferred.observe().addCallback(lambda res: res.get(key)) def add_invalidation_callback( self, key: KT, callback: Optional[Callable[[], None]] diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index f7cead9e1206..114130a08fe2 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -51,9 +51,9 @@ def dependencies(self) -> Iterable[str]: DEV_EXTRAS = {"lint", "mypy", "test", "dev"} -ALL_EXTRAS = metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra") -assert ALL_EXTRAS is not None -RUNTIME_EXTRAS = set(ALL_EXTRAS) - DEV_EXTRAS +RUNTIME_EXTRAS = ( + set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) - DEV_EXTRAS +) VERSION = metadata.version(DISTRIBUTION_NAME) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index f693ba2a8c0c..cde4a0780fe7 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -291,8 +291,7 @@ def _on_enter(self, request_id: object) -> "defer.Deferred[None]": if self.metrics_name: rate_limit_reject_counter.labels(self.metrics_name).inc() raise LimitExceededError( - limiter_name="rc_federation", - retry_after_ms=int(self.window_size / self.sleep_limit), + retry_after_ms=int(self.window_size / self.sleep_limit) ) self.request_times.append(time_now) diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 9e89aeb74891..4aea64b338b4 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -57,13 +57,14 @@ class TaskScheduler: the code launching the task. You can also specify the `result` (and/or an `error`) when returning from the function. - The reconciliation loop runs every minute, so this is not a precise scheduler. - There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already - full. In this regard, please take great care that scheduled tasks can actually finished. - For now there is no mechanism to stop a running task if it is stuck. + The reconciliation loop runs every 5 mns, so this is not a precise scheduler. When wanting + to launch now, the launch will still not happen before the next loop run. Tasks will be run on the worker specified with `run_background_tasks_on` config, or the main one by default. + There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already + full. In this regard, please take great care that scheduled tasks can actually finished. + For now there is no mechanism to stop a running task if it is stuck. """ # Precision of the scheduler, evaluation of tasks to run will only happen @@ -84,7 +85,7 @@ def __init__(self, hs: "HomeServer"): self._actions: Dict[ str, Callable[ - [ScheduledTask], + [ScheduledTask, bool], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], ], ] = {} @@ -97,13 +98,11 @@ def __init__(self, hs: "HomeServer"): "handle_scheduled_tasks", self._handle_scheduled_tasks, ) - else: - self.replication_client = hs.get_replication_command_handler() def register_action( self, function: Callable[ - [ScheduledTask], + [ScheduledTask, bool], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], ], action_name: str, @@ -116,9 +115,10 @@ def register_action( calling `schedule_task` but rather in an `__init__` method. Args: - function: The function to be executed for this action. The parameter - passed to the function when launched is the `ScheduledTask` being run. - The function should return a tuple of new `status`, `result` + function: The function to be executed for this action. The parameters + passed to the function when launched are the `ScheduledTask` being run, + and a `first_launch` boolean to signal if it's a resumed task or the first + launch of it. The function should return a tuple of new `status`, `result` and `error` as specified in `ScheduledTask`. action_name: The name of the action to be associated with the function """ @@ -171,12 +171,6 @@ async def schedule_task( ) await self._store.insert_scheduled_task(task) - if status == TaskStatus.ACTIVE: - if self._run_background_tasks: - await self._launch_task(task) - else: - self.replication_client.send_new_active_task(task.id) - return task.id async def update_task( @@ -271,13 +265,21 @@ async def delete_task(self, id: str) -> None: Args: id: id of the task to delete """ - task = await self.get_task(id) - if task is None: - raise Exception(f"Task {id} does not exist") - if task.status == TaskStatus.ACTIVE: - raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") + if self.task_is_running(id): + raise Exception(f"Task {id} is currently running and can't be deleted") await self._store.delete_scheduled_task(id) + def task_is_running(self, id: str) -> bool: + """Check if a task is currently running. + + Can only be called from the worker handling the task scheduling. + + Args: + id: id of the task to check + """ + assert self._run_background_tasks + return id in self._running_tasks + async def _handle_scheduled_tasks(self) -> None: """Main loop taking care of launching tasks and cleaning up old ones.""" await self._launch_scheduled_tasks() @@ -286,11 +288,29 @@ async def _handle_scheduled_tasks(self) -> None: async def _launch_scheduled_tasks(self) -> None: """Retrieve and launch scheduled tasks that should be running at that time.""" for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]): - await self._launch_task(task) + if not self.task_is_running(task.id): + if ( + len(self._running_tasks) + < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + ): + await self._launch_task(task, first_launch=False) + else: + if ( + self._clock.time_msec() + > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS + ): + logger.warn( + f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" + ) for task in await self.get_tasks( statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec() ): - await self._launch_task(task) + if ( + not self.task_is_running(task.id) + and len(self._running_tasks) + < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + ): + await self._launch_task(task, first_launch=True) running_tasks_gauge.set(len(self._running_tasks)) @@ -300,27 +320,27 @@ async def _clean_scheduled_tasks(self) -> None: statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE] ): # FAILED and COMPLETE tasks should never be running - assert task.id not in self._running_tasks + assert not self.task_is_running(task.id) if ( self._clock.time_msec() > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS ): await self._store.delete_scheduled_task(task.id) - async def _launch_task(self, task: ScheduledTask) -> None: + async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None: """Launch a scheduled task now. Args: task: the task to launch + first_launch: `True` if it's the first time is launched, `False` otherwise """ - assert self._run_background_tasks - assert task.action in self._actions + function = self._actions[task.action] async def wrapper() -> None: try: - (status, result, error) = await function(task) + (status, result, error) = await function(task, first_launch) except Exception: f = Failure() logger.error( @@ -340,20 +360,6 @@ async def wrapper() -> None: ) self._running_tasks.remove(task.id) - if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: - return - - if ( - self._clock.time_msec() - > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS - ): - logger.warn( - f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" - ) - - if task.id in self._running_tasks: - return - self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) description = f"{task.id}-{task.action}" diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index dcd01d56885c..ce96574915fd 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import pymacaroons @@ -35,6 +35,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import simple_async_mock from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -59,16 +60,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) - self.store.insert_client_ip = AsyncMock(return_value=None) - self.store.is_support_user = AsyncMock(return_value=False) + self.store.insert_client_ip = simple_async_mock(None) + self.store.is_support_user = simple_async_mock(False) def test_get_user_by_req_user_valid_token(self) -> None: user_info = TokenLookupResult( user_id=self.test_user, token_id=5, device_id="device" ) - self.store.get_user_by_access_token = AsyncMock(return_value=user_info) - self.store.mark_access_token_as_used = AsyncMock(return_value=None) - self.store.get_user_locked_status = AsyncMock(return_value=False) + self.store.get_user_by_access_token = simple_async_mock(user_info) + self.store.mark_access_token_as_used = simple_async_mock(None) + self.store.get_user_locked_status = simple_async_mock(False) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -77,7 +78,7 @@ def test_get_user_by_req_user_valid_token(self) -> None: self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self) -> None: - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -90,7 +91,7 @@ def test_get_user_by_req_user_bad_token(self) -> None: def test_get_user_by_req_user_missing_token(self) -> None: user_info = TokenLookupResult(user_id=self.test_user, token_id=5) - self.store.get_user_by_access_token = AsyncMock(return_value=user_info) + self.store.get_user_by_access_token = simple_async_mock(user_info) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -105,7 +106,7 @@ def test_get_user_by_req_appservice_valid_token(self) -> None: token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -124,7 +125,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None: ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "192.168.10.10" @@ -143,7 +144,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "131.111.8.42" @@ -157,7 +158,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: def test_get_user_by_req_appservice_bad_token(self) -> None: self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -171,7 +172,7 @@ def test_get_user_by_req_appservice_bad_token(self) -> None: def test_get_user_by_req_appservice_missing_token(self) -> None: app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -189,8 +190,8 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None: app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -209,7 +210,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None: ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -233,10 +234,10 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None: app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) + self.store.get_user_by_access_token = simple_async_mock(None) # This also needs to just return a truth-y value - self.store.get_device = AsyncMock(return_value={"hidden": False}) + self.store.get_device = simple_async_mock({"hidden": False}) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -265,10 +266,10 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None: app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) + self.store.get_user_by_access_token = simple_async_mock(None) # This also needs to just return a falsey value - self.store.get_device = AsyncMock(return_value=None) + self.store.get_device = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -282,8 +283,8 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None: self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: - self.store.get_user_by_access_token = AsyncMock( - return_value=TokenLookupResult( + self.store.get_user_by_access_token = simple_async_mock( + TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", token_id=5, @@ -291,9 +292,9 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> Non token_used=True, ) ) - self.store.insert_client_ip = AsyncMock(return_value=None) - self.store.mark_access_token_as_used = AsyncMock(return_value=None) - self.store.get_user_locked_status = AsyncMock(return_value=False) + self.store.insert_client_ip = simple_async_mock(None) + self.store.mark_access_token_as_used = simple_async_mock(None) + self.store.get_user_locked_status = simple_async_mock(False) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -303,8 +304,8 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> Non def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: self.auth._track_puppeted_user_ips = True - self.store.get_user_by_access_token = AsyncMock( - return_value=TokenLookupResult( + self.store.get_user_by_access_token = simple_async_mock( + TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", token_id=5, @@ -312,9 +313,9 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: token_used=True, ) ) - self.store.get_user_locked_status = AsyncMock(return_value=False) - self.store.insert_client_ip = AsyncMock(return_value=None) - self.store.mark_access_token_as_used = AsyncMock(return_value=None) + self.store.get_user_locked_status = simple_async_mock(False) + self.store.insert_client_ip = simple_async_mock(None) + self.store.mark_access_token_as_used = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -323,7 +324,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: self.assertEqual(self.store.insert_client_ip.call_count, 2) def test_get_user_from_macaroon(self) -> None: - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_access_token = simple_async_mock(None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -341,8 +342,8 @@ def test_get_user_from_macaroon(self) -> None: ) def test_get_guest_user_from_macaroon(self) -> None: - self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) - self.store.get_user_by_access_token = AsyncMock(return_value=None) + self.store.get_user_by_id = simple_async_mock({"is_guest": True}) + self.store.get_user_by_access_token = simple_async_mock(None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -372,7 +373,7 @@ def test_blocking_mau(self) -> None: self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) + self.store.get_monthly_active_count = simple_async_mock(lots_of_users) e = self.get_failure( self.auth_blocking.check_auth_blocking(), ResourceLimitError @@ -382,27 +383,25 @@ def test_blocking_mau(self) -> None: self.assertEqual(e.value.code, 403) # Ensure does not throw an error - self.store.get_monthly_active_count = AsyncMock( - return_value=small_number_of_users - ) + self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) self.get_success(self.auth_blocking.check_auth_blocking()) def test_blocking_mau__depending_on_user_type(self) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = AsyncMock(return_value=100) + self.store.get_monthly_active_count = simple_async_mock(100) # Support users allowed self.get_success( self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) ) - self.store.get_monthly_active_count = AsyncMock(return_value=100) + self.store.get_monthly_active_count = simple_async_mock(100) # Bots not allowed self.get_failure( self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError, ) - self.store.get_monthly_active_count = AsyncMock(return_value=100) + self.store.get_monthly_active_count = simple_async_mock(100) # Real users not allowed self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) @@ -413,9 +412,9 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips( self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = False - self.store.get_monthly_active_count = AsyncMock(return_value=100) - self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) - self.store.is_trial_user = AsyncMock(return_value=False) + self.store.get_monthly_active_count = simple_async_mock(100) + self.store.user_last_seen_monthly_active = simple_async_mock() + self.store.is_trial_user = simple_async_mock() appservice = ApplicationService( "abcd", @@ -444,9 +443,9 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips( self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = True - self.store.get_monthly_active_count = AsyncMock(return_value=100) - self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) - self.store.is_trial_user = AsyncMock(return_value=False) + self.store.get_monthly_active_count = simple_async_mock(100) + self.store.user_last_seen_monthly_active = simple_async_mock() + self.store.is_trial_user = simple_async_mock() appservice = ApplicationService( "abcd", @@ -474,7 +473,7 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips( def test_reserved_threepid(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._max_mau_value = 1 - self.store.get_monthly_active_count = AsyncMock(return_value=2) + self.store.get_monthly_active_count = simple_async_mock(2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.auth_blocking._mau_limits_reserved_threepids = [threepid] diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py deleted file mode 100644 index 8e159029d9b0..000000000000 --- a/tests/api/test_errors.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2023 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from synapse.api.errors import LimitExceededError - -from tests import unittest - - -class LimitExceededErrorTestCase(unittest.TestCase): - def test_key_appears_in_context_but_not_error_dict(self) -> None: - err = LimitExceededError("needle") - serialised = json.dumps(err.error_dict(None)) - self.assertIn("needle", err.debug_context) - self.assertNotIn("needle", serialised) - - # Create a sub-class to avoid mutating the class-level property. - class LimitExceededErrorHeaders(LimitExceededError): - include_retry_after_header = True - - def test_limit_exceeded_header(self) -> None: - err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100) - self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) - assert err.headers is not None - self.assertEqual(err.headers.get("Retry-After"), "1") - - def test_limit_exceeded_rounding(self) -> None: - err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001) - self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) - assert err.headers is not None - self.assertEqual(err.headers.get("Retry-After"), "4") diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index a24638c9eff7..fa6c1c02ce95 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,6 +1,5 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from synapse.appservice import ApplicationService -from synapse.config.ratelimiting import RatelimitSettings from synapse.types import create_requester from tests import unittest @@ -11,7 +10,8 @@ def test_allowed_via_can_do_action(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), + rate_hz=0.1, + burst_count=1, ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) @@ -43,11 +43,8 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> Non limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings( - key="", - per_second=0.1, - burst_count=1, - ), + rate_hz=0.1, + burst_count=1, ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -79,11 +76,8 @@ def test_allowed_appservice_via_can_requester_do_action(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings( - key="", - per_second=0.1, - burst_count=1, - ), + rate_hz=0.1, + burst_count=1, ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -107,7 +101,8 @@ def test_allowed_via_ratelimit(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), + rate_hz=0.1, + burst_count=1, ) # Shouldn't raise @@ -133,7 +128,8 @@ def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), + rate_hz=0.1, + burst_count=1, ) # First attempt should be allowed @@ -181,7 +177,8 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), + rate_hz=0.1, + burst_count=1, ) # First attempt should be allowed @@ -211,7 +208,8 @@ def test_pruning(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), + rate_hz=0.1, + burst_count=1, ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -246,11 +244,7 @@ def test_db_user_override(self) -> None: ) ) - limiter = Ratelimiter( - store=store, - clock=self.clock, - cfg=RatelimitSettings("", per_second=0.1, burst_count=1), - ) + limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1) # Shouldn't raise for _ in range(20): @@ -260,11 +254,8 @@ def test_multiple_actions(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings( - key="", - per_second=0.1, - burst_count=3, - ), + rate_hz=0.1, + burst_count=3, ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( @@ -330,7 +321,8 @@ def test_rate_limit_burst_only_given_once(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings("", per_second=0.1, burst_count=3), + rate_hz=0.1, + burst_count=3, ) def consume_at(time: float) -> bool: @@ -354,11 +346,8 @@ def test_record_action_which_doesnt_fill_bucket(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings( - "", - per_second=0.1, - burst_count=3, - ), + rate_hz=0.1, + burst_count=3, ) # Observe two actions, leaving room in the bucket for one more. @@ -380,11 +369,8 @@ def test_record_action_which_fills_bucket(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings( - "", - per_second=0.1, - burst_count=3, - ), + rate_hz=0.1, + burst_count=3, ) # Observe three actions, filling up the bucket. @@ -412,11 +398,8 @@ def test_record_action_which_overfills_bucket(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - cfg=RatelimitSettings( - "", - per_second=0.1, - burst_count=3, - ), + rate_hz=0.1, + burst_count=3, ) # Observe four actions, exceeding the bucket. diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 75fb5fae6b92..3c635e3dcbdb 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -96,7 +96,7 @@ async def get_json( ) # We assign to a method, which mypy doesn't like. - self.api.get_json = Mock(side_effect=get_json) # type: ignore[method-assign] + self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] result = self.get_success( self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) @@ -168,7 +168,7 @@ async def get_json( ) # We assign to a method, which mypy doesn't like. - self.api.get_json = Mock(side_effect=get_json) # type: ignore[method-assign] + self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] result = self.get_success( self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) @@ -215,7 +215,7 @@ async def post_json_get_json( return RESPONSE # We assign to a method, which mypy doesn't like. - self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[method-assign] + self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment] MISSING_KEYS = [ # Known user, known device, missing algorithm. diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 6ac5fc1ae7c4..66753c60c4b1 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -13,13 +13,14 @@ # limitations under the License. import re from typing import Any, Generator -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.internet import defer from synapse.appservice import ApplicationService, Namespace from tests import unittest +from tests.test_utils import simple_async_mock def _regex(regex: str, exclusive: bool = True) -> Namespace: @@ -42,8 +43,8 @@ def setUp(self) -> None: ) self.store = Mock() - self.store.get_aliases_for_room = AsyncMock(return_value=[]) - self.store.get_local_users_in_room = AsyncMock(return_value=[]) + self.store.get_aliases_for_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) @defer.inlineCallbacks def test_regex_user_id_prefix_match( @@ -126,10 +127,10 @@ def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, Non self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room = AsyncMock( - return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"] + self.store.get_aliases_for_room = simple_async_mock( + ["#irc_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_local_users_in_room = AsyncMock(return_value=[]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield self.service.is_interested_in_event( @@ -181,10 +182,10 @@ def test_regex_alias_no_match( self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room = AsyncMock( - return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"] + self.store.get_aliases_for_room = simple_async_mock( + ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_local_users_in_room = AsyncMock(return_value=[]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertFalse( ( yield defer.ensureDeferred( @@ -204,10 +205,8 @@ def test_regex_multiple_matches( ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room = AsyncMock( - return_value=["#irc_barfoo:matrix.org"] - ) - self.store.get_local_users_in_room = AsyncMock(return_value=[]) + self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield self.service.is_interested_in_event( @@ -236,10 +235,10 @@ def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, No def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. - self.store.get_local_users_in_room = AsyncMock( - return_value=["@alice:here", "@irc_fo:here", "@bob:here"] + self.store.get_local_users_in_room = simple_async_mock( + ["@alice:here", "@irc_fo:here", "@bob:here"] ) - self.store.get_aliases_for_room = AsyncMock(return_value=[]) + self.store.get_aliases_for_room = simple_async_mock([]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 445919417e63..e2a3bad065da 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Sequence, Tuple, cast -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from typing_extensions import TypeAlias @@ -37,6 +37,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import simple_async_mock from ..utils import MockClock @@ -61,12 +62,10 @@ def test_single_service_up_txn_sent(self) -> None: txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = AsyncMock( - return_value=ApplicationServiceState.UP - ) - txn.send = AsyncMock(return_value=True) - txn.complete = AsyncMock(return_value=True) - self.store.create_appservice_txn = AsyncMock(return_value=txn) + self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) + txn.send = simple_async_mock(True) + txn.complete = simple_async_mock(True) + self.store.create_appservice_txn = simple_async_mock(txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -90,10 +89,10 @@ def test_single_service_down(self) -> None: events = [Mock(), Mock()] txn = Mock(id="idhere", service=service, events=events) - self.store.get_appservice_state = AsyncMock( - return_value=ApplicationServiceState.DOWN + self.store.get_appservice_state = simple_async_mock( + ApplicationServiceState.DOWN ) - self.store.create_appservice_txn = AsyncMock(return_value=txn) + self.store.create_appservice_txn = simple_async_mock(txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -119,12 +118,10 @@ def test_single_service_up_txn_not_sent(self) -> None: txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = AsyncMock( - return_value=ApplicationServiceState.UP - ) - self.store.set_appservice_state = AsyncMock(return_value=True) - txn.send = AsyncMock(return_value=False) # fails to send - self.store.create_appservice_txn = AsyncMock(return_value=txn) + self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) + self.store.set_appservice_state = simple_async_mock(True) + txn.send = simple_async_mock(False) # fails to send + self.store.create_appservice_txn = simple_async_mock(txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -153,7 +150,7 @@ def setUp(self) -> None: self.as_api = Mock() self.store = Mock() self.service = Mock() - self.callback = AsyncMock() + self.callback = simple_async_mock() self.recoverer = _Recoverer( clock=cast(Clock, self.clock), as_api=self.as_api, @@ -177,8 +174,8 @@ def take_txn( self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = AsyncMock(return_value=True) - txn.complete = AsyncMock(return_value=None) + txn.send = simple_async_mock(True) + txn.complete = simple_async_mock(None) # wait for exp backoff self.clock.advance_time(2) self.assertEqual(1, txn.send.call_count) @@ -205,8 +202,8 @@ def take_txn( self.recoverer.recover() self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = AsyncMock(return_value=False) - txn.complete = AsyncMock(return_value=None) + txn.send = simple_async_mock(False) + txn.complete = simple_async_mock(None) self.clock.advance_time(2) self.assertEqual(1, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) @@ -219,7 +216,7 @@ def take_txn( self.assertEqual(3, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, self.callback.call_count) - txn.send = AsyncMock(return_value=True) # successfully send the txn + txn.send = simple_async_mock(True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEqual(1, txn.send.call_count) # new mock reset call count @@ -247,7 +244,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None: self.scheduler = ApplicationServiceScheduler(hs) self.txn_ctrl = Mock() - self.txn_ctrl.send = AsyncMock() + self.txn_ctrl.send = simple_async_mock() # Replace instantiated _TransactionController instances with our Mock self.scheduler.txn_ctrl = self.txn_ctrl diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py index 0c27dd21e2b8..f12147eaa000 100644 --- a/tests/config/test_ratelimiting.py +++ b/tests/config/test_ratelimiting.py @@ -12,42 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.config.homeserver import HomeServerConfig -from synapse.config.ratelimiting import RatelimitSettings from tests.unittest import TestCase from tests.utils import default_config -class ParseRatelimitSettingsTestcase(TestCase): - def test_depth_1(self) -> None: - cfg = { - "a": { - "per_second": 5, - "burst_count": 10, - } - } - parsed = RatelimitSettings.parse(cfg, "a") - self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) - - def test_depth_2(self) -> None: - cfg = { - "a": { - "b": { - "per_second": 5, - "burst_count": 10, - }, - } - } - parsed = RatelimitSettings.parse(cfg, "a.b") - self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10)) - - def test_missing(self) -> None: - parsed = RatelimitSettings.parse( - {}, "a", defaults={"per_second": 5, "burst_count": 10} - ) - self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) - - class RatelimitConfigTestCase(TestCase): def test_parse_rc_federation(self) -> None: config_dict = default_config("test") diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f93ba5d4cf0c..2be341ac7b84 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -13,7 +13,7 @@ # limitations under the License. import time from typing import Any, Dict, List, Optional, cast -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import attr import canonicaljson @@ -45,6 +45,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import logcontext_clean, override_config @@ -290,7 +291,7 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None: with a null `ts_valid_until_ms` """ mock_fetcher = Mock() - mock_fetcher.get_keys = AsyncMock(return_value={}) + mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) key1 = signedjson.key.generate_signing_key("1") r = self.hs.get_datastores().main.store_server_signature_keys( diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 0fcfe38efada..6fb1f1bd6e31 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Iterable, List, Optional, Set, Tuple, Union -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import attr @@ -30,6 +30,7 @@ from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config +from tests.test_utils import simple_async_mock from tests.unittest import ( FederatingHomeserverTestCase, HomeserverTestCase, @@ -156,7 +157,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. self.fed_transport_client = Mock(spec=["send_transaction"]) - self.fed_transport_client.send_transaction = AsyncMock(return_value={}) + self.fed_transport_client.send_transaction = simple_async_mock({}) hs = self.setup_test_homeserver( federation_transport_client=self.fed_transport_client, diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 73a2766bafcb..129d7cfd93f5 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock +from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError from synapse.rest import admin @@ -20,6 +20,7 @@ from synapse.types import JsonDict, UserID, create_requester from tests import unittest +from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): @@ -57,7 +58,7 @@ def test_complexity_simple(self) -> None: async def get_current_state_event_counts(room_id: str) -> int: return int(500 * 1.23) - store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign] + store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] # Get the room complexity again -- make sure it's our artificial value channel = self.make_signed_federation_request( @@ -74,9 +75,9 @@ def test_join_too_large(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] - return_value=("", 1) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -105,9 +106,9 @@ def test_join_too_large_admin(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] - return_value=("", 1) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -142,16 +143,16 @@ def test_join_too_large_once_joined(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[method-assign] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] - return_value=("", 1) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] + return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity async def get_current_state_event_counts(room_id: str) -> int: return 600 - self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign] + self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] d = handler._remote_join( create_requester(u1), @@ -199,9 +200,9 @@ def test_join_too_large_no_admin(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] - return_value=("", 1) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -229,9 +230,9 @@ def test_join_too_large_admin(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] - handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] - return_value=("", 1) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 75ae740b435d..b290b020a274 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -1,6 +1,6 @@ from typing import Callable, Collection, List, Optional, Tuple from unittest import mock -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -19,7 +19,7 @@ from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination -from tests.test_utils import event_injection +from tests.test_utils import event_injection, make_awaitable from tests.unittest import FederatingHomeserverTestCase @@ -50,8 +50,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # This mock is crucial for destination_rooms to be populated. # TODO: this seems to no longer be the case---tests pass with this mock # commented out. - state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign] - return_value={"test", "host2"} + state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment] + return_value=make_awaitable({"test", "host2"}) ) # whenever send_transaction is called, record the pdu data @@ -436,7 +436,7 @@ def test_catch_up_on_synapse_startup(self) -> None: def wake_destination_track(destination: str) -> None: woken.add(destination) - self.federation_sender.wake_destination = wake_destination_track # type: ignore[method-assign] + self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment] # We wait quite long so that all dests can be woken up, since there is a delay # between them. diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 7bd3d06859f6..9e104fd96aeb 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, FrozenSet, List, Optional, Set -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from signedjson import key, sign from signedjson.types import BaseKey, SigningKey @@ -29,6 +29,7 @@ from synapse.types import JsonDict, ReadReceipt from synapse.util import Clock +from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase @@ -42,16 +43,15 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.federation_transport_client = Mock(spec=["send_transaction"]) - self.federation_transport_client.send_transaction = AsyncMock() hs = self.setup_test_homeserver( federation_transport_client=self.federation_transport_client, ) - hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign] - return_value={"test", "host2"} + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] + return_value=make_awaitable({"test", "host2"}) ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[method-assign] + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment] hs.get_storage_controllers().state.get_current_hosts_in_room ) @@ -64,7 +64,7 @@ def default_config(self) -> JsonDict: def test_send_receipts(self) -> None: mock_send_transaction = self.federation_transport_client.send_transaction - mock_send_transaction.return_value = {} + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( @@ -104,7 +104,7 @@ def test_send_receipts(self) -> None: def test_send_receipts_thread(self) -> None: mock_send_transaction = self.federation_transport_client.send_transaction - mock_send_transaction.return_value = {} + mock_send_transaction.return_value = make_awaitable({}) # Create receipts for: # @@ -180,7 +180,7 @@ def test_send_receipts_with_backoff(self) -> None: """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" mock_send_transaction = self.federation_transport_client.send_transaction - mock_send_transaction.return_value = {} + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( @@ -276,8 +276,6 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.federation_transport_client = Mock( spec=["send_transaction", "query_user_devices"] ) - self.federation_transport_client.send_transaction = AsyncMock() - self.federation_transport_client.query_user_devices = AsyncMock() return self.setup_test_homeserver( federation_transport_client=self.federation_transport_client, ) @@ -319,13 +317,13 @@ async def get_current_hosts_in_room(room_id: str) -> Set[str]: self.record_transaction ) - async def record_transaction( + def record_transaction( self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None - ) -> JsonDict: + ) -> "defer.Deferred[JsonDict]": assert json_cb is not None data = json_cb() self.edus.extend(data["edus"]) - return {} + return defer.succeed({}) def test_send_device_updates(self) -> None: """Basic case: each device update should result in an EDU""" @@ -356,11 +354,15 @@ def test_dont_send_device_updates_for_remote_users(self) -> None: # Send the server a device list EDU for the other user, this will cause # it to try and resync the device lists. - self.federation_transport_client.query_user_devices.return_value = { - "stream_id": "1", - "user_id": "@user2:host2", - "devices": [{"device_id": "D1"}], - } + self.federation_transport_client.query_user_devices.return_value = ( + make_awaitable( + { + "stream_id": "1", + "user_id": "@user2:host2", + "devices": [{"device_id": "D1"}], + } + ) + ) self.get_success( self.device_handler.device_list_updater.incoming_device_list_update( @@ -531,7 +533,7 @@ def test_unreachable_server(self) -> None: recovery """ mock_send_txn = self.federation_transport_client.send_transaction - mock_send_txn.side_effect = AssertionError("fail") + mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) # create devices u1 = self.register_user("user", "pass") @@ -576,7 +578,7 @@ def test_prune_outbound_device_pokes1(self) -> None: This case tests the behaviour when the server has never been reachable. """ mock_send_txn = self.federation_transport_client.send_transaction - mock_send_txn.side_effect = AssertionError("fail") + mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) # create devices u1 = self.register_user("user", "pass") @@ -634,7 +636,7 @@ def test_prune_outbound_device_pokes2(self) -> None: # now the server goes offline mock_send_txn = self.federation_transport_client.send_transaction - mock_send_txn.side_effect = AssertionError("fail") + mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 3f42f79f26db..70209ab09011 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -218,7 +218,7 @@ async def approve_all_signature_checking( ) -> EventBase: return pdu - homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[method-assign] + homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[assignment] approve_all_signature_checking ) @@ -229,7 +229,7 @@ async def _check_event_auth( ) -> None: pass - homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[method-assign] + homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] return super().prepare(reactor, clock, homeserver) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 46d022092e82..9014e60577c7 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Dict, Iterable, List, Optional -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from parameterized import parameterized @@ -36,7 +36,7 @@ from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import event_injection +from tests.test_utils import event_injection, make_awaitable, simple_async_mock from tests.unittest import override_config from tests.utils import MockClock @@ -46,13 +46,15 @@ class AppServiceHandlerTestCase(unittest.TestCase): def setUp(self) -> None: self.mock_store = Mock() - self.mock_as_api = AsyncMock() + self.mock_as_api = Mock() self.mock_scheduler = Mock() hs = Mock() hs.get_datastores.return_value = Mock(main=self.mock_store) - self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None) - self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None) - self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None) + self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None) + self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) + self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( + None + ) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() @@ -67,25 +69,21 @@ def test_notify_interested_services(self) -> None: self._mkservice(is_interested_in_event=False), ] - self.mock_as_api.query_user.return_value = True + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id = AsyncMock(return_value=[]) + self.mock_store.get_user_by_id.return_value = make_awaitable([]) event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_all_new_event_ids_stream = AsyncMock( - side_effect=[ - (0, {}), - (1, {event.event_id: 0}), - ] - ) - self.mock_store.get_events_as_list = AsyncMock( - side_effect=[ - [], - [event], - ] - ) + self.mock_store.get_all_new_event_ids_stream.side_effect = [ + make_awaitable((0, {})), + make_awaitable((1, {event.event_id: 0})), + ] + self.mock_store.get_events_as_list.side_effect = [ + make_awaitable([]), + make_awaitable([event]), + ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( @@ -97,16 +95,14 @@ def test_query_user_exists_unknown_user(self) -> None: services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id = AsyncMock(return_value=None) + self.mock_store.get_user_by_id.return_value = make_awaitable(None) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = True - self.mock_store.get_all_new_event_ids_stream = AsyncMock( - side_effect=[ - (0, {event.event_id: 0}), - ] - ) - self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]]) + self.mock_as_api.query_user.return_value = make_awaitable(True) + self.mock_store.get_all_new_event_ids_stream.side_effect = [ + make_awaitable((0, {event.event_id: 0})), + ] + self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])] self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @@ -116,15 +112,13 @@ def test_query_user_exists_known_user(self) -> None: services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id}) + self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = True - self.mock_store.get_all_new_event_ids_stream = AsyncMock( - side_effect=[ - (0, [event], {event.event_id: 0}), - ] - ) + self.mock_as_api.query_user.return_value = make_awaitable(True) + self.mock_store.get_all_new_event_ids_stream.side_effect = [ + make_awaitable((0, [event], {event.event_id: 0})), + ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -147,10 +141,10 @@ def test_query_room_alias_exists(self) -> None: self._mkservice_alias(is_room_alias_in_namespace=False), ] - self.mock_as_api.query_alias = AsyncMock(return_value=True) + self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_association_from_room_alias = AsyncMock( - return_value=Mock(room_id=room_id, servers=servers) + self.mock_store.get_association_from_room_alias.return_value = make_awaitable( + Mock(room_id=room_id, servers=servers) ) result = self.successResultOf( @@ -183,7 +177,7 @@ def test_get_3pe_protocols_no_protocols(self) -> None: def test_get_3pe_protocols_protocol_no_response(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] - self.mock_as_api.get_3pe_protocol.return_value = None + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols()) ) @@ -195,10 +189,9 @@ def test_get_3pe_protocols_protocol_no_response(self) -> None: def test_get_3pe_protocols_select_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] - self.mock_as_api.get_3pe_protocol.return_value = { - "x-protocol-data": 42, - "instances": [], - } + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) ) @@ -212,10 +205,9 @@ def test_get_3pe_protocols_select_one_protocol(self) -> None: def test_get_3pe_protocols_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] - self.mock_as_api.get_3pe_protocol.return_value = { - "x-protocol-data": 42, - "instances": [], - } + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols()) ) @@ -230,10 +222,9 @@ def test_get_3pe_protocols_multiple_protocol(self) -> None: service_one = self._mkservice(False, ["my-protocol"]) service_two = self._mkservice(False, ["other-protocol"]) self.mock_store.get_app_services.return_value = [service_one, service_two] - self.mock_as_api.get_3pe_protocol.return_value = { - "x-protocol-data": 42, - "instances": [], - } + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols()) ) @@ -296,11 +287,13 @@ def test_notify_interested_services_ephemeral(self) -> None: interested_service = self._mkservice(is_interested_in_event=True) services = [interested_service] self.mock_store.get_app_services.return_value = services - self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579) + self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( + 579 + ) event = Mock(event_id="event_1") - self.event_source.sources.receipt.get_new_events_as = AsyncMock( - return_value=([event], None) + self.event_source.sources.receipt.get_new_events_as.return_value = ( + make_awaitable(([event], None)) ) self.handler.notify_interested_services_ephemeral( @@ -324,11 +317,13 @@ def test_notify_interested_services_ephemeral_out_of_order(self) -> None: services = [interested_service] self.mock_store.get_app_services.return_value = services - self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580) + self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( + 580 + ) event = Mock(event_id="event_1") - self.event_source.sources.receipt.get_new_events_as = AsyncMock( - return_value=([event], None) + self.event_source.sources.receipt.get_new_events_as.return_value = ( + make_awaitable(([event], None)) ) self.handler.notify_interested_services_ephemeral( @@ -355,7 +350,9 @@ def _mkservice( A mock representing the ApplicationService. """ service = Mock() - service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event) + service.is_interested_in_event.return_value = make_awaitable( + is_interested_in_event + ) service.token = "mock_service_token" service.url = "mock_service_url" service.protocols = protocols @@ -399,12 +396,12 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events - self.send_mock = AsyncMock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign] + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] return_value=self._services ) @@ -897,12 +894,12 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that # will be sent over the wire - self.put_json = AsyncMock() - hs.get_application_service_api().put_json = self.put_json # type: ignore[method-assign] + self.put_json = simple_async_mock() + hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign] + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] return_value=self._services ) @@ -1003,8 +1000,8 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track what's going out - self.send_mock = AsyncMock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] # We assign to a method. + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. # Define an application service for the tests self._service_token = "VERYSECRET" diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 413ff8795bef..036dbbc45ba5 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional -from unittest.mock import AsyncMock +from unittest.mock import Mock import pymacaroons @@ -25,6 +25,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable class AuthTestCase(unittest.HomeserverTestCase): @@ -165,8 +166,8 @@ def test_mau_limits_disabled(self) -> None: def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( - return_value=self.large_number_of_users + self.hs.get_datastores().main.get_monthly_active_count = Mock( + return_value=make_awaitable(self.large_number_of_users) ) self.get_failure( @@ -176,8 +177,8 @@ def test_mau_limits_exceeded_large(self) -> None: ResourceLimitError, ) - self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( - return_value=self.large_number_of_users + self.hs.get_datastores().main.get_monthly_active_count = Mock( + return_value=make_awaitable(self.large_number_of_users) ) token = self.get_success( self.auth_handler.create_login_token_for_user_id(self.user1) @@ -190,8 +191,8 @@ def test_mau_limits_parity(self) -> None: self.auth_blocking._limit_usage_by_mau = True # Set the server to be at the edge of too many users. - self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( - return_value=self.auth_blocking._max_mau_value + self.hs.get_datastores().main.get_monthly_active_count = Mock( + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) # If not in monthly active cohort @@ -207,8 +208,8 @@ def test_mau_limits_parity(self) -> None: self.assertIsNone(self.token_login(token)) # If in monthly active cohort - self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock( - return_value=self.clock.time_msec() + self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( + return_value=make_awaitable(self.clock.time_msec()) ) self.get_success( self.auth_handler.create_access_token_for_user_id( @@ -223,8 +224,8 @@ def test_mau_limits_parity(self) -> None: def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( - return_value=self.small_number_of_users + self.hs.get_datastores().main.get_monthly_active_count = Mock( + return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception self.get_success( @@ -233,8 +234,8 @@ def test_mau_limits_not_exceeded(self) -> None: ) ) - self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( - return_value=self.small_number_of_users + self.hs.get_datastores().main.get_monthly_active_count = Mock( + return_value=make_awaitable(self.small_number_of_users) ) token = self.get_success( self.auth_handler.create_login_token_for_user_id(self.user1) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 8582b1cd1e9e..63aad0d10c2f 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -20,6 +20,7 @@ from synapse.server import HomeServer from synapse.util import Clock +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. @@ -60,7 +61,7 @@ def test_map_cas_user_to_user(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] cas_response = CasResponse("test_user", {}) request = _mock_request() @@ -88,7 +89,7 @@ def test_map_cas_user_to_existing_user(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # Map a user via SSO. cas_response = CasResponse("test_user", {}) @@ -128,7 +129,7 @@ def test_map_cas_user_to_invalid_localpart(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] cas_response = CasResponse("föö", {}) request = _mock_request() @@ -159,7 +160,7 @@ def test_required_attributes(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # The response doesn't have the proper userGroup or department. cas_response = CasResponse("test_user", {}) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 55a4f95ef32b..e1e58fa6e648 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -32,6 +32,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config user1 = "@boris:aaa" @@ -40,7 +41,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.appservice_api = mock.AsyncMock() + self.appservice_api = mock.Mock() hs = self.setup_test_homeserver( "server", application_service_api=self.appservice_api, @@ -122,50 +123,50 @@ def test_get_devices_by_user(self) -> None: self.assertEqual(3, len(res)) device_map = {d["device_id"]: d for d in res} - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": user1, "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, "last_seen_ts": None, - }.items(), - device_map["xyz"].items(), + }, + device_map["xyz"], ) - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": user1, "device_id": "fco", "display_name": "display 1", "last_seen_ip": "ip1", "last_seen_ts": 1000000, - }.items(), - device_map["fco"].items(), + }, + device_map["fco"], ) - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": user1, "device_id": "abc", "display_name": "display 2", "last_seen_ip": "ip3", "last_seen_ts": 3000000, - }.items(), - device_map["abc"].items(), + }, + device_map["abc"], ) def test_get_device(self) -> None: self._record_users() res = self.get_success(self.handler.get_device(user1, "abc")) - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": user1, "device_id": "abc", "display_name": "display 2", "last_seen_ip": "ip3", "last_seen_ts": 3000000, - }.items(), - res.items(), + }, + res, ) def test_delete_device(self) -> None: @@ -374,11 +375,13 @@ def test_on_federation_query_user_devices_appservice(self) -> None: ) # Setup a response. - self.appservice_api.query_keys.return_value = { - "device_keys": { - local_user: {device_2: device_key_2b, device_3: device_key_3} + self.appservice_api.query_keys.return_value = make_awaitable( + { + "device_keys": { + local_user: {device_2: device_key_2b, device_3: device_key_3} + } } - } + ) # Request all devices. res = self.get_success( diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 367d94eca3dd..90aec484c48c 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Awaitable, Callable, Dict -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -27,13 +27,14 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): """Tests the directory service.""" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.mock_federation = AsyncMock() + self.mock_federation = Mock() self.mock_registry = Mock() self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} @@ -72,10 +73,9 @@ def test_get_local_association(self) -> None: self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self) -> None: - self.mock_federation.make_query.return_value = { - "room_id": "!8765qwer:test", - "servers": ["test", "remote"], - } + self.mock_federation.make_query.return_value = make_awaitable( + {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} + ) result = self.get_success(self.handler.get_association(self.remote_room)) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index c5556f284491..2eaffe511ee4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterable +from typing import Iterable from unittest import mock from parameterized import parameterized @@ -31,12 +31,13 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.appservice_api = mock.AsyncMock() + self.appservice_api = mock.Mock() return self.setup_test_homeserver( federation_client=mock.Mock(), application_service_api=self.appservice_api ) @@ -800,27 +801,29 @@ def test_query_devices_remote_no_sync(self) -> None: remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign] - return_value={ - "device_keys": {remote_user_id: {}}, - "master_keys": { - remote_user_id: { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - }, - "self_signing_keys": { - remote_user_id: { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" - + remote_self_signing_key: remote_self_signing_key + self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment] + return_value=make_awaitable( + { + "device_keys": {remote_user_id: {}}, + "master_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - } - }, - } + }, + "self_signing_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + } + }, + } + ) ) e2e_handler = self.hs.get_e2e_keys_handler() @@ -871,29 +874,34 @@ def test_query_devices_remote_sync(self) -> None: # Pretend we're sharing a room with the user we're querying. If not, # `_query_devices_for_destination` will return early. - self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"}) + self.store.get_rooms_for_user = mock.Mock( + return_value=make_awaitable({"some_room_id"}) + ) remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign] - return_value={ - "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { + self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment] + return_value=make_awaitable( + { "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" + remote_self_signing_key: remote_self_signing_key + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - }, - } + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) ) e2e_handler = self.hs.get_e2e_keys_handler() @@ -979,20 +987,20 @@ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> Non mock_get_rooms = mock.patch.object( self.store, "get_rooms_for_user", - new_callable=mock.AsyncMock, - return_value=["some_room_id"], + new_callable=mock.MagicMock, + return_value=make_awaitable(["some_room_id"]), ) mock_get_users = mock.patch.object( self.store, "get_users_server_still_shares_room_with", - new_callable=mock.AsyncMock, - return_value={remote_user_id}, + new_callable=mock.MagicMock, + return_value=make_awaitable({remote_user_id}), ) mock_request = mock.patch.object( self.hs.get_federation_client(), "query_user_devices", - new_callable=mock.AsyncMock, - return_value=response_body, + new_callable=mock.MagicMock, + return_value=make_awaitable(response_body), ) with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request: @@ -1052,9 +1060,8 @@ def test_query_appservice(self) -> None: ) # Setup a response, but only for device 2. - self.appservice_api.claim_client_keys.return_value = ( - {local_user: {device_id_2: otk}}, - [(local_user, device_id_1, "alg1", 1)], + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)]) ) # we shouldn't have any unused fallback keys yet @@ -1120,10 +1127,9 @@ def test_query_appservice_with_fallback(self) -> None: ) # Setup a response. - response: Dict[str, Dict[str, Dict[str, JsonDict]]] = { - local_user: {device_id_1: {**as_otk, **as_fallback_key}} - } - self.appservice_api.claim_client_keys.return_value = (response, []) + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, []) + ) # Claim OTKs, which will ask the appservice and do nothing else. claim_res = self.get_success( @@ -1165,9 +1171,8 @@ def test_query_appservice_with_fallback(self) -> None: self.assertEqual(fallback_res, ["alg1"]) # The appservice will return only the OTK. - self.appservice_api.claim_client_keys.return_value = ( - {local_user: {device_id_1: as_otk}}, - [], + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: as_otk}}, []) ) # Claim OTKs, which should return the OTK from the appservice and the @@ -1229,9 +1234,8 @@ def test_query_appservice_with_fallback(self) -> None: self.assertEqual(fallback_res, ["alg1"]) # Finally, return only the fallback key from the appservice. - self.appservice_api.claim_client_keys.return_value = ( - {local_user: {device_id_1: as_fallback_key}}, - [], + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: as_fallback_key}}, []) ) # Claim OTKs, which will return only the fallback key from the database. @@ -1346,11 +1350,13 @@ def test_query_local_devices_appservice(self) -> None: ) # Setup a response. - self.appservice_api.query_keys.return_value = { - "device_keys": { - local_user: {device_2: device_key_2b, device_3: device_key_3} + self.appservice_api.query_keys.return_value = make_awaitable( + { + "device_keys": { + local_user: {device_2: device_key_2b, device_3: device_key_3} + } } - } + ) # Request all devices. res = self.get_success(self.handler.query_local_devices({local_user: None})) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 21d63ab1f297..5f11d5df11ad 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -14,7 +14,7 @@ import logging from typing import Collection, Optional, cast from unittest import TestCase -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock, patch from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor @@ -40,7 +40,7 @@ from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import event_injection +from tests.test_utils import event_injection, make_awaitable logger = logging.getLogger(__name__) @@ -370,15 +370,15 @@ def test_backfill_ignores_known_events(self) -> None: # We mock out the FederationClient.backfill method, to pretend that a remote # server has returned our fake event. - federation_client_backfill_mock = AsyncMock(return_value=[event]) - self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[method-assign] + federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) + self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment] # We also mock the persist method with a side effect of itself. This allows us # to track when it has been called while preserving its function. persist_events_and_notify_mock = Mock( side_effect=self.hs.get_federation_event_handler().persist_events_and_notify ) - self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[method-assign] + self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment] persist_events_and_notify_mock ) @@ -631,29 +631,33 @@ def test_failed_partial_join_is_clean(self) -> None: }, RoomVersions.V10, ) - mock_make_membership_event = AsyncMock( - return_value=( - "example.com", - membership_event, - RoomVersions.V10, + mock_make_membership_event = Mock( + return_value=make_awaitable( + ( + "example.com", + membership_event, + RoomVersions.V10, + ) ) ) - mock_send_join = AsyncMock( - return_value=SendJoinResult( - membership_event, - "example.com", - state=[ - EVENT_CREATE, - EVENT_CREATOR_MEMBERSHIP, - EVENT_INVITATION_MEMBERSHIP, - ], - auth_chain=[ - EVENT_CREATE, - EVENT_CREATOR_MEMBERSHIP, - EVENT_INVITATION_MEMBERSHIP, - ], - partial_state=True, - servers_in_room={"example.com"}, + mock_send_join = Mock( + return_value=make_awaitable( + SendJoinResult( + membership_event, + "example.com", + state=[ + EVENT_CREATE, + EVENT_CREATOR_MEMBERSHIP, + EVENT_INVITATION_MEMBERSHIP, + ], + auth_chain=[ + EVENT_CREATE, + EVENT_CREATOR_MEMBERSHIP, + EVENT_INVITATION_MEMBERSHIP, + ], + partial_state=True, + servers_in_room={"example.com"}, + ) ) ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 70e6a7e142f1..23f1b33b2fda 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -35,7 +35,7 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import event_injection +from tests.test_utils import event_injection, make_awaitable class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): @@ -50,10 +50,6 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation_transport_client = mock.Mock( spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] ) - self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock() - self.mock_federation_transport_client.get_room_state = mock.AsyncMock() - self.mock_federation_transport_client.get_event = mock.AsyncMock() - self.mock_federation_transport_client.backfill = mock.AsyncMock() return super().setup_test_homeserver( federation_transport_client=self.mock_federation_transport_client ) @@ -202,14 +198,20 @@ async def get_event( ) # we expect an outbound request to /state_ids, so stub that out - self.mock_federation_transport_client.get_room_state_ids.return_value = { - "pdu_ids": [e.event_id for e in state_at_prev_event], - "auth_chain_ids": [], - } + self.mock_federation_transport_client.get_room_state_ids.return_value = ( + make_awaitable( + { + "pdu_ids": [e.event_id for e in state_at_prev_event], + "auth_chain_ids": [], + } + ) + ) # we also expect an outbound request to /state self.mock_federation_transport_client.get_room_state.return_value = ( - StateRequestResponse(auth_events=[], state=state_at_prev_event) + make_awaitable( + StateRequestResponse(auth_events=[], state=state_at_prev_event) + ) ) # we have to bump the clock a bit, to keep the retry logic in @@ -271,23 +273,26 @@ def test_process_pulled_event_records_failed_backfill_attempts( room_version = self.get_success(main_store.get_room_version(room_id)) # We expect an outbound request to /state_ids, so stub that out - self.mock_federation_transport_client.get_room_state_ids.return_value = { - # Mimic the other server not knowing about the state at all. - # We want to cause Synapse to throw an error (`Unable to get - # missing prev_event $fake_prev_event`) and fail to backfill - # the pulled event. - "pdu_ids": [], - "auth_chain_ids": [], - } - + self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable( + { + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + "pdu_ids": [], + "auth_chain_ids": [], + } + ) # We also expect an outbound request to /state - self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse( - # Mimic the other server not knowing about the state at all. - # We want to cause Synapse to throw an error (`Unable to get - # missing prev_event $fake_prev_event`) and fail to backfill - # the pulled event. - auth_events=[], - state=[], + self.mock_federation_transport_client.get_room_state.return_value = make_awaitable( + StateRequestResponse( + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + auth_events=[], + state=[], + ) ) pulled_event = make_event_from_dict( @@ -540,23 +545,25 @@ def test_backfill_signature_failure_does_not_fetch_same_prev_event_later( ) # We expect an outbound request to /backfill, so stub that out - self.mock_federation_transport_client.backfill.return_value = { - "origin": self.OTHER_SERVER_NAME, - "origin_server_ts": 123, - "pdus": [ - # This is one of the important aspects of this test: we include - # `pulled_event_without_signatures` so it fails the signature check - # when we filter down the backfill response down to events which - # have valid signatures in - # `_check_sigs_and_hash_for_pulled_events_and_fetch` - pulled_event_without_signatures.get_pdu_json(), - # Then later when we process this valid signature event, when we - # fetch the missing `prev_event`s, we want to make sure that we - # backoff and don't try and fetch `pulled_event_without_signatures` - # again since we know it just had an invalid signature. - pulled_event.get_pdu_json(), - ], - } + self.mock_federation_transport_client.backfill.return_value = make_awaitable( + { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + # This is one of the important aspects of this test: we include + # `pulled_event_without_signatures` so it fails the signature check + # when we filter down the backfill response down to events which + # have valid signatures in + # `_check_sigs_and_hash_for_pulled_events_and_fetch` + pulled_event_without_signatures.get_pdu_json(), + # Then later when we process this valid signature event, when we + # fetch the missing `prev_event`s, we want to make sure that we + # backoff and don't try and fetch `pulled_event_without_signatures` + # again since we know it just had an invalid signature. + pulled_event.get_pdu_json(), + ], + } + ) # Keep track of the count and make sure we don't make any of these requests event_endpoint_requested_count = 0 @@ -724,13 +731,15 @@ def test_backfill_process_previously_failed_pull_attempt_event_in_the_background ) # We expect an outbound request to /backfill, so stub that out - self.mock_federation_transport_client.backfill.return_value = { - "origin": self.OTHER_SERVER_NAME, - "origin_server_ts": 123, - "pdus": [ - pulled_event.get_pdu_json(), - ], - } + self.mock_federation_transport_client.backfill.return_value = make_awaitable( + { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + pulled_event.get_pdu_json(), + ], + } + ) # The function under test: try to backfill and process the pulled event with LoggingContext("test"): diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 1c5897c84e49..9691d66b48a0 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -46,11 +46,18 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._persist_event_storage_controller = persistence self.user_id = self.register_user("tester", "foobar") - device_id = "dev-1" - access_token = self.login("tester", "foobar", device_id=device_id) - self.room_id = self.helper.create_room_as(self.user_id, tok=access_token) + self.access_token = self.login("tester", "foobar") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) + + info = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token( + self.access_token, + ) + ) + assert info is not None + self.token_id = info.token_id - self.requester = create_requester(self.user_id, device_id=device_id) + self.requester = create_requester(self.user_id, access_token_id=self.token_id) def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: # Create a member event we can use as an auth_event diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 534cae7f893f..3baeb28e620f 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -39,7 +39,7 @@ from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import FakeResponse, get_awaitable_result +from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.unittest import HomeserverTestCase, skip_unless from tests.utils import mock_getRawHeaders @@ -148,7 +148,7 @@ def _assertParams(self) -> None: def test_inactive_token(self) -> None: """The handler should return a 403 where the token is inactive.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={"active": False}, @@ -167,7 +167,7 @@ def test_inactive_token(self) -> None: def test_active_no_scope(self) -> None: """The handler should return a 403 where no scope is given.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={"active": True}, @@ -186,7 +186,7 @@ def test_active_no_scope(self) -> None: def test_active_user_no_subject(self) -> None: """The handler should return a 500 when no subject is present.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}, @@ -205,7 +205,7 @@ def test_active_user_no_subject(self) -> None: def test_active_no_user_scope(self) -> None: """The handler should return a 500 when no subject is present.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -228,7 +228,7 @@ def test_active_no_user_scope(self) -> None: def test_active_admin_not_user(self) -> None: """The handler should raise when the scope has admin right but not user.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -252,7 +252,7 @@ def test_active_admin_not_user(self) -> None: def test_active_admin(self) -> None: """The handler should return a requester with admin rights.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -282,7 +282,7 @@ def test_active_admin(self) -> None: def test_active_admin_highest_privilege(self) -> None: """The handler should resolve to the most permissive scope.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -314,7 +314,7 @@ def test_active_admin_highest_privilege(self) -> None: def test_active_user(self) -> None: """The handler should return a requester with normal user rights.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -345,7 +345,7 @@ def test_active_user_admin_impersonation(self) -> None: """The handler should return a requester with normal user rights and an user ID matching the one specified in query param `user_id`""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -379,7 +379,7 @@ def test_active_user_admin_impersonation(self) -> None: def test_active_user_with_device(self) -> None: """The handler should return a requester with normal user rights and a device ID.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -409,7 +409,7 @@ def test_active_user_with_device(self) -> None: def test_multiple_devices(self) -> None: """The handler should raise an error if multiple devices are found in the scope.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -434,7 +434,7 @@ def test_multiple_devices(self) -> None: def test_active_guest_not_allowed(self) -> None: """The handler should return an insufficient scope error.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -464,7 +464,7 @@ def test_active_guest_not_allowed(self) -> None: def test_active_guest_allowed(self) -> None: """The handler should return a requester with guest user rights and a device ID.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -500,19 +500,19 @@ def test_unavailable_introspection_endpoint(self) -> None: request.requestHeaders.getRawHeaders = mock_getRawHeaders() # The introspection endpoint is returning an error. - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse(code=500, body=b"Internal Server Error") ) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) # The introspection endpoint request fails. - self.http_client.request = AsyncMock(side_effect=Exception()) + self.http_client.request = simple_async_mock(raises=Exception()) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) # The introspection endpoint does not return a JSON object. - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload=["this is an array", "not an object"] ) @@ -521,7 +521,7 @@ def test_unavailable_introspection_endpoint(self) -> None: self.assertEqual(error.value.code, 503) # The introspection endpoint does not return valid JSON. - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse(code=200, body=b"this is not valid JSON") ) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) @@ -529,7 +529,7 @@ def test_unavailable_introspection_endpoint(self) -> None: def test_introspection_token_cache(self) -> None: access_token = "open_sesame" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={"active": "true", "scope": "guest", "jti": access_token}, @@ -560,7 +560,7 @@ def test_introspection_token_cache(self) -> None: # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a # token with a soon-to-expire `exp` field to the cache - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ @@ -641,7 +641,7 @@ def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: def test_cross_signing(self) -> None: """Try uploading device keys with OAuth delegation enabled.""" - self.http_client.request = AsyncMock( + self.http_client.request = simple_async_mock( return_value=FakeResponse.json( code=200, payload={ diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e797aaae00dd..0a8bae54fbea 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple -from unittest.mock import ANY, AsyncMock, Mock, patch +from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse import pymacaroons @@ -28,7 +28,7 @@ from synapse.util.macaroons import get_value_from_macaroon from synapse.util.stringutils import random_string -from tests.test_utils import FakeResponse, get_awaitable_result +from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config @@ -157,15 +157,15 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) - sso_handler.render_error = self.render_error # type: ignore[method-assign] + sso_handler.render_error = self.render_error # type: ignore[assignment] # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 auth_handler = hs.get_auth_handler() # Mock the complete SSO login method. - self.complete_sso_login = AsyncMock() - auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[method-assign] + self.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] return hs diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 11ec8c7f116f..394006f5f314 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -16,7 +16,7 @@ from http import HTTPStatus from typing import Any, Dict, List, Optional, Type, Union -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -32,6 +32,7 @@ from tests import unittest from tests.server import FakeChannel +from tests.test_utils import make_awaitable from tests.unittest import override_config # Login flows we expect to appear in the list after the normal ones. @@ -186,7 +187,7 @@ def password_only_auth_provider_login_test_body(self) -> None: self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) # check_password must return an awaitable - mock_password_provider.check_password = AsyncMock(return_value=True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) @@ -208,13 +209,13 @@ def password_only_auth_provider_ui_auth_test_body(self) -> None: """UI Auth should delegate correctly to the password provider""" # log in twice, to get two devices - mock_password_provider.check_password = AsyncMock(return_value=True) + mock_password_provider.check_password.return_value = make_awaitable(True) tok1 = self.login("u", "p") self.login("u", "p", device_id="dev2") mock_password_provider.reset_mock() # have the auth provider deny the request to start with - mock_password_provider.check_password = AsyncMock(return_value=False) + mock_password_provider.check_password.return_value = make_awaitable(False) # make the initial request which returns a 401 session = self._start_delete_device_session(tok1, "dev2") @@ -228,7 +229,7 @@ def password_only_auth_provider_ui_auth_test_body(self) -> None: mock_password_provider.reset_mock() # Finally, check the request goes through when we allow it - mock_password_provider.check_password = AsyncMock(return_value=True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @@ -242,7 +243,7 @@ def local_user_fallback_login_test_body(self) -> None: self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password = AsyncMock(return_value=False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) @@ -259,7 +260,7 @@ def local_user_fallback_ui_auth_test_body(self) -> None: self.register_user("localuser", "localpass") # have the auth provider deny the request - mock_password_provider.check_password = AsyncMock(return_value=False) + mock_password_provider.check_password.return_value = make_awaitable(False) # log in twice, to get two devices tok1 = self.login("localuser", "localpass") @@ -302,7 +303,7 @@ def no_local_user_fallback_login_test_body(self) -> None: self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password = AsyncMock(return_value=False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -324,7 +325,7 @@ def no_local_user_fallback_ui_auth_test_body(self) -> None: self.register_user("localuser", "localpass") # allow login via the auth provider - mock_password_provider.check_password = AsyncMock(return_value=True) + mock_password_provider.check_password.return_value = make_awaitable(True) # log in twice, to get two devices tok1 = self.login("localuser", "p") @@ -341,7 +342,7 @@ def no_local_user_fallback_ui_auth_test_body(self) -> None: mock_password_provider.check_password.assert_not_called() # now try deleting with the local password - mock_password_provider.check_password = AsyncMock(return_value=False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._authed_delete_device( tok1, "dev2", session, "localuser", "localpass" ) @@ -395,7 +396,9 @@ def custom_auth_provider_login_test_body(self) -> None: self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None)) + mock_password_provider.check_auth.return_value = make_awaitable( + ("@user:test", None) + ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@user:test", channel.json_body["user_id"]) @@ -444,7 +447,9 @@ def custom_auth_provider_ui_auth_test_body(self) -> None: mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None)) + mock_password_provider.check_auth.return_value = make_awaitable( + ("@user:test", None) + ) body["auth"]["test_field"] = "foo" channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 403) @@ -455,8 +460,8 @@ def custom_auth_provider_ui_auth_test_body(self) -> None: mock_password_provider.reset_mock() # and finally, succeed - mock_password_provider.check_auth = AsyncMock( - return_value=("@localuser:test", None) + mock_password_provider.check_auth.return_value = make_awaitable( + ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 200) @@ -473,10 +478,10 @@ def test_custom_auth_provider_callback(self) -> None: self.custom_auth_provider_callback_test_body() def custom_auth_provider_callback_test_body(self) -> None: - callback = AsyncMock(return_value=None) + callback = Mock(return_value=make_awaitable(None)) - mock_password_provider.check_auth = AsyncMock( - return_value=("@user:test", callback) + mock_password_provider.check_auth.return_value = make_awaitable( + ("@user:test", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) @@ -611,8 +616,8 @@ def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None: login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") - mock_password_provider.check_auth = AsyncMock( - return_value=("@localuser:test", None) + mock_password_provider.check_auth.return_value = make_awaitable( + ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) @@ -830,11 +835,11 @@ def _test_3pid_allowed(self, username: str, registration: bool) -> None: username: The username to use for the test. registration: Whether to test with registration URLs. """ - self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[method-assign] - return_value=0 + self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment] + return_value=make_awaitable(0), ) - m = AsyncMock(return_value=False) + m = Mock(return_value=make_awaitable(False)) self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] self.register_user(username, "password") @@ -864,7 +869,7 @@ def _test_3pid_allowed(self, username: str, registration: bool) -> None: m.assert_called_once_with("email", "foo@test.com", registration) - m = AsyncMock(return_value=True) + m = Mock(return_value=make_awaitable(True)) self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] channel = self.make_request( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index a987267308ee..1aebcc16adc5 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -524,7 +524,6 @@ def default_config(self) -> JsonDict: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = f"@test:{self.hs.config.server.server_name}" - self.device_id = "dev-1" # Move the reactor to the initial time. self.reactor.advance(1000) @@ -609,10 +608,7 @@ def test_restored_presence_online_after_sync( self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000 / 2) self.get_success( presence_handler.user_syncing( - self.user_id, - self.device_id, - sync_state != PresenceState.OFFLINE, - sync_state, + self.user_id, sync_state != PresenceState.OFFLINE, sync_state ) ) @@ -636,7 +632,6 @@ def test_restored_presence_online_after_sync( class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): user_id = "@test:server" user_id_obj = UserID.from_string(user_id) - device_id = "dev-1" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.presence_handler = hs.get_presence_handler() @@ -646,20 +641,13 @@ def test_external_process_timeout(self) -> None: """Test that if an external process doesn't update the records for a while we time out their syncing users presence. """ + process_id = "1" - # Create a worker and use it to handle /sync traffic instead. - # This is used to test that presence changes get replicated from workers - # to the main process correctly. - worker_to_sync_against = self.make_worker_hs( - "synapse.app.generic_worker", {"worker_name": "synchrotron"} - ) - worker_presence_handler = worker_to_sync_against.get_presence_handler() - + # Notify handler that a user is now syncing. self.get_success( - worker_presence_handler.user_syncing( - self.user_id, self.device_id, True, PresenceState.ONLINE - ), - by=0.1, + self.presence_handler.update_external_syncs_row( + process_id, self.user_id, True, self.clock.time_msec() + ) ) # Check that if we wait a while without telling the handler the user has @@ -713,7 +701,7 @@ def test_user_goes_offline_manually_with_no_status_msg(self) -> None: # Mark user as offline self.get_success( self.presence_handler.set_state( - self.user_id_obj, self.device_id, {"presence": PresenceState.OFFLINE} + self.user_id_obj, {"presence": PresenceState.OFFLINE} ) ) @@ -745,7 +733,7 @@ def test_user_reset_online_with_no_status(self) -> None: # Mark user as online again self.get_success( self.presence_handler.set_state( - self.user_id_obj, self.device_id, {"presence": PresenceState.ONLINE} + self.user_id_obj, {"presence": PresenceState.ONLINE} ) ) @@ -774,7 +762,7 @@ def test_set_presence_from_syncing_not_set(self) -> None: self.get_success( self.presence_handler.user_syncing( - self.user_id, self.device_id, False, PresenceState.ONLINE + self.user_id, False, PresenceState.ONLINE ) ) @@ -791,9 +779,7 @@ def test_set_presence_from_syncing_is_set(self) -> None: self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg) self.get_success( - self.presence_handler.user_syncing( - self.user_id, self.device_id, True, PresenceState.ONLINE - ) + self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE) ) state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) @@ -807,9 +793,7 @@ def test_set_presence_from_syncing_keeps_status(self) -> None: self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg) self.get_success( - self.presence_handler.user_syncing( - self.user_id, self.device_id, True, PresenceState.ONLINE - ) + self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE) ) state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) @@ -836,7 +820,7 @@ def test_set_presence_from_syncing_keeps_busy( # This is used to test that presence changes get replicated from workers # to the main process correctly. worker_to_sync_against = self.make_worker_hs( - "synapse.app.generic_worker", {"worker_name": "synchrotron"} + "synapse.app.generic_worker", {"worker_name": "presence_writer"} ) # Set presence to BUSY @@ -847,9 +831,8 @@ def test_set_presence_from_syncing_keeps_busy( # /presence/*. self.get_success( worker_to_sync_against.get_presence_handler().user_syncing( - self.user_id, self.device_id, True, PresenceState.ONLINE - ), - by=0.1, + self.user_id, True, PresenceState.ONLINE + ) ) # Check against the main process that the user's presence did not change. @@ -857,21 +840,6 @@ def test_set_presence_from_syncing_keeps_busy( # we should still be busy self.assertEqual(state.state, PresenceState.BUSY) - # Advance such that the device would be discarded if it was not busy, - # then pump so _handle_timeouts function to called. - self.reactor.advance(IDLE_TIMER / 1000) - self.reactor.pump([5]) - - # The account should still be busy. - state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) - self.assertEqual(state.state, PresenceState.BUSY) - - # Ensure that a /presence call can set the user *off* busy. - self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg) - - state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) - self.assertEqual(state.state, PresenceState.ONLINE) - def _set_presencestate_with_status_msg( self, state: str, status_msg: Optional[str] ) -> None: @@ -884,7 +852,6 @@ def _set_presencestate_with_status_msg( self.get_success( self.presence_handler.set_state( self.user_id_obj, - self.device_id, {"presence": state, "status_msg": status_msg}, ) ) @@ -1126,9 +1093,7 @@ def test_remote_joins(self) -> None: # Mark test2 as online, test will be offline with a last_active of 0 self.get_success( self.presence_handler.set_state( - UserID.from_string("@test2:server"), - "dev-1", - {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -1175,9 +1140,7 @@ def test_remote_gets_presence_when_local_user_joins(self) -> None: # Mark test as online self.get_success( self.presence_handler.set_state( - UserID.from_string("@test:server"), - "dev-1", - {"presence": PresenceState.ONLINE}, + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} ) ) @@ -1185,9 +1148,7 @@ def test_remote_gets_presence_when_local_user_joins(self) -> None: # Note we don't join them to the room yet self.get_success( self.presence_handler.set_state( - UserID.from_string("@test2:server"), - "dev-1", - {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f9b292b9ece1..ec2f5d30bea9 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Awaitable, Callable, Dict -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from parameterized import parameterized @@ -26,6 +26,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable class ProfileTestCase(unittest.HomeserverTestCase): @@ -34,7 +35,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.mock_federation = AsyncMock() + self.mock_federation = Mock() self.mock_registry = Mock() self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} @@ -134,7 +135,9 @@ def test_set_my_name_noauth(self) -> None: ) def test_get_other_name(self) -> None: - self.mock_federation.make_query.return_value = {"displayname": "Alice"} + self.mock_federation.make_query.return_value = make_awaitable( + {"displayname": "Alice"} + ) displayname = self.get_success(self.handler.get_displayname(self.alice)) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e9fbf32c7ce9..54eeec228e20 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Collection, List, Optional, Tuple -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -38,6 +38,7 @@ ) from synapse.util import Clock +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -202,22 +203,24 @@ def test_mau_limits_when_disabled(self) -> None: @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self) -> None: - self.store.count_monthly_users = AsyncMock( # type: ignore[method-assign] - return_value=self.hs.config.server.max_mau_value - 1 + self.store.count_monthly_users = Mock( # type: ignore[assignment] + return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self) -> None: - self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users) + self.store.get_monthly_active_count = Mock( + return_value=make_awaitable(self.lots_of_users) + ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) - self.store.get_monthly_active_count = AsyncMock( - return_value=self.hs.config.server.max_mau_value + self.store.get_monthly_active_count = Mock( + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -226,13 +229,15 @@ def test_get_or_create_user_mau_blocked(self) -> None: @override_config({"limit_usage_by_mau": True}) def test_register_mau_blocked(self) -> None: - self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users) + self.store.get_monthly_active_count = Mock( + return_value=make_awaitable(self.lots_of_users) + ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) - self.store.get_monthly_active_count = AsyncMock( - return_value=self.hs.config.server.max_mau_value + self.store.get_monthly_active_count = Mock( + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError @@ -287,7 +292,7 @@ def test_auto_create_auto_join_where_auto_create_is_false(self) -> None: @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: room_alias_str = "#room:test" - self.store.is_real_user = AsyncMock(return_value=False) + self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -299,8 +304,8 @@ def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: room_alias_str = "#room:test" - self.store.count_real_users = AsyncMock(return_value=1) # type: ignore[method-assign] - self.store.is_real_user = AsyncMock(return_value=True) + self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment] + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_directory_handler() @@ -314,8 +319,8 @@ def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> N def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( self, ) -> None: - self.store.count_real_users = AsyncMock(return_value=2) # type: ignore[method-assign] - self.store.is_real_user = AsyncMock(return_value=True) + self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment] + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index 3e28117e2c0f..41199ffa297f 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, patch +from unittest.mock import Mock, patch from twisted.test.proto_helpers import MemoryReactor @@ -16,6 +16,7 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request +from tests.test_utils import make_awaitable from tests.unittest import ( FederatingHomeserverTestCase, HomeserverTestCase, @@ -153,21 +154,25 @@ def test_remote_joins_contribute_to_rate_limit(self) -> None: None, ) - mock_make_membership_event = AsyncMock( - return_value=( - self.OTHER_SERVER_NAME, - join_event, - self.hs.config.server.default_room_version, + mock_make_membership_event = Mock( + return_value=make_awaitable( + ( + self.OTHER_SERVER_NAME, + join_event, + self.hs.config.server.default_room_version, + ) ) ) - mock_send_join = AsyncMock( - return_value=SendJoinResult( - join_event, - self.OTHER_SERVER_NAME, - state=[create_event], - auth_chain=[create_event], - partial_state=False, - servers_in_room=frozenset(), + mock_send_join = Mock( + return_value=make_awaitable( + SendJoinResult( + join_event, + self.OTHER_SERVER_NAME, + state=[create_event], + auth_chain=[create_event], + partial_state=False, + servers_in_room=frozenset(), + ) ) ) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 00f4e181e81a..b5c772a7aedf 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Dict, Optional, Set, Tuple -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import attr @@ -25,6 +25,7 @@ from synapse.types import JsonDict from synapse.util import Clock +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -133,7 +134,7 @@ def test_map_saml_response_to_user(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) @@ -163,7 +164,7 @@ def test_map_saml_response_to_existing_user(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # Map a user via SSO. saml_response = FakeAuthnResponse( @@ -205,11 +206,11 @@ def test_map_saml_response_to_invalid_localpart(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # mock out the error renderer too sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] + sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) request = _mock_request() @@ -226,9 +227,9 @@ def test_map_saml_response_to_user_retries(self) -> None: # stub out the auth handler and error renderer auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] + sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] # register a user to occupy the first-choice MXID store = self.hs.get_datastores().main @@ -311,7 +312,7 @@ def test_attribute_requirements(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # The response doesn't have the proper userGroup or department. saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index a066745d70b8..8b6e4a40b620 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -13,40 +13,19 @@ # limitations under the License. -from typing import Callable, List, Tuple, Type, Union -from unittest.mock import patch +from typing import Callable, List, Tuple from zope.interface import implementer from twisted.internet import defer -from twisted.internet._sslverify import ClientTLSOptions -from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.address import IPv4Address from twisted.internet.defer import ensureDeferred -from twisted.internet.interfaces import IProtocolFactory -from twisted.internet.ssl import ContextFactory from twisted.mail import interfaces, smtp from tests.server import FakeTransport from tests.unittest import HomeserverTestCase, override_config -def TestingESMTPTLSClientFactory( - contextFactory: ContextFactory, - _connectWrapped: bool, - wrappedProtocol: IProtocolFactory, -) -> IProtocolFactory: - """We use this to pass through in testing without using TLS, but - saving the context information to check that it would have happened. - - Note that this is what the MemoryReactor does on connectSSL. - It only saves the contextFactory, but starts the connection with the - underlying Factory. - See: L{twisted.internet.testing.MemoryReactor.connectSSL}""" - - wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined] - return wrappedProtocol - - @implementer(interfaces.IMessageDelivery) class _DummyMessageDelivery: def __init__(self) -> None: @@ -96,13 +75,7 @@ def connectionLost(self) -> None: pass -class SendEmailHandlerTestCaseIPv4(HomeserverTestCase): - ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address - - def setUp(self) -> None: - super().setUp() - self.reactor.lookups["localhost"] = "127.0.0.1" - +class SendEmailHandlerTestCase(HomeserverTestCase): def test_send_email(self) -> None: """Happy-path test that we can send email to a non-TLS server.""" h = self.hs.get_send_email_handler() @@ -116,7 +89,7 @@ def test_send_email(self) -> None: (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ 0 ] - self.assertEqual(host, self.reactor.lookups["localhost"]) + self.assertEqual(host, "localhost") self.assertEqual(port, 25) # wire it up to an SMTP server @@ -132,9 +105,7 @@ def test_send_email(self) -> None: FakeTransport( client_protocol, self.reactor, - peer_address=self.ip_class( - "TCP", self.reactor.lookups["localhost"], 1234 - ), + peer_address=IPv4Address("TCP", "127.0.0.1", 1234), ) ) @@ -147,10 +118,6 @@ def test_send_email(self) -> None: self.assertEqual(str(user), "foo@bar.com") self.assertIn(b"Subject: test subject", msg) - @patch( - "synapse.handlers.send_email.TLSMemoryBIOFactory", - TestingESMTPTLSClientFactory, - ) @override_config( { "email": { @@ -168,23 +135,17 @@ def test_send_email_force_tls(self) -> None: ) ) # there should be an attempt to connect to localhost:465 - self.assertEqual(len(self.reactor.tcpClients), 1) + self.assertEqual(len(self.reactor.sslClients), 1) ( host, port, client_factory, + contextFactory, _timeout, _bindAddress, - ) = self.reactor.tcpClients[0] - self.assertEqual(host, self.reactor.lookups["localhost"]) + ) = self.reactor.sslClients[0] + self.assertEqual(host, "localhost") self.assertEqual(port, 465) - # We need to make sure that TLS is happenning - self.assertIsInstance( - client_factory._wrappedFactory._testingContextFactory, - ClientTLSOptions, - ) - # And since we use endpoints, they go through reactor.connectTCP - # which works differently to connectSSL on the testing reactor # wire it up to an SMTP server message_delivery = _DummyMessageDelivery() @@ -199,9 +160,7 @@ def test_send_email_force_tls(self) -> None: FakeTransport( client_protocol, self.reactor, - peer_address=self.ip_class( - "TCP", self.reactor.lookups["localhost"], 1234 - ), + peer_address=IPv4Address("TCP", "127.0.0.1", 1234), ) ) @@ -213,11 +172,3 @@ def test_send_email_force_tls(self) -> None: user, msg = message_delivery.messages.pop() self.assertEqual(str(user), "foo@bar.com") self.assertIn(b"Subject: test subject", msg) - - -class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4): - ip_class = IPv6Address - - def setUp(self) -> None: - super().setUp() - self.reactor.lookups["localhost"] = "::1" diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 948d04fc323f..9f035a02dc69 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import MagicMock, Mock, patch from twisted.test.proto_helpers import MemoryReactor @@ -29,6 +29,7 @@ import tests.unittest import tests.utils +from tests.test_utils import make_awaitable class SyncTestCase(tests.unittest.HomeserverTestCase): @@ -252,8 +253,8 @@ def test_ban_wins_race_with_join(self) -> None: mocked_get_prev_events = patch.object( self.hs.get_datastores().main, "get_prev_events_for_room", - new_callable=AsyncMock, - return_value=[last_room_creation_event_id], + new_callable=MagicMock, + return_value=make_awaitable([last_room_creation_event_id]), ) with mocked_get_prev_events: self.helper.join(room_id, eve, tok=eve_token) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 2a295da3a0b7..5da1d95f0b22 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -15,7 +15,7 @@ import json from typing import Dict, List, Set -from unittest.mock import ANY, AsyncMock, Mock, call +from unittest.mock import ANY, Mock, call from netaddr import IPSet @@ -33,6 +33,7 @@ from tests import unittest from tests.server import ThreadedMemoryReactorClock +from tests.test_utils import make_awaitable from tests.unittest import override_config # Some local users to test with @@ -73,11 +74,11 @@ def make_homeserver( # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) - mock_keyring.verify_json_for_server = AsyncMock(return_value=True) + mock_keyring.verify_json_for_server.return_value = make_awaitable(True) # we mock out the federation client too - self.mock_federation_client = AsyncMock(spec=["put_json"]) - self.mock_federation_client.put_json.return_value = (200, "OK") + self.mock_federation_client = Mock(spec=["put_json"]) + self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) self.mock_federation_client.agent = MatrixFederationAgent( reactor, tls_client_options_factory=None, @@ -120,18 +121,20 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.datastore = hs.get_datastores().main - self.datastore.get_destination_retry_timings = AsyncMock(return_value=None) + self.datastore.get_destination_retry_timings = Mock( + return_value=make_awaitable(None) + ) - self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[method-assign] - return_value=(0, []) + self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment] + return_value=make_awaitable((0, [])) ) - self.datastore.get_destination_last_successful_stream_ordering = AsyncMock( # type: ignore[method-assign] - return_value=None + self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment] + return_value=make_awaitable(None) ) - self.datastore.get_received_txn_response = AsyncMock( # type: ignore[method-assign] - return_value=None + self.datastore.get_received_txn_response = Mock( # type: ignore[assignment] + return_value=make_awaitable(None) ) self.room_members: List[UserID] = [] @@ -143,25 +146,25 @@ async def check_user_in_room(room_id: str, requester: Requester) -> None: raise AuthError(401, "User is not in the room") return None - hs.get_auth().check_user_in_room = Mock( # type: ignore[method-assign] + hs.get_auth().check_user_in_room = Mock( # type: ignore[assignment] side_effect=check_user_in_room ) async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID - hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[method-assign] + hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[assignment] side_effect=check_host_in_room ) async def get_current_hosts_in_room(room_id: str) -> Set[str]: return {member.domain for member in self.room_members} - hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[method-assign] + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] side_effect=get_current_hosts_in_room ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[method-assign] + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[assignment] side_effect=get_current_hosts_in_room ) @@ -170,25 +173,27 @@ async def get_users_in_room(room_id: str) -> Set[str]: self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room) - self.datastore.get_user_directory_stream_pos = AsyncMock( # type: ignore[method-assign] - # we deliberately return a non-None stream pos to avoid - # doing an initial_sync - return_value=1 + self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment] + side_effect=( + # we deliberately return a non-None stream pos to avoid + # doing an initial_sync + lambda: make_awaitable(1) + ) ) - self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[method-assign] + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment] - self.datastore.get_to_device_stream_token = Mock( # type: ignore[method-assign] - return_value=0 + self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment] + side_effect=lambda: 0 ) - self.datastore.get_new_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign] - return_value=([], 0) + self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kargs: make_awaitable(([], 0)) ) - self.datastore.delete_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign] - return_value=None + self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kargs: make_awaitable(None) ) - self.datastore.set_received_txn_response = AsyncMock( # type: ignore[method-assign] - return_value=None + self.datastore.set_received_txn_response = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kwargs: make_awaitable(None) ) def test_started_typing_local(self) -> None: diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index b5f15aa7d425..430209705e23 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Tuple -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock, patch from urllib.parse import quote from twisted.test.proto_helpers import MemoryReactor @@ -30,7 +30,7 @@ from tests import unittest from tests.storage.test_user_directory import GetUserDirectoryTables -from tests.test_utils import event_injection +from tests.test_utils import event_injection, make_awaitable from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config @@ -471,7 +471,7 @@ def test_handle_user_deactivated_regular_user(self) -> None: self.store.register_user(user_id=r_user_id, password_hash=None) ) - mock_remove_from_user_dir = AsyncMock(return_value=None) + mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 0d17f2fe5be4..6a0b5fc0bd56 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -14,8 +14,8 @@ import base64 import logging import os -from typing import Generator, List, Optional, cast -from unittest.mock import AsyncMock, patch +from typing import Any, Awaitable, Callable, Generator, List, Optional, cast +from unittest.mock import Mock, patch import treq from netaddr import IPSet @@ -41,7 +41,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent -from synapse.http.federation.srv_resolver import Server, SrvResolver +from synapse.http.federation.srv_resolver import Server from synapse.http.federation.well_known_resolver import ( WELL_KNOWN_MAX_SIZE, WellKnownResolver, @@ -68,11 +68,21 @@ logger = logging.getLogger(__name__) +# Once Async Mocks or lambdas are supported this can go away. +def generate_resolve_service( + result: List[Server], +) -> Callable[[Any], Awaitable[List[Server]]]: + async def resolve_service(_: Any) -> List[Server]: + return result + + return resolve_service + + class MatrixFederationAgentTests(unittest.TestCase): def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() - self.mock_resolver = AsyncMock(spec=SrvResolver) + self.mock_resolver = Mock() config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] @@ -626,7 +636,7 @@ def test_get_hostname_bad_cert(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.return_value = [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar") @@ -712,7 +722,7 @@ def test_get_no_srv_no_well_known(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.return_value = [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") @@ -766,7 +776,7 @@ def test_get_well_known(self) -> None: """Test the behaviour when the .well-known delegates elsewhere""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.return_value = [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -830,7 +840,7 @@ def test_get_well_known_redirect(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.return_value = [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -920,7 +930,7 @@ def test_get_invalid_well_known(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.return_value = [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") @@ -976,7 +986,7 @@ def test_get_well_known_unsigned_cert(self) -> None: # the config left to the default, which will not trust it (since the # presented cert is signed by a test CA) - self.mock_resolver.resolve_service.return_value = [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" config = default_config("test", parse=True) @@ -1027,9 +1037,9 @@ def test_get_hostname_srv(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.return_value = [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") @@ -1084,9 +1094,9 @@ def test_get_well_known_srv(self) -> None: self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - self.mock_resolver.resolve_service.return_value = [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self._handle_well_known_connection( client_factory, @@ -1127,7 +1137,7 @@ def test_idna_servername(self) -> None: """test the behaviour when the server name has idna chars in""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.return_value = [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" @@ -1191,9 +1201,9 @@ def test_idna_srv_target(self) -> None: """test the behaviour when the target of a SRV record has idna chars""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.return_value = [ - Server(host=b"xn--trget-3qa.com", port=8443) - ] # tĂ¢rget.com + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"xn--trget-3qa.com", port=8443)] # tĂ¢rget.com + ) self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request( @@ -1397,10 +1407,12 @@ def test_srv_fallbacks(self) -> None: """Test that other SRV results are tried if the first one fails.""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.return_value = [ - Server(host=b"target.com", port=8443), - Server(host=b"target.com", port=8444), - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + ) self.reactor.lookups["target.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index c379853e20ef..fa27f1279a95 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -164,7 +164,7 @@ def test_with_request_context(self) -> None: # Call requestReceived to finish instantiating the object. request.content = BytesIO() # Partially skip some internal processing of SynapseRequest. - request._started_processing = Mock() # type: ignore[method-assign] + request._started_processing = Mock() # type: ignore[assignment] request.request_metrics = Mock(spec=["name"]) with patch.object(Request, "render"): request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 172fc3a736df..fe631d7ecbd8 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -33,6 +33,7 @@ from tests.events.test_presence_router import send_presence_update, sync_presence from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.test_utils import simple_async_mock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config @@ -69,7 +70,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. self.fed_transport_client = Mock(spec=["send_transaction"]) - self.fed_transport_client.send_transaction = AsyncMock(return_value={}) + self.fed_transport_client.send_transaction = simple_async_mock({}) return self.setup_test_homeserver( federation_transport_client=self.fed_transport_client, @@ -233,7 +234,7 @@ def test_get_user_ip_and_agents__no_user_found(self) -> None: def test_sending_events_into_room(self) -> None: """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent - self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[method-assign] + self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment] spec=[], side_effect=self.event_creation_handler.create_and_send_nonmember_event, ) @@ -578,8 +579,10 @@ def test_update_room_membership_remote_join(self) -> None: """Test that the module API can join a remote room.""" # Necessary to fake a remote join. fake_stream_id = 1 - mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id)) - self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[method-assign] + mocked_remote_join = simple_async_mock( + return_value=("fake-event-id", fake_stream_id) + ) + self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] fake_remote_host = f"{self.module_api.server_name}-remote" # Given that the join is to be faked, we expect the relevant join event not to diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 7c23b77e0a11..829b9df83d4e 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Optional -from unittest.mock import AsyncMock, patch +from unittest.mock import patch from parameterized import parameterized @@ -28,6 +28,7 @@ from synapse.types import JsonDict, create_requester from synapse.util import Clock +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -190,7 +191,7 @@ def test_action_for_event_by_user_disabled_by_config(self) -> None: # Mock the method which calculates push rules -- we do this instead of # e.g. checking the results in the database because we want to ensure # that code isn't even running. - bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[method-assign] + bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment] # Ensure no actions are generated! self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) @@ -381,6 +382,7 @@ def test_room_mentions(self) -> None: ) ) + @override_config({"experimental_features": {"msc3958_supress_edit_notifs": True}}) def test_suppress_edits(self) -> None: """Under the default push rules, event edits should not generate notifications.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py index af25815fa56e..f7c6417a09fd 100644 --- a/tests/replication/storage/test_events.py +++ b/tests/replication/storage/test_events.py @@ -58,7 +58,7 @@ def patch__eq__(cls: object) -> Callable[[], None]: def unpatch() -> None: if eq is not None: - cls.__eq__ = eq # type: ignore[method-assign] + cls.__eq__ = eq # type: ignore[assignment] return unpatch diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 9b28cd474fbf..a324b4d31dde 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from netaddr import IPSet @@ -26,6 +26,7 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import get_clock +from tests.test_utils import make_awaitable logger = logging.getLogger(__name__) @@ -61,7 +62,7 @@ def test_send_event_single_sender(self) -> None: new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json = AsyncMock(return_value={}) + mock_client.put_json.return_value = make_awaitable({}) mock_client.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", @@ -92,7 +93,7 @@ def test_send_event_sharded(self) -> None: new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json = AsyncMock(return_value={}) + mock_client1.put_json.return_value = make_awaitable({}) mock_client1.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", @@ -107,7 +108,7 @@ def test_send_event_sharded(self) -> None: ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json = AsyncMock(return_value={}) + mock_client2.put_json.return_value = make_awaitable({}) mock_client2.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", @@ -161,7 +162,7 @@ def test_send_typing_sharded(self) -> None: new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json = AsyncMock(return_value={}) + mock_client1.put_json.return_value = make_awaitable({}) mock_client1.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", @@ -176,7 +177,7 @@ def test_send_typing_sharded(self) -> None: ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json = AsyncMock(return_value={}) + mock_client2.put_json.return_value = make_awaitable({}) mock_client2.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 2f6bd0d74faa..feb81844aee9 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -18,7 +18,7 @@ import urllib.parse from binascii import unhexlify from typing import List, Optional -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock, patch from parameterized import parameterized, parameterized_class @@ -45,7 +45,7 @@ from tests import unittest from tests.server import FakeSite, make_request -from tests.test_utils import SMALL_PNG +from tests.test_utils import SMALL_PNG, make_awaitable from tests.unittest import override_config @@ -71,8 +71,8 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs.config.registration.registration_shared_secret = "shared" - self.hs.get_media_repository = Mock() # type: ignore[method-assign] - self.hs.get_deactivate_account_handler = Mock() # type: ignore[method-assign] + self.hs.get_media_repository = Mock() # type: ignore[assignment] + self.hs.get_deactivate_account_handler = Mock() # type: ignore[assignment] return self.hs @@ -419,8 +419,8 @@ def test_register_mau_limit_reached(self) -> None: store = self.hs.get_datastores().main # Set monthly active users to the limit - store.get_monthly_active_count = AsyncMock( - return_value=self.hs.config.server.max_mau_value + store.get_monthly_active_count = Mock( + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1834,8 +1834,8 @@ def test_create_user_mau_limit_reached_active_admin(self) -> None: ) # Set monthly active users to the limit - self.store.get_monthly_active_count = AsyncMock( - return_value=self.hs.config.server.max_mau_value + self.store.get_monthly_active_count = Mock( + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1871,8 +1871,8 @@ def test_create_user_mau_limit_reached_passive_admin(self) -> None: handler = self.hs.get_registration_handler() # Set monthly active users to the limit - self.store.get_monthly_active_count = AsyncMock( - return_value=self.hs.config.server.max_mau_value + self.store.get_monthly_active_count = Mock( + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 4c69d224b81a..6c04e6c56cc2 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -50,7 +50,7 @@ async def check_username( ) handler = self.hs.get_registration_handler() - handler.check_username = check_username # type: ignore[method-assign] + handler.check_username = check_username # type: ignore[assignment] def test_username_available(self) -> None: """ diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e9f495e20671..ac19f3c6daef 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1346,7 +1346,7 @@ async def post_json( return {} # Register a mock that will return the expected result depending on the remote. - self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[method-assign] + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment] # Check that we've got the correct response from the client-side endpoint. self._test_status( diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py index 481db9a687c3..d5b0640e7aec 100644 --- a/tests/rest/client/test_account_data.py +++ b/tests/rest/client/test_account_data.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock +from unittest.mock import Mock from synapse.rest import admin from synapse.rest.client import account_data, login, room from tests import unittest +from tests.test_utils import make_awaitable class AccountDataTestCase(unittest.HomeserverTestCase): @@ -31,7 +32,7 @@ def test_on_account_data_updated_callback(self) -> None: """Tests that the on_account_data_updated module callback is called correctly when a user's account data changes. """ - mocked_callback = AsyncMock(return_value=None) + mocked_callback = Mock(return_value=make_awaitable(None)) self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append( mocked_callback ) diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 141e0f57a33b..54df2a252c01 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -45,7 +45,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(config=config) - hs.get_federation_handler = Mock() # type: ignore[method-assign] + hs.get_federation_handler = Mock() # type: ignore[assignment] return hs diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 90a8df147c7c..a2d5d340be35 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -65,14 +65,14 @@ def test_add_filter_for_other_user(self) -> None: def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False # type: ignore[method-assign] + self.hs.is_mine = lambda target_user: False # type: ignore[assignment] channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.hs.is_mine = _is_mine # type: ignore[method-assign] + self.hs.is_mine = _is_mine # type: ignore[assignment] self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index a2a65895647f..ffbc13bb8df3 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -169,8 +169,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well "account": {"per_second": 10000, "burst_count": 10000}, - }, - "experimental_features": {"msc4041_enabled": True}, + } } ) def test_POST_ratelimiting_per_address(self) -> None: @@ -190,15 +189,12 @@ def test_POST_ratelimiting_per_address(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) - retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertLess(retry_after_ms, 6000) - assert retry_header - self.assertLessEqual(int(retry_header[0]), 6) + self.assertTrue(retry_after_ms < 6000) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -221,8 +217,7 @@ def test_POST_ratelimiting_per_address(self) -> None: # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well "address": {"per_second": 10000, "burst_count": 10000}, - }, - "experimental_features": {"msc4041_enabled": True}, + } } ) def test_POST_ratelimiting_per_account(self) -> None: @@ -239,15 +234,12 @@ def test_POST_ratelimiting_per_account(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) - retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertLess(retry_after_ms, 6000) - assert retry_header - self.assertLessEqual(int(retry_header[0]), 6) + self.assertTrue(retry_after_ms < 6000) self.reactor.advance(retry_after_ms / 1000.0) @@ -270,8 +262,7 @@ def test_POST_ratelimiting_per_account(self) -> None: # rc_login dict here, we need to set this manually as well "address": {"per_second": 10000, "burst_count": 10000}, "failed_attempts": {"per_second": 0.17, "burst_count": 5}, - }, - "experimental_features": {"msc4041_enabled": True}, + } } ) def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: @@ -288,15 +279,12 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) - retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 403, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertLess(retry_after_ms, 6000) - assert retry_header - self.assertLessEqual(int(retry_header[0]), 6) + self.assertTrue(retry_after_ms < 6000) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -581,9 +569,8 @@ def test_spam_checker_deny(self) -> None: body, ) self.assertEqual(channel.code, 403, channel.result) - self.assertLessEqual( - {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}.items(), - channel.json_body.items(), + self.assertDictContainsSubset( + {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body ) diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py index 41ceb3db51a4..700f6587a007 100644 --- a/tests/rest/client/test_notifications.py +++ b/tests/rest/client/test_notifications.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -20,6 +20,7 @@ from synapse.server import HomeServer from synapse.util import Clock +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase @@ -44,7 +45,7 @@ def prepare( def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) - fed_transport_client.send_transaction = AsyncMock(return_value={}) + fed_transport_client.send_transaction = simple_async_mock({}) return self.setup_test_homeserver( federation_transport_client=fed_transport_client, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 66b387cea37e..e12098102b96 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -23,6 +23,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable class PresenceTestCase(unittest.HomeserverTestCase): @@ -35,7 +36,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.presence_handler = Mock(spec=PresenceHandler) - self.presence_handler.set_state = AsyncMock(return_value=None) + self.presence_handler.set_state.return_value = make_awaitable(None) hs = self.setup_test_homeserver( "red", diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index c33393dc284b..b228dba8613d 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -75,7 +75,7 @@ def test_POST_appservice_registration_valid(self) -> None: self.assertEqual(channel.code, 200, msg=channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} - self.assertLessEqual(det_data.items(), channel.json_body.items()) + self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_appservice_registration_no_type(self) -> None: as_token = "i_am_an_app_service" @@ -136,7 +136,7 @@ def test_POST_user_valid(self) -> None: "device_id": device_id, } self.assertEqual(channel.code, 200, msg=channel.result) - self.assertLessEqual(det_data.items(), channel.json_body.items()) + self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) def test_POST_disabled_registration(self) -> None: @@ -157,7 +157,7 @@ def test_POST_guest_registration(self) -> None: det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEqual(channel.code, 200, msg=channel.result) - self.assertLessEqual(det_data.items(), channel.json_body.items()) + self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self) -> None: self.hs.config.registration.allow_guest_access = False @@ -267,7 +267,7 @@ def test_POST_registration_requires_token(self) -> None: "device_id": device_id, } self.assertEqual(channel.code, 200, msg=channel.result) - self.assertLessEqual(det_data.items(), channel.json_body.items()) + self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 res = self.get_success( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 61773fb28c32..9bfe913e451e 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import urllib.parse from typing import Any, Callable, Dict, List, Optional, Tuple -from unittest.mock import AsyncMock, patch +from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -28,6 +28,7 @@ from tests import unittest from tests.server import FakeChannel +from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_event from tests.unittest import override_config @@ -263,8 +264,7 @@ def test_ignore_invalid_room(self) -> None: # Disable the validation to pretend this came over federation. with patch( "synapse.handlers.message.EventCreationHandler._validate_event_relation", - new_callable=AsyncMock, - return_value=None, + new=lambda self, event: make_awaitable(None), ): # Generate a various relations from a different room. self.get_success( @@ -570,7 +570,7 @@ def test_edit_reply(self) -> None: ) self.assertEqual(200, channel.code, channel.json_body) event_result = channel.json_body - self.assertLessEqual(original_body.items(), event_result["content"].items()) + self.assertDictContainsSubset(original_body, event_result["content"]) # also check /context, which returns the *edited* event channel = self.make_request( @@ -587,14 +587,14 @@ def test_edit_reply(self) -> None: (context_result, "/context"), ): # The reference metadata should still be intact. - self.assertLessEqual( + self.assertDictContainsSubset( { "m.relates_to": { "event_id": self.parent_id, "rel_type": "m.reference", } - }.items(), - result_event_dict["content"].items(), + }, + result_event_dict["content"], desc, ) @@ -1300,8 +1300,7 @@ def test_nested_thread(self) -> None: # not an event the Client-Server API will allow.. with patch( "synapse.handlers.message.EventCreationHandler._validate_event_relation", - new_callable=AsyncMock, - return_value=None, + new=lambda self, event: make_awaitable(None), ): # Create a sub-thread off the thread, which is not allowed. self._send_relation( @@ -1372,11 +1371,9 @@ def test_thread_edit_latest_event(self) -> None: latest_event_in_thread = thread_summary["latest_event"] # The latest event in the thread should have the edit appear under the # bundled aggregations. - self.assertLessEqual( - {"event_id": edit_event_id, "sender": "@alice:test"}.items(), - latest_event_in_thread["unsigned"]["m.relations"][ - RelationTypes.REPLACE - ].items(), + self.assertDictContainsSubset( + {"event_id": edit_event_id, "sender": "@alice:test"}, + latest_event_in_thread["unsigned"]["m.relations"][RelationTypes.REPLACE], ) def test_aggregation_get_event_for_annotation(self) -> None: @@ -1639,9 +1636,9 @@ def test_redact_relation_thread(self) -> None: ################################################## self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertLessEqual( - {"count": 3, "current_user_participated": True}.items(), - relations[RelationTypes.THREAD].items(), + self.assertDictContainsSubset( + {"count": 3, "current_user_participated": True}, + relations[RelationTypes.THREAD], ) # The latest event is the last sent event. self.assertEqual( @@ -1660,9 +1657,9 @@ def test_redact_relation_thread(self) -> None: # The thread should still exist, but the latest event should be updated. self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertLessEqual( - {"count": 2, "current_user_participated": True}.items(), - relations[RelationTypes.THREAD].items(), + self.assertDictContainsSubset( + {"count": 2, "current_user_participated": True}, + relations[RelationTypes.THREAD], ) # And the latest event is the last unredacted event. self.assertEqual( @@ -1679,9 +1676,9 @@ def test_redact_relation_thread(self) -> None: # Nothing should have changed (except the thread count). self.assertEqual(self._get_related_events(), thread_replies) relations = self._get_bundled_aggregations() - self.assertLessEqual( - {"count": 1, "current_user_participated": True}.items(), - relations[RelationTypes.THREAD].items(), + self.assertDictContainsSubset( + {"count": 1, "current_user_participated": True}, + relations[RelationTypes.THREAD], ) # And the latest event is the last unredacted event. self.assertEqual( @@ -1776,12 +1773,12 @@ def test_redact_parent_thread(self) -> None: event_ids = self._get_related_events() relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) - self.assertLessEqual( + self.assertDictContainsSubset( { "count": 1, "current_user_participated": True, - }.items(), - relations[RelationTypes.THREAD].items(), + }, + relations[RelationTypes.THREAD], ) self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 47c1d38ad7dd..88e579dc393f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -20,7 +20,7 @@ import json from http import HTTPStatus from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from unittest.mock import AsyncMock, Mock, call, patch +from unittest.mock import Mock, call, patch from urllib import parse as urlparse from parameterized import param, parameterized @@ -52,6 +52,7 @@ from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase +from tests.test_utils import make_awaitable from tests.test_utils.event_injection import create_event from tests.unittest import override_config @@ -68,15 +69,15 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: "red", ) - self.hs.get_federation_handler = Mock() # type: ignore[method-assign] - self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock( - return_value=None + self.hs.get_federation_handler = Mock() # type: ignore[assignment] + self.hs.get_federation_handler.return_value.maybe_backfill = Mock( + return_value=make_awaitable(None) ) async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[method-assign] + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] return self.hs @@ -2374,7 +2375,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - return self.setup_test_homeserver(federation_client=AsyncMock()) + return self.setup_test_homeserver(federation_client=Mock()) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.register_user("user", "pass") @@ -2384,7 +2385,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.return_value = {} # type: ignore[attr-defined] + self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} @@ -2412,7 +2413,7 @@ def test_fallback(self) -> None: # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), - {}, + make_awaitable({}), ) search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} @@ -3412,17 +3413,17 @@ def test_threepid_invite_spamcheck_deprecated(self) -> None: # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign] - self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign] - return_value=None, + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] + self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] + return_value=make_awaitable(None), ) # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # allow everything for now. # `spec` argument is needed for this function mock to have `__qualname__`, which # is needed for `Measure` metrics buried in SpamChecker. - mock = AsyncMock(return_value=True, spec=lambda *x: None) + mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None) self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( mock ) @@ -3450,7 +3451,7 @@ def test_threepid_invite_spamcheck_deprecated(self) -> None: # Now change the return value of the callback to deny any invite and test that # we can't send the invite. - mock.return_value = False + mock.return_value = make_awaitable(False) channel = self.make_request( method="POST", path="/rooms/" + self.room_id + "/invite", @@ -3476,18 +3477,18 @@ def test_threepid_invite_spamcheck(self) -> None: # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign] - self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign] - return_value=None, + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] + self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] + return_value=make_awaitable(None), ) # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # allow everything for now. # `spec` argument is needed for this function mock to have `__qualname__`, which # is needed for `Measure` metrics buried in SpamChecker. - mock = AsyncMock( - return_value=synapse.module_api.NOT_SPAM, + mock = Mock( + return_value=make_awaitable(synapse.module_api.NOT_SPAM), spec=lambda *x: None, ) self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( @@ -3518,7 +3519,7 @@ def test_threepid_invite_spamcheck(self) -> None: # Now change the return value of the callback to deny any invite and test that # we can't send the invite. We pick an arbitrary error code to be able to check # that the same code has been returned - mock.return_value = Codes.CONSENT_NOT_GIVEN + mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) channel = self.make_request( method="POST", path="/rooms/" + self.room_id + "/invite", @@ -3537,7 +3538,7 @@ def test_threepid_invite_spamcheck(self) -> None: make_invite_mock.assert_called_once() # Run variant with `Tuple[Codes, dict]`. - mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"}) + mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"})) channel = self.make_request( method="POST", path="/rooms/" + self.room_id + "/invite", diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 9aecf88e4160..8d2cdf875150 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -84,7 +84,7 @@ def test_invite(self) -> None: def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() - identity_handler.lookup_3pid = Mock( # type: ignore[method-assign] + identity_handler.lookup_3pid = Mock( # type: ignore[assignment] side_effect=AssertionError("This should not get called") ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 57eb713b150a..e5ba5a970639 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -13,7 +13,7 @@ # limitations under the License. import threading from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -33,6 +33,7 @@ from synapse.util.frozenutils import unfreeze from tests import unittest +from tests.test_utils import make_awaitable if TYPE_CHECKING: from synapse.module_api import ModuleApi @@ -117,7 +118,7 @@ async def approve_all_signature_checking( async def _check_event_auth(origin: Any, event: Any, context: Any) -> None: pass - hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[method-assign] + hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] return hs @@ -476,7 +477,7 @@ async def test_fn( def test_on_new_event(self) -> None: """Test that the on_new_event callback is called on new events""" - on_new_event = AsyncMock(return_value=None) + on_new_event = Mock(make_awaitable(None)) self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append( on_new_event ) @@ -579,7 +580,7 @@ def test_on_profile_update(self) -> None: avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" # Register a mock callback. - m = AsyncMock(return_value=None) + m = Mock(return_value=make_awaitable(None)) self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( m ) @@ -640,7 +641,7 @@ def test_on_profile_update_admin(self) -> None: avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" # Register a mock callback. - m = AsyncMock(return_value=None) + m = Mock(return_value=make_awaitable(None)) self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( m ) @@ -681,7 +682,7 @@ def test_on_user_deactivation_status_changed(self) -> None: correctly when processing a user's deactivation. """ # Register a mocked callback. - deactivation_mock = AsyncMock(return_value=None) + deactivation_mock = Mock(return_value=make_awaitable(None)) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._on_user_deactivation_status_changed_callbacks.append( deactivation_mock, @@ -689,7 +690,7 @@ def test_on_user_deactivation_status_changed(self) -> None: # Also register a mocked callback for profile updates, to check that the # deactivation code calls it in a way that let modules know the user is being # deactivated. - profile_mock = AsyncMock(return_value=None) + profile_mock = Mock(return_value=make_awaitable(None)) self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( profile_mock, ) @@ -739,7 +740,7 @@ def test_on_user_deactivation_status_changed_admin(self) -> None: well as a reactivation. """ # Register a mock callback. - m = AsyncMock(return_value=None) + m = Mock(return_value=make_awaitable(None)) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._on_user_deactivation_status_changed_callbacks.append(m) @@ -793,7 +794,7 @@ def test_check_can_deactivate_user(self) -> None: correctly when processing a user's deactivation. """ # Register a mocked callback. - deactivation_mock = AsyncMock(return_value=False) + deactivation_mock = Mock(return_value=make_awaitable(False)) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._check_can_deactivate_user_callbacks.append( deactivation_mock, @@ -839,7 +840,7 @@ def test_check_can_deactivate_user_admin(self) -> None: correctly when processing a user's deactivation triggered by a server admin. """ # Register a mocked callback. - deactivation_mock = AsyncMock(return_value=False) + deactivation_mock = Mock(return_value=make_awaitable(False)) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._check_can_deactivate_user_callbacks.append( deactivation_mock, @@ -878,7 +879,7 @@ def test_check_can_shutdown_room(self) -> None: correctly when processing an admin's shutdown room request. """ # Register a mocked callback. - shutdown_mock = AsyncMock(return_value=False) + shutdown_mock = Mock(return_value=make_awaitable(False)) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._check_can_shutdown_room_callbacks.append( shutdown_mock, @@ -914,7 +915,7 @@ def test_on_threepid_bind(self) -> None: associating a 3PID to an account. """ # Register a mocked callback. - threepid_bind_mock = AsyncMock(return_value=None) + threepid_bind_mock = Mock(return_value=make_awaitable(None)) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock) @@ -956,9 +957,11 @@ def test_on_add_and_remove_user_third_party_identifier(self) -> None: just before associating and removing a 3PID to/from an account. """ # Pretend to be a Synapse module and register both callbacks as mocks. - on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None) - on_remove_user_third_party_identifier_callback_mock = AsyncMock( - return_value=None + on_add_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + on_remove_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) ) self.hs.get_module_api().register_third_party_rules_callbacks( on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock, @@ -1018,8 +1021,8 @@ def test_on_remove_user_third_party_identifier_is_called_on_deactivate( when a user is deactivated and their third-party ID associations are deleted. """ # Pretend to be a Synapse module and register both callbacks as mocks. - on_remove_user_third_party_identifier_callback_mock = AsyncMock( - return_value=None + on_remove_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) ) self.hs.get_module_api().register_third_party_rules_callbacks( on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock, diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 951a3cbc43e9..d8dc56261ac1 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -14,7 +14,7 @@ from http import HTTPStatus from typing import Any, Generator, Tuple, cast -from unittest.mock import AsyncMock, Mock, call +from unittest.mock import Mock, call from twisted.internet import defer, reactor as _reactor @@ -24,6 +24,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import MockClock reactor = cast(ISynapseReactor, _reactor) @@ -52,7 +53,7 @@ def setUp(self) -> None: def test_executes_given_function( self, ) -> Generator["defer.Deferred[Any]", object, None]: - cb = AsyncMock(return_value=self.mock_http_response) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) res = yield self.cache.fetch_or_execute_request( self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg" ) @@ -63,7 +64,7 @@ def test_executes_given_function( def test_deduplicates_based_on_key( self, ) -> Generator["defer.Deferred[Any]", object, None]: - cb = AsyncMock(return_value=self.mock_http_response) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute_request( self.mock_request, @@ -167,7 +168,7 @@ def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": @defer.inlineCallbacks def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: - cb = AsyncMock(return_value=self.mock_http_response) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) yield self.cache.fetch_or_execute_request( self.mock_request, self.mock_requester, cb, "an arg" ) diff --git a/tests/server.py b/tests/server.py index 08633fe640f4..ff03d2886476 100644 --- a/tests/server.py +++ b/tests/server.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import hashlib -import ipaddress import json import logging import os @@ -46,7 +45,7 @@ from typing_extensions import ParamSpec from zope.interface import implementer -from twisted.internet import address, tcp, threads, udp +from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError @@ -568,8 +567,6 @@ def connectTCP( conn = super().connectTCP( host, port, factory, timeout=timeout, bindAddress=None ) - if self.lookups and host in self.lookups: - validate_connector(conn, self.lookups[host]) callback = self._tcp_callbacks.get((host, port)) if callback: @@ -602,55 +599,6 @@ def advance(self, amount: float) -> None: super().advance(0) -def validate_connector(connector: tcp.Connector, expected_ip: str) -> None: - """Try to validate the obtained connector as it would happen when - synapse is running and the conection will be established. - - This method will raise a useful exception when necessary, else it will - just do nothing. - - This is in order to help catch quirks related to reactor.connectTCP, - since when called directly, the connector's destination will be of type - IPv4Address, with the hostname as the literal host that was given (which - could be an IPv6-only host or an IPv6 literal). - - But when called from reactor.connectTCP *through* e.g. an Endpoint, the - connector's destination will contain the specific IP address with the - correct network stack class. - - Note that testing code paths that use connectTCP directly should not be - affected by this check, unless they specifically add a test with a - matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of - type ThreadedMemoryReactorClock. - For an example of implementing such tests, see test/handlers/send_email.py. - """ - destination = connector.getDestination() - - # We use address.IPv{4,6}Address to check what the reactor thinks it is - # is sending but check for validity with ipaddress.IPv{4,6}Address - # because they fail with IPs on the wrong network stack. - cls_mapping = { - address.IPv4Address: ipaddress.IPv4Address, - address.IPv6Address: ipaddress.IPv6Address, - } - - cls = cls_mapping.get(destination.__class__) - - if cls is not None: - try: - cls(expected_ip) - except Exception as exc: - raise ValueError( - "Invalid IP type and resolution for %s. Expected %s to be %s" - % (destination, expected_ip, cls.__name__) - ) from exc - else: - raise ValueError( - "Unknown address type %s for %s" - % (destination.__class__.__name__, destination) - ) - - class ThreadPool: """ Threadless thread pool. @@ -722,7 +670,7 @@ def runInteraction( **kwargs, ) - pool.runWithConnection = runWithConnection # type: ignore[method-assign] + pool.runWithConnection = runWithConnection # type: ignore[assignment] pool.runInteraction = runInteraction # type: ignore[assignment] # Replace the thread pool with a threadless 'thread' pool pool.threadpool = ThreadPool(clock._reactor) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 17f428bfc5e5..d2bfa53eda49 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -29,6 +29,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import default_config @@ -68,22 +69,24 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: assert isinstance(rlsn, ResourceLimitsServerNotices) self._rlsn = rlsn - self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=1000) - self._rlsn._server_notices_manager.send_notice = AsyncMock( # type: ignore[method-assign] - return_value=Mock() + self._rlsn._store.user_last_seen_monthly_active = Mock( + return_value=make_awaitable(1000) + ) + self._rlsn._server_notices_manager.send_notice = Mock( # type: ignore[assignment] + return_value=make_awaitable(Mock()) ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.user_id = "@user_id:test" - self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = ( - AsyncMock(return_value="!something:localhost") + self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( + return_value=make_awaitable("!something:localhost") ) - self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = AsyncMock( - return_value="!something:localhost" + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock( + return_value=make_awaitable("!something:localhost") ) - self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[method-assign] - self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[method-assign] + self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self) -> None: @@ -100,14 +103,14 @@ def test_maybe_send_server_notice_to_user_flag_off(self) -> None: def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] - return_value=None + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] + return_value=make_awaitable(None) ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] - return_value={"123": mock_event} + self._rlsn._store.get_events = Mock( # type: ignore[assignment] + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,16 +125,16 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> No """ Test when user has blocked notice, but notice ought to be there (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] - return_value=None, + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] + return_value=make_awaitable(None), side_effect=ResourceLimitError(403, "foo"), ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] - return_value={"123": mock_event} + self._rlsn._store.get_events = Mock( # type: ignore[assignment] + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -142,8 +145,8 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None: """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] - return_value=None, + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] + return_value=make_awaitable(None), side_effect=ResourceLimitError(403, "foo"), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -155,8 +158,8 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None: """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] - return_value=None + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] + return_value=make_awaitable(None) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -168,10 +171,12 @@ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None: Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] - return_value=None + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] + return_value=make_awaitable(None) + ) + self._rlsn._store.user_last_seen_monthly_active = Mock( + return_value=make_awaitable(None) ) - self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=None) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @@ -184,8 +189,8 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked( Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] - return_value=None, + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), @@ -199,8 +204,8 @@ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None: """ Test that when a server is disabled, that MAU limit alerting is ignored. """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] - return_value=None, + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED ), @@ -218,22 +223,22 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked( When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. """ - self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] - return_value=None, + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), ) - self._rlsn._is_room_currently_blocked = AsyncMock( # type: ignore[method-assign] - return_value=(True, []) + self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment] + return_value=make_awaitable((True, [])) ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] - return_value={"123": mock_event} + self._rlsn._store.get_events = Mock( # type: ignore[assignment] + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -279,9 +284,11 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self) -> None: - self.store.get_monthly_active_count = AsyncMock(return_value=1000) + self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) - self.store.user_last_seen_monthly_active = AsyncMock(return_value=1000) + self.store.user_last_seen_monthly_active = Mock( + return_value=make_awaitable(1000) + ) # Call the function multiple times to ensure we only send the notice once self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -320,7 +327,7 @@ def test_no_invite_without_notice(self) -> None: hasn't been reached (since it's the only user and the limit is 5), so users shouldn't receive a server notice. """ - m = AsyncMock(return_value=None) + m = Mock(return_value=make_awaitable(None)) self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m user_id = self.register_user("user", "password") diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 650b4941bab6..f541f1d6be1e 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -132,7 +132,6 @@ def test_timeout_lock(self) -> None: # We simulate the process getting stuck by cancelling the looping call # that keeps the lock active. - assert lock._looping_call lock._looping_call.stop() # Wait for the lock to timeout. @@ -404,7 +403,6 @@ def test_timeout_lock(self) -> None: # We simulate the process getting stuck by cancelling the looping call # that keeps the lock active. - assert lock._looping_call lock._looping_call.stop() # Wait for the lock to timeout. diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index cbce26a725c8..71302facd14d 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -15,7 +15,7 @@ import os import tempfile from typing import List, cast -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import yaml @@ -35,6 +35,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): @@ -338,7 +339,7 @@ def test_get_oldest_unsent_txn(self) -> None: # we aren't testing store._base stuff here, so mock this out # (ignore needed because Mypy won't allow us to assign to a method otherwise) - self.store.get_events_as_list = AsyncMock(return_value=events) # type: ignore[method-assign] + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) self.get_success(self._insert_txn(service.id, 10, events)) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index abf7d0564d81..a4a823a25242 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -from unittest.mock import AsyncMock, Mock + +from unittest.mock import Mock import yaml @@ -32,6 +32,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable, simple_async_mock from tests.unittest import override_config @@ -330,28 +331,6 @@ async def update_short(progress: JsonDict, count: int) -> int: self.update_handler.side_effect = update_short self.get_success(self.updates.do_next_background_update(False)) - def test_failed_update_logs_exception_details(self) -> None: - needle = "RUH ROH RAGGY" - - def failing_update(progress: JsonDict, count: int) -> int: - raise Exception(needle) - - self.update_handler.side_effect = failing_update - self.update_handler.reset_mock() - - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - values={"update_name": "test_update", "progress_json": "{}"}, - ) - ) - - with self.assertLogs(level=logging.ERROR) as logs: - # Expect a back-to-back RuntimeError to be raised - self.get_failure(self.updates.run_background_updates(False), RuntimeError) - - self.assertTrue(any(needle in log for log in logs.output), logs.output) - class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -369,8 +348,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Mock out the AsyncContextManager class MockCM: - __aenter__ = AsyncMock(return_value=None) - __aexit__ = AsyncMock(return_value=None) + __aenter__ = simple_async_mock(return_value=None) + __aexit__ = simple_async_mock(return_value=None) self._update_ctx_manager = MockCM @@ -384,9 +363,9 @@ class MockCM: # Register the callbacks with more mocks self.hs.get_module_api().register_background_update_controller_callbacks( on_update=self._on_update, - min_batch_size=AsyncMock(return_value=self._default_batch_size), - default_batch_size=AsyncMock( - return_value=self._default_batch_size, + min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)), + default_batch_size=Mock( + return_value=make_awaitable(self._default_batch_size), ), ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 6b9692c48625..209d68b40ba9 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Any, Dict -from unittest.mock import AsyncMock +from unittest.mock import Mock from parameterized import parameterized @@ -30,6 +30,7 @@ from tests import unittest from tests.server import make_request +from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -65,15 +66,15 @@ def test_insert_new_client_ip(self) -> None: ) r = result[(user_id, device_id)] - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": user_id, "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, - }.items(), - r.items(), + }, + r, ) def test_insert_new_client_ip_none_device_id(self) -> None: @@ -442,7 +443,9 @@ def test_adding_monthly_active_user_when_full(self) -> None: lots_of_users = 100 user_id = "@user:server" - self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) + self.store.get_monthly_active_count = Mock( + return_value=make_awaitable(lots_of_users) + ) self.get_success( self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id" @@ -526,15 +529,15 @@ def test_devices_last_seen_bg_update(self) -> None: ) r = result[(user_id, device_id)] - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": user_id, "device_id": device_id, "ip": None, "user_agent": None, "last_seen": None, - }.items(), - r.items(), + }, + r, ) # Register the background update to run again. @@ -561,15 +564,15 @@ def test_devices_last_seen_bg_update(self) -> None: ) r = result[(user_id, device_id)] - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": user_id, "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, - }.items(), - r.items(), + }, + r, ) def test_old_user_ips_pruned(self) -> None: @@ -640,15 +643,15 @@ def test_old_user_ips_pruned(self) -> None: ) r = result2[(user_id, device_id)] - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": user_id, "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, - }.items(), - r.items(), + }, + r, ) def test_invalid_user_agents_are_ignored(self) -> None: @@ -777,13 +780,13 @@ def _runtest( self.store.get_last_client_ip_by_device(self.user_id, device_id) ) r = result[(self.user_id, device_id)] - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": self.user_id, "device_id": device_id, "ip": expected_ip, "user_agent": "Mozzila pizza", "last_seen": 123456100, - }.items(), - r.items(), + }, + r, ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 58ab41cf2670..f03807c8f9d4 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -58,13 +58,13 @@ def test_store_new_device(self) -> None: res = self.get_success(self.store.get_device("user_id", "device_id")) assert res is not None - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": "user_id", "device_id": "device_id", "display_name": "display_name", - }.items(), - res.items(), + }, + res, ) def test_get_devices_by_user(self) -> None: @@ -80,21 +80,21 @@ def test_get_devices_by_user(self) -> None: res = self.get_success(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": "user_id", "device_id": "device1", "display_name": "display_name 1", - }.items(), - res["device1"].items(), + }, + res["device1"], ) - self.assertLessEqual( + self.assertDictContainsSubset( { "user_id": "user_id", "device_id": "device2", "display_name": "display_name 2", - }.items(), - res["device2"].items(), + }, + res["device2"], ) def test_count_devices_by_users(self) -> None: diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 2033377b5247..5fde3b9c7879 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -38,7 +38,7 @@ def test_key_without_device_name(self) -> None: self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertLessEqual(json.items(), dev.items()) + self.assertDictContainsSubset(json, dev) def test_reupload_key(self) -> None: now = 1470174257070 @@ -71,12 +71,8 @@ def test_get_key_with_device_name(self) -> None: self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertLessEqual( - { - "key": "value", - "unsigned": {"device_display_name": "display_name"}, - }.items(), - dev.items(), + self.assertDictContainsSubset( + {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev ) def test_multiple_devices(self) -> None: diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 49366440ce10..282773837907 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List -from unittest.mock import AsyncMock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -21,6 +21,7 @@ from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 @@ -252,7 +253,7 @@ def test_populate_monthly_users_is_guest(self) -> None: ) self.get_success(d) - self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -260,22 +261,24 @@ def test_populate_monthly_users_is_guest(self) -> None: self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self) -> None: - self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[method-assign] + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] - self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) + self.store.user_last_seen_monthly_active = Mock( + return_value=make_awaitable(None) + ) d = self.store.populate_monthly_active_users("user_id") self.get_success(d) self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self) -> None: - self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[method-assign] - self.store.user_last_seen_monthly_active = AsyncMock( - return_value=self.hs.get_clock().time_msec() + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] + self.store.user_last_seen_monthly_active = Mock( + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) d = self.store.populate_monthly_active_users("user_id") @@ -356,7 +359,7 @@ def test_track_monthly_users_without_cap(self) -> None: @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self) -> None: - self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1e27f2c275a1..71ec74eadc91 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -44,13 +44,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def test_get_room(self) -> None: res = self.get_success(self.store.get_room(self.room.to_string())) assert res is not None - self.assertLessEqual( + self.assertDictContainsSubset( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "is_public": True, - }.items(), - res.items(), + }, + res, ) def test_get_room_unknown_room(self) -> None: @@ -59,13 +59,13 @@ def test_get_room_unknown_room(self) -> None: def test_get_room_with_stats(self) -> None: res = self.get_success(self.store.get_room_with_stats(self.room.to_string())) assert res is not None - self.assertLessEqual( + self.assertDictContainsSubset( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "public": True, - }.items(), - res.items(), + }, + res, ) def test_get_room_with_stats_unknown_room(self) -> None: diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index 29be8cdbd0e8..0e3fc2a77f05 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -22,6 +22,7 @@ PartialStateEventsTracker, ) +from tests.test_utils import make_awaitable from tests.unittest import TestCase @@ -123,17 +124,16 @@ def test_cancellation(self) -> None: class PartialCurrentStateTrackerTestCase(TestCase): def setUp(self) -> None: self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) - self.mock_store.is_partial_state_room = mock.AsyncMock() self.tracker = PartialCurrentStateTracker(self.mock_store) def test_does_not_block_for_full_state_rooms(self) -> None: - self.mock_store.is_partial_state_room.return_value = False + self.mock_store.is_partial_state_room.return_value = make_awaitable(False) self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) def test_blocks_for_partial_room_state(self) -> None: - self.mock_store.is_partial_state_room.return_value = True + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) d = ensureDeferred(self.tracker.await_full_state("room_id")) @@ -156,7 +156,7 @@ async def is_partial_state_room(room_id: str) -> bool: self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) def test_cancellation(self) -> None: - self.mock_store.is_partial_state_room.return_value = True + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) d1 = ensureDeferred(self.tracker.await_full_state("room_id")) self.assertNoResult(d1) diff --git a/tests/test_federation.py b/tests/test_federation.py index f8ade6da3852..6d15ac759785 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Collection, List, Optional, Union -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -31,6 +31,7 @@ from synapse.util.retryutils import NotRetryingDestination from tests import unittest +from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): @@ -80,7 +81,7 @@ async def _check_event_auth( ) -> None: pass - federation_event_handler._check_event_auth = _check_event_auth # type: ignore[method-assign] + federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment] self.client = self.hs.get_federation_client() async def _check_sigs_and_hash_for_pulled_events_and_fetch( @@ -190,12 +191,12 @@ def query_user_devices( # Register the mock on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign] + federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment] # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.hs.get_datastores().main - store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"]) + store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. @@ -240,24 +241,27 @@ def test_cross_signing_keys_retry(self) -> None: # Register mock device list retrieval on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign] - return_value={ - "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { + federation_client.query_user_devices = Mock( # type: ignore[assignment] + return_value=make_awaitable( + { "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" + remote_self_signing_key: remote_self_signing_key + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - }, - } + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) ) # Resync the device list. diff --git a/tests/test_state.py b/tests/test_state.py index 9c8679cc1dc9..eded38c7669d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -714,7 +714,7 @@ def test_resolve_state_conflict( store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events # type: ignore[method-assign] + self.dummy_store.get_events = store.get_events # type: ignore[assignment] context: EventContext context = yield self._get_context( @@ -773,7 +773,7 @@ def test_standard_depth_conflict( store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events # type: ignore[method-assign] + self.dummy_store.get_events = store.get_events # type: ignore[assignment] context: EventContext context = yield self._get_context( diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 64a49488c654..52424aa08713 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -85,9 +85,7 @@ def test_ui_auth(self) -> None: } } self.assertIsInstance(channel.json_body["params"], dict) - self.assertLessEqual( - channel.json_body["params"].items(), expected_params.items() - ) + self.assertDictContainsSubset(channel.json_body["params"], expected_params) # We have to complete the dummy auth stage before completing the terms stage request_data = { diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index fa731426cda5..c8cc841d9540 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -18,8 +18,10 @@ import json import sys import warnings +from asyncio import Future from binascii import unhexlify -from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar +from unittest.mock import Mock import attr import zope.interface @@ -55,12 +57,27 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") +def make_awaitable(result: TV) -> Awaitable[TV]: + """ + Makes an awaitable, suitable for mocking an `async` function. + This uses Futures as they can be awaited multiple times so can be returned + to multiple callers. + """ + future: Future[TV] = Future() + future.set_result(result) + return future + + def setup_awaitable_errors() -> Callable[[], None]: """ Convert warnings from a non-awaited coroutines into errors. """ warnings.simplefilter("error", RuntimeWarning) + # unraisablehook was added in Python 3.8. + if not hasattr(sys, "unraisablehook"): + return lambda: None + # State shared between unraisablehook and check_for_unraisable_exceptions. unraisable_exceptions = [] orig_unraisablehook = sys.unraisablehook @@ -83,6 +100,18 @@ def cleanup() -> None: return cleanup +def simple_async_mock( + return_value: Optional[TV] = None, raises: Optional[Exception] = None +) -> Mock: + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args: Any, **kwargs: Any) -> Optional[TV]: + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + # Type ignore: it does not fully implement IResponse, but is good enough for tests @zope.interface.implementer(IResponse) @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/tests/unittest.py b/tests/unittest.py index 5d3640d8ac24..b0721e060c40 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -313,7 +313,7 @@ class HomeserverTestCase(TestCase): servlets: List of servlet registration function. user_id (str): The user ID to assume if auth is hijacked. hijack_auth: Whether to hijack auth to return the user specified - in user_id. + in user_id. """ hijack_auth: ClassVar[bool] = True @@ -395,9 +395,9 @@ async def get_requester(*args: Any, **kwargs: Any) -> Requester: ) # Type ignore: mypy doesn't like us assigning to methods. - self.hs.get_auth().get_user_by_req = get_requester # type: ignore[method-assign] - self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[method-assign] - self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[method-assign] + self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment] + self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment] + self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment] if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment] diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 05983ed434b1..91cac9822af4 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -60,9 +60,11 @@ def check_called_first(res: int) -> int: observer1.addBoth(check_called_first) # store the results - results: List[Optional[int]] = [None, None] + results: List[Optional[ObservableDeferred[int]]] = [None, None] - def check_val(res: int, idx: int) -> int: + def check_val( + res: ObservableDeferred[int], idx: int + ) -> ObservableDeferred[int]: results[idx] = res return res @@ -91,14 +93,14 @@ def check_called_first(res: int) -> int: observer1.addBoth(check_called_first) # store the results - results: List[Optional[Failure]] = [None, None] + results: List[Optional[ObservableDeferred[str]]] = [None, None] - def check_failure(res: Failure, idx: int) -> None: + def check_val(res: ObservableDeferred[str], idx: int) -> None: results[idx] = res return None - observer1.addErrback(check_failure, 0) - observer2.addErrback(check_failure, 1) + observer1.addErrback(check_val, 0) + observer2.addErrback(check_val, 1) try: raise Exception("gah!") diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py index 8665aeb50c0f..3a97559bf04b 100644 --- a/tests/util/test_task_scheduler.py +++ b/tests/util/test_task_scheduler.py @@ -22,11 +22,10 @@ from synapse.util import Clock from synapse.util.task_scheduler import TaskScheduler -from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.unittest import HomeserverTestCase, override_config +from tests import unittest -class TestTaskScheduler(HomeserverTestCase): +class TestTaskScheduler(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.task_scheduler = hs.get_task_scheduler() self.task_scheduler.register_action(self._test_task, "_test_task") @@ -35,7 +34,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.task_scheduler.register_action(self._resumable_task, "_resumable_task") async def _test_task( - self, task: ScheduledTask + self, task: ScheduledTask, first_launch: bool ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: # This test task will copy the parameters to the result result = None @@ -78,7 +77,7 @@ def test_schedule_task(self) -> None: self.assertIsNone(task) async def _sleeping_task( - self, task: ScheduledTask + self, task: ScheduledTask, first_launch: bool ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: # Sleep for a second await deferLater(self.reactor, 1, lambda: None) @@ -86,18 +85,24 @@ async def _sleeping_task( def test_schedule_lot_of_tasks(self) -> None: """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior.""" + timestamp = self.clock.time_msec() + 30 * 1000 task_ids = [] for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1): task_ids.append( self.get_success( self.task_scheduler.schedule_task( "_sleeping_task", + timestamp=timestamp, params={"val": i}, ) ) ) - # This is to give the time to the active tasks to finish + # The timestamp being 30s after now the task should been executed + # after the first scheduling loop is run + self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) + + # This is to give the time to the sleeping tasks to finish self.reactor.advance(1) # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one @@ -115,11 +120,10 @@ def test_schedule_lot_of_tasks(self) -> None: ) scheduled_tasks = [ - t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE + t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED ] self.assertEquals(len(scheduled_tasks), 1) - # We need to wait for the next run of the scheduler loop self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) self.reactor.advance(1) @@ -134,7 +138,7 @@ def test_schedule_lot_of_tasks(self) -> None: ) async def _raising_task( - self, task: ScheduledTask + self, task: ScheduledTask, first_launch: bool ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: raise Exception("raising") @@ -142,13 +146,15 @@ def test_schedule_raising_task(self) -> None: """Schedule a task raising an exception and check it runs to failure and report exception content.""" task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task")) + self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) + task = self.get_success(self.task_scheduler.get_task(task_id)) assert task is not None self.assertEqual(task.status, TaskStatus.FAILED) self.assertEqual(task.error, "raising") async def _resumable_task( - self, task: ScheduledTask + self, task: ScheduledTask, first_launch: bool ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: if task.result and "in_progress" in task.result: return TaskStatus.COMPLETE, {"success": True}, None @@ -163,6 +169,8 @@ def test_schedule_resumable_task(self) -> None: """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart.""" task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task")) + self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) + task = self.get_success(self.task_scheduler.get_task(task_id)) assert task is not None self.assertEqual(task.status, TaskStatus.ACTIVE) @@ -176,33 +184,3 @@ def test_schedule_resumable_task(self) -> None: self.assertEqual(task.status, TaskStatus.COMPLETE) assert task.result is not None self.assertTrue(task.result.get("success")) - - -class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.task_scheduler = hs.get_task_scheduler() - self.task_scheduler.register_action(self._test_task, "_test_task") - - async def _test_task( - self, task: ScheduledTask - ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: - return (TaskStatus.COMPLETE, None, None) - - @override_config({"run_background_tasks_on": "worker1"}) - def test_schedule_task(self) -> None: - """Check that a task scheduled to run now is launch right away on the background worker.""" - bg_worker_hs = self.make_worker_hs( - "synapse.app.generic_worker", - extra_config={"worker_name": "worker1"}, - ) - bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task") - - task_id = self.get_success( - self.task_scheduler.schedule_task( - "_test_task", - ) - ) - - task = self.get_success(self.task_scheduler.get_task(task_id)) - assert task is not None - self.assertEqual(task.status, TaskStatus.COMPLETE)