diff --git a/.ci/scripts/calculate_jobs.py b/.ci/scripts/calculate_jobs.py index 50e11e6504ff..661887e20985 100755 --- a/.ci/scripts/calculate_jobs.py +++ b/.ci/scripts/calculate_jobs.py @@ -47,10 +47,9 @@ def set_output(key: str, value: str): "database": "sqlite", "extras": "all", } - for version in ("3.9", "3.10", "3.11") + for version in ("3.9", "3.10", "3.11", "3.12.0-rc.1") ) - trial_postgres_tests = [ { "python-version": "3.8", diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml index ec6391cf8fd4..7b839f59c1d9 100644 --- a/.github/workflows/latest_deps.yml +++ b/.github/workflows/latest_deps.yml @@ -57,8 +57,8 @@ jobs: # `pip install matrix-synapse[all]` as closely as possible. - run: poetry update --no-dev - run: poetry run pip list > after.txt && (diff -u before.txt after.txt || true) - - name: Remove warn_unused_ignores from mypy config - run: sed '/warn_unused_ignores = True/d' -i mypy.ini + - name: Remove unhelpful options from mypy config + run: sed -e '/warn_unused_ignores = True/d' -e '/warn_redundant_casts = True/d' -i mypy.ini - run: poetry run mypy trial: needs: check_repo diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index 67ccc03f6e2d..7d629a4ed097 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -54,8 +54,8 @@ jobs: poetry remove twisted poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref || 'trunk' }} poetry install --no-interaction --extras "all test" - - name: Remove warn_unused_ignores from mypy config - run: sed '/warn_unused_ignores = True/d' -i mypy.ini + - name: Remove unhelpful options from mypy config + run: sed -e '/warn_unused_ignores = True/d' -e '/warn_redundant_casts = True/d' -i mypy.ini - run: poetry run mypy trial: diff --git a/Cargo.lock b/Cargo.lock index 61c0f1bd0402..4d60f8dcb62a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,9 +13,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.72" +version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b13c32d80ecc7ab747b80c3784bce54ee8a7a0cc4fbda9bf4cda2cf6fe90854" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" [[package]] name = "arc-swap" @@ -291,9 +291,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.3" +version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" +checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" dependencies = [ "aho-corasick", "memchr", @@ -303,9 +303,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" +checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" dependencies = [ "aho-corasick", "memchr", @@ -314,9 +314,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "ryu" @@ -332,18 +332,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.184" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c911f4b04d7385c9035407a4eff5903bf4fe270fa046fda448b69e797f4fff0" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.184" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1df27f5b29406ada06609b2e2f77fb34f6dbb104a457a671cc31dbed237e09e" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", @@ -352,9 +352,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.104" +version = "1.0.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076066c5f1078eac5b722a31827a8832fe108bed65dfa75e233c89f8206e976c" +checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" dependencies = [ "itoa", "ryu", diff --git a/changelog.d/15816.feature b/changelog.d/15816.feature new file mode 100644 index 000000000000..9248dd6792cc --- /dev/null +++ b/changelog.d/15816.feature @@ -0,0 +1 @@ +Add configuration setting for CAS protocol version. Contributed by AurĂ©lien Grimpard. diff --git a/changelog.d/16008.doc b/changelog.d/16008.doc new file mode 100644 index 000000000000..1142224951c4 --- /dev/null +++ b/changelog.d/16008.doc @@ -0,0 +1 @@ +Update links to the matrix.org blog. diff --git a/changelog.d/16099.misc b/changelog.d/16099.misc new file mode 100644 index 000000000000..d0e281136668 --- /dev/null +++ b/changelog.d/16099.misc @@ -0,0 +1 @@ +Prepare unit tests for Python 3.12. diff --git a/changelog.d/16113.feature b/changelog.d/16113.feature new file mode 100644 index 000000000000..69fdaaebacc1 --- /dev/null +++ b/changelog.d/16113.feature @@ -0,0 +1 @@ +Suppress notifications from message edits per [MSC3958](https://github.com/matrix-org/matrix-spec-proposals/pull/3958). diff --git a/changelog.d/16121.misc b/changelog.d/16121.misc new file mode 100644 index 000000000000..f325d2a31dbd --- /dev/null +++ b/changelog.d/16121.misc @@ -0,0 +1 @@ +Attempt to fix the twisted trunk job. diff --git a/changelog.d/16135.misc b/changelog.d/16135.misc new file mode 100644 index 000000000000..cba8733d0201 --- /dev/null +++ b/changelog.d/16135.misc @@ -0,0 +1 @@ +Describe which rate limiter was hit in logs. diff --git a/changelog.d/16136.feature b/changelog.d/16136.feature new file mode 100644 index 000000000000..4ad98a88c309 --- /dev/null +++ b/changelog.d/16136.feature @@ -0,0 +1 @@ +Return a `Retry-After` with `M_LIMIT_EXCEEDED` error responses. diff --git a/changelog.d/16155.bugfix b/changelog.d/16155.bugfix new file mode 100644 index 000000000000..8b2dc0400672 --- /dev/null +++ b/changelog.d/16155.bugfix @@ -0,0 +1 @@ +Fix IPv6-related bugs on SMTP settings, adding groundwork to fix similar issues. Contributed by @evilham and @telmich (ungleich.ch). diff --git a/changelog.d/16168.doc b/changelog.d/16168.doc new file mode 100644 index 000000000000..7dadb047bef9 --- /dev/null +++ b/changelog.d/16168.doc @@ -0,0 +1 @@ +Document which admin APIs are disabled when experimental [MSC3861](https://github.com/matrix-org/matrix-spec-proposals/pull/3861) support is enabled. diff --git a/changelog.d/16170.misc b/changelog.d/16170.misc new file mode 100644 index 000000000000..c950b5436705 --- /dev/null +++ b/changelog.d/16170.misc @@ -0,0 +1 @@ +Simplify presence code when using workers. diff --git a/changelog.d/16171.misc b/changelog.d/16171.misc new file mode 100644 index 000000000000..4d709cb56e19 --- /dev/null +++ b/changelog.d/16171.misc @@ -0,0 +1 @@ +Track per-device information in the presence code. diff --git a/changelog.d/16172.misc b/changelog.d/16172.misc new file mode 100644 index 000000000000..4d709cb56e19 --- /dev/null +++ b/changelog.d/16172.misc @@ -0,0 +1 @@ +Track per-device information in the presence code. diff --git a/changelog.d/16175.misc b/changelog.d/16175.misc new file mode 100644 index 000000000000..308fbc225923 --- /dev/null +++ b/changelog.d/16175.misc @@ -0,0 +1 @@ +Stop using the `event_txn_id` table. diff --git a/changelog.d/16178.doc b/changelog.d/16178.doc new file mode 100644 index 000000000000..ea21e19240bd --- /dev/null +++ b/changelog.d/16178.doc @@ -0,0 +1 @@ +Document `exclude_rooms_from_sync` configuration option. diff --git a/changelog.d/16179.misc b/changelog.d/16179.misc new file mode 100644 index 000000000000..8d04954ab97a --- /dev/null +++ b/changelog.d/16179.misc @@ -0,0 +1 @@ +Use `AsyncMock` instead of custom code. diff --git a/changelog.d/16180.misc b/changelog.d/16180.misc new file mode 100644 index 000000000000..8d04954ab97a --- /dev/null +++ b/changelog.d/16180.misc @@ -0,0 +1 @@ +Use `AsyncMock` instead of custom code. diff --git a/changelog.d/16183.misc b/changelog.d/16183.misc new file mode 100644 index 000000000000..305d5baa6e03 --- /dev/null +++ b/changelog.d/16183.misc @@ -0,0 +1 @@ +Improve error reporting of invalid data passed to `/_matrix/key/v2/query`. diff --git a/changelog.d/16184.misc b/changelog.d/16184.misc new file mode 100644 index 000000000000..3c0baddfe1c6 --- /dev/null +++ b/changelog.d/16184.misc @@ -0,0 +1 @@ +Task scheduler: add replication notify for new task to launch ASAP. diff --git a/changelog.d/16185.bugfix b/changelog.d/16185.bugfix new file mode 100644 index 000000000000..e62c9c7a0d8b --- /dev/null +++ b/changelog.d/16185.bugfix @@ -0,0 +1 @@ +Fix a spec compliance issue where requests to the `/publicRooms` federation API would specify `include_all_networks` as a string. diff --git a/changelog.d/16186.misc b/changelog.d/16186.misc new file mode 100644 index 000000000000..93ceaeafc9b9 --- /dev/null +++ b/changelog.d/16186.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/changelog.d/16187.misc b/changelog.d/16187.misc new file mode 100644 index 000000000000..989147274a70 --- /dev/null +++ b/changelog.d/16187.misc @@ -0,0 +1 @@ +Bump black version to 23.7.0. diff --git a/changelog.d/16188.misc b/changelog.d/16188.misc new file mode 100644 index 000000000000..93ceaeafc9b9 --- /dev/null +++ b/changelog.d/16188.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/changelog.d/16201.misc b/changelog.d/16201.misc new file mode 100644 index 000000000000..93ceaeafc9b9 --- /dev/null +++ b/changelog.d/16201.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/changelog.d/16205.bugfix b/changelog.d/16205.bugfix new file mode 100644 index 000000000000..97ac92a14889 --- /dev/null +++ b/changelog.d/16205.bugfix @@ -0,0 +1 @@ +Fix inaccurate error message while attempting to ban or unban a user with the same or higher PL by spliting the conditional statements. Contributed by @leviosacz. \ No newline at end of file diff --git a/changelog.d/16210.bugfix b/changelog.d/16210.bugfix new file mode 100644 index 000000000000..39c35a1fe144 --- /dev/null +++ b/changelog.d/16210.bugfix @@ -0,0 +1 @@ +Fix rare bug that broke looping calls, which could lead to e.g. linearly increasing memory usage. Introduced in v1.90.0. diff --git a/changelog.d/16211.bugfix b/changelog.d/16211.bugfix new file mode 100644 index 000000000000..ab1816386c1c --- /dev/null +++ b/changelog.d/16211.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where uploading images would fail if we could not generate thumbnails for them. diff --git a/changelog.d/16212.misc b/changelog.d/16212.misc new file mode 100644 index 000000000000..19cf9b102d37 --- /dev/null +++ b/changelog.d/16212.misc @@ -0,0 +1 @@ +Log the details of background update failures. diff --git a/changelog.d/16213.misc b/changelog.d/16213.misc new file mode 100644 index 000000000000..8c14f5fd51ad --- /dev/null +++ b/changelog.d/16213.misc @@ -0,0 +1 @@ +Fix the latest-deps CI job. diff --git a/changelog.d/16220.bugfix b/changelog.d/16220.bugfix new file mode 100644 index 000000000000..dcfac6bda110 --- /dev/null +++ b/changelog.d/16220.bugfix @@ -0,0 +1 @@ +Fix a performance regression introduced in Synapse 1.91.0 where event persistence would cause excessive CPU usage over time. diff --git a/changelog.d/16241.misc b/changelog.d/16241.misc new file mode 100644 index 000000000000..0fc5f34c5ce1 --- /dev/null +++ b/changelog.d/16241.misc @@ -0,0 +1 @@ +Cache device resync requests over replication. diff --git a/docs/admin_api/account_validity.md b/docs/admin_api/account_validity.md index 87d8f7150e8c..dfa69e515bfc 100644 --- a/docs/admin_api/account_validity.md +++ b/docs/admin_api/account_validity.md @@ -1,5 +1,7 @@ # Account validity API +**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) + This API allows a server administrator to manage the validity of an account. To use it, you must enable the account validity feature (under `account_validity`) in Synapse's configuration. diff --git a/docs/admin_api/register_api.md b/docs/admin_api/register_api.md index dd2830f3a18a..e9a235ada5e2 100644 --- a/docs/admin_api/register_api.md +++ b/docs/admin_api/register_api.md @@ -1,5 +1,7 @@ # Shared-Secret Registration +**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) + This API allows for the creation of users in an administrative and non-interactive way. This is generally used for bootstrapping a Synapse instance with administrator accounts. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 99abfea3a0fb..8032e05497ad 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -218,7 +218,7 @@ The following parameters should be set in the URL: - `name` - Is optional and filters to only return users with user ID localparts **or** displaynames that contain this value. - `guests` - string representing a bool - Is optional and if `false` will **exclude** guest users. - Defaults to `true` to include guest users. + Defaults to `true` to include guest users. This parameter is not supported when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) - `admins` - Optional flag to filter admins. If `true`, only admins are queried. If `false`, admins are excluded from the query. When the flag is absent (the default), **both** admins and non-admins are included in the search results. - `deactivated` - string representing a bool - Is optional and if `true` will **include** deactivated users. @@ -390,6 +390,8 @@ The following actions are **NOT** performed. The list may be incomplete. ## Reset password +**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) + Changes the password of another user. This will automatically log the user out of all their devices. The api is: @@ -413,6 +415,8 @@ The parameter `logout_devices` is optional and defaults to `true`. ## Get whether a user is a server administrator or not +**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) + The api is: ``` @@ -430,6 +434,8 @@ A response body like the following is returned: ## Change whether a user is a server administrator or not +**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) + Note that you cannot demote yourself. The api is: @@ -723,6 +729,8 @@ delete largest/smallest or newest/oldest files first. ## Login as a user +**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) + Get an access token that can be used to authenticate as that user. Useful for when admins wish to do actions on behalf of a user. diff --git a/docs/development/releases.md b/docs/development/releases.md index c9a8c6994597..6e83c81e27eb 100644 --- a/docs/development/releases.md +++ b/docs/development/releases.md @@ -12,7 +12,7 @@ Note that this schedule might be modified depending on the availability of the Synapse team, e.g. releases may be skipped to avoid holidays. Release announcements can be found in the -[release category of the Matrix blog](https://matrix.org/blog/category/releases). +[release category of the Matrix blog](https://matrix.org/category/releases). ## Bugfix releases @@ -34,4 +34,4 @@ be held to be released together. In some cases, a pre-disclosure of a security release will be issued as a notice to Synapse operators that there is an upcoming security release. These can be -found in the [security category of the Matrix blog](https://matrix.org/blog/category/security). +found in the [security category of the Matrix blog](https://matrix.org/category/security). diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md index c5130859d426..ba95bcf03801 100644 --- a/docs/usage/administration/admin_api/registration_tokens.md +++ b/docs/usage/administration/admin_api/registration_tokens.md @@ -1,5 +1,7 @@ # Registration Tokens +**Note:** This API is disabled when MSC3861 is enabled. [See #15582](https://github.com/matrix-org/synapse/pull/15582) + This API allows you to manage tokens which can be used to authenticate registration requests, as proposed in [MSC3231](https://github.com/matrix-org/matrix-doc/blob/main/proposals/3231-token-authenticated-registration.md) diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 743c51d76adf..0b1725816e3d 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3420,6 +3420,7 @@ Has the following sub-options: to style the login flow according to the identity provider in question. See the [spec](https://spec.matrix.org/latest/) for possible options here. * `server_url`: The URL of the CAS authorization endpoint. +* `protocol_version`: The CAS protocol version, defaults to none (version 3 is required if you want to use "required_attributes"). * `displayname_attribute`: The attribute of the CAS response to use as the display name. If no name is given here, no displayname will be set. * `required_attributes`: It is possible to configure Synapse to only allow logins if CAS attributes @@ -3433,6 +3434,7 @@ Example configuration: cas_config: enabled: true server_url: "https://cas-server.com" + protocol_version: 3 displayname_attribute: name required_attributes: userGroup: "staff" @@ -3865,6 +3867,19 @@ Example configuration: ```yaml forget_rooms_on_leave: false ``` +--- +### `exclude_rooms_from_sync` +A list of rooms to exclude from sync responses. This is useful for server +administrators wishing to group users into a room without these users being able +to see it from their client. + +By default, no room is excluded. + +Example configuration: +```yaml +exclude_rooms_from_sync: + - !foo:example.com +``` --- ## Opentracing diff --git a/mypy.ini b/mypy.ini index 311a951aa8de..fb5f44c939d8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -87,18 +87,9 @@ ignore_missing_imports = True [mypy-saml2.*] ignore_missing_imports = True -[mypy-service_identity.*] -ignore_missing_imports = True - [mypy-srvlookup.*] ignore_missing_imports = True # https://github.com/twisted/treq/pull/366 [mypy-treq.*] ignore_missing_imports = True - -[mypy-incremental.*] -ignore_missing_imports = True - -[mypy-setuptools_rust.*] -ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index e62c10da9f76..0688d5d92e3c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "alabaster" @@ -148,36 +148,33 @@ lxml = ["lxml"] [[package]] name = "black" -version = "23.3.0" +version = "23.7.0" description = "The uncompromising code formatter." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"}, - {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"}, - {file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"}, - {file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"}, - {file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"}, - {file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"}, - {file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"}, - {file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"}, - {file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"}, - {file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"}, - {file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"}, - {file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"}, - {file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"}, - {file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"}, - {file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"}, - {file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"}, - {file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"}, - {file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"}, - {file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"}, - {file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"}, - {file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"}, - {file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"}, - {file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"}, - {file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"}, - {file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, + {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, + {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, + {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, + {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, + {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, + {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, + {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, + {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, + {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, + {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, ] [package.dependencies] @@ -544,13 +541,13 @@ files = [ [[package]] name = "elementpath" -version = "4.1.0" +version = "4.1.5" description = "XPath 1.0/2.0/3.0/3.1 parsers and selectors for ElementTree and lxml" optional = true python-versions = ">=3.7" files = [ - {file = "elementpath-4.1.0-py3-none-any.whl", hash = "sha256:2b1b524223d70fd6dd63a36b9bc32e4919c96a272c2d1454094c4d85086bc6f8"}, - {file = "elementpath-4.1.0.tar.gz", hash = "sha256:dbd7eba3cf0b3b4934f627ba24851a3e0798ef2bc9104555a4cd831f2e6e8e14"}, + {file = "elementpath-4.1.5-py3-none-any.whl", hash = "sha256:2ac1a2fb31eb22bbbf817f8cf6752f844513216263f0e3892c8e79782fe4bb55"}, + {file = "elementpath-4.1.5.tar.gz", hash = "sha256:c2d6dc524b29ef751ecfc416b0627668119d8812441c555d7471da41d4bacb8d"}, ] [package.extras] @@ -1448,43 +1445,43 @@ files = [ [[package]] name = "mypy" -version = "1.0.1" +version = "1.4.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.7" files = [ - {file = "mypy-1.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:71a808334d3f41ef011faa5a5cd8153606df5fc0b56de5b2e89566c8093a0c9a"}, - {file = "mypy-1.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:920169f0184215eef19294fa86ea49ffd4635dedfdea2b57e45cb4ee85d5ccaf"}, - {file = "mypy-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27a0f74a298769d9fdc8498fcb4f2beb86f0564bcdb1a37b58cbbe78e55cf8c0"}, - {file = "mypy-1.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:65b122a993d9c81ea0bfde7689b3365318a88bde952e4dfa1b3a8b4ac05d168b"}, - {file = "mypy-1.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5deb252fd42a77add936b463033a59b8e48eb2eaec2976d76b6878d031933fe4"}, - {file = "mypy-1.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2013226d17f20468f34feddd6aae4635a55f79626549099354ce641bc7d40262"}, - {file = "mypy-1.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:48525aec92b47baed9b3380371ab8ab6e63a5aab317347dfe9e55e02aaad22e8"}, - {file = "mypy-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c96b8a0c019fe29040d520d9257d8c8f122a7343a8307bf8d6d4a43f5c5bfcc8"}, - {file = "mypy-1.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:448de661536d270ce04f2d7dddaa49b2fdba6e3bd8a83212164d4174ff43aa65"}, - {file = "mypy-1.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:d42a98e76070a365a1d1c220fcac8aa4ada12ae0db679cb4d910fabefc88b994"}, - {file = "mypy-1.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e64f48c6176e243ad015e995de05af7f22bbe370dbb5b32bd6988438ec873919"}, - {file = "mypy-1.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fdd63e4f50e3538617887e9aee91855368d9fc1dea30da743837b0df7373bc4"}, - {file = "mypy-1.0.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:dbeb24514c4acbc78d205f85dd0e800f34062efcc1f4a4857c57e4b4b8712bff"}, - {file = "mypy-1.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a2948c40a7dd46c1c33765718936669dc1f628f134013b02ff5ac6c7ef6942bf"}, - {file = "mypy-1.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5bc8d6bd3b274dd3846597855d96d38d947aedba18776aa998a8d46fabdaed76"}, - {file = "mypy-1.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:17455cda53eeee0a4adb6371a21dd3dbf465897de82843751cf822605d152c8c"}, - {file = "mypy-1.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e831662208055b006eef68392a768ff83596035ffd6d846786578ba1714ba8f6"}, - {file = "mypy-1.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e60d0b09f62ae97a94605c3f73fd952395286cf3e3b9e7b97f60b01ddfbbda88"}, - {file = "mypy-1.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:0af4f0e20706aadf4e6f8f8dc5ab739089146b83fd53cb4a7e0e850ef3de0bb6"}, - {file = "mypy-1.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:24189f23dc66f83b839bd1cce2dfc356020dfc9a8bae03978477b15be61b062e"}, - {file = "mypy-1.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93a85495fb13dc484251b4c1fd7a5ac370cd0d812bbfc3b39c1bafefe95275d5"}, - {file = "mypy-1.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f546ac34093c6ce33f6278f7c88f0f147a4849386d3bf3ae193702f4fe31407"}, - {file = "mypy-1.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c6c2ccb7af7154673c591189c3687b013122c5a891bb5651eca3db8e6c6c55bd"}, - {file = "mypy-1.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b5a824b58c7c822c51bc66308e759243c32631896743f030daf449fe3677f3"}, - {file = "mypy-1.0.1-py3-none-any.whl", hash = "sha256:eda5c8b9949ed411ff752b9a01adda31afe7eae1e53e946dbdf9db23865e66c4"}, - {file = "mypy-1.0.1.tar.gz", hash = "sha256:28cea5a6392bb43d266782983b5a4216c25544cd7d80be681a155ddcdafd152d"}, -] - -[package.dependencies] -mypy-extensions = ">=0.4.3" + {file = "mypy-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:566e72b0cd6598503e48ea610e0052d1b8168e60a46e0bfd34b3acf2d57f96a8"}, + {file = "mypy-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ca637024ca67ab24a7fd6f65d280572c3794665eaf5edcc7e90a866544076878"}, + {file = "mypy-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dde1d180cd84f0624c5dcaaa89c89775550a675aff96b5848de78fb11adabcd"}, + {file = "mypy-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c4d8e89aa7de683e2056a581ce63c46a0c41e31bd2b6d34144e2c80f5ea53dc"}, + {file = "mypy-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:bfdca17c36ae01a21274a3c387a63aa1aafe72bff976522886869ef131b937f1"}, + {file = "mypy-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7549fbf655e5825d787bbc9ecf6028731973f78088fbca3a1f4145c39ef09462"}, + {file = "mypy-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98324ec3ecf12296e6422939e54763faedbfcc502ea4a4c38502082711867258"}, + {file = "mypy-1.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:141dedfdbfe8a04142881ff30ce6e6653c9685b354876b12e4fe6c78598b45e2"}, + {file = "mypy-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8207b7105829eca6f3d774f64a904190bb2231de91b8b186d21ffd98005f14a7"}, + {file = "mypy-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:16f0db5b641ba159eff72cff08edc3875f2b62b2fa2bc24f68c1e7a4e8232d01"}, + {file = "mypy-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:470c969bb3f9a9efcedbadcd19a74ffb34a25f8e6b0e02dae7c0e71f8372f97b"}, + {file = "mypy-1.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5952d2d18b79f7dc25e62e014fe5a23eb1a3d2bc66318df8988a01b1a037c5b"}, + {file = "mypy-1.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:190b6bab0302cec4e9e6767d3eb66085aef2a1cc98fe04936d8a42ed2ba77bb7"}, + {file = "mypy-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9d40652cc4fe33871ad3338581dca3297ff5f2213d0df345bcfbde5162abf0c9"}, + {file = "mypy-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:01fd2e9f85622d981fd9063bfaef1aed6e336eaacca00892cd2d82801ab7c042"}, + {file = "mypy-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2460a58faeea905aeb1b9b36f5065f2dc9a9c6e4c992a6499a2360c6c74ceca3"}, + {file = "mypy-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2746d69a8196698146a3dbe29104f9eb6a2a4d8a27878d92169a6c0b74435b6"}, + {file = "mypy-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ae704dcfaa180ff7c4cfbad23e74321a2b774f92ca77fd94ce1049175a21c97f"}, + {file = "mypy-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:43d24f6437925ce50139a310a64b2ab048cb2d3694c84c71c3f2a1626d8101dc"}, + {file = "mypy-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c482e1246726616088532b5e964e39765b6d1520791348e6c9dc3af25b233828"}, + {file = "mypy-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:43b592511672017f5b1a483527fd2684347fdffc041c9ef53428c8dc530f79a3"}, + {file = "mypy-1.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34a9239d5b3502c17f07fd7c0b2ae6b7dd7d7f6af35fbb5072c6208e76295816"}, + {file = "mypy-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5703097c4936bbb9e9bce41478c8d08edd2865e177dc4c52be759f81ee4dd26c"}, + {file = "mypy-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e02d700ec8d9b1859790c0475df4e4092c7bf3272a4fd2c9f33d87fac4427b8f"}, + {file = "mypy-1.4.1-py3-none-any.whl", hash = "sha256:45d32cec14e7b97af848bddd97d85ea4f0db4d5a149ed9676caa4eb2f7402bb4"}, + {file = "mypy-1.4.1.tar.gz", hash = "sha256:9bbcd9ab8ea1f2e1c8031c21445b511442cc45c89951e49bbf852cbb70755b1b"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=3.10" +typing-extensions = ">=4.1.0" [package.extras] dmypy = ["psutil (>=4.0)"] @@ -1505,17 +1502,17 @@ files = [ [[package]] name = "mypy-zope" -version = "0.9.1" +version = "1.0.0" description = "Plugin for mypy to support zope interfaces" optional = false python-versions = "*" files = [ - {file = "mypy-zope-0.9.1.tar.gz", hash = "sha256:4c87dbc71fec35f6533746ecdf9d400cd9281338d71c16b5676bb5ed00a97ca2"}, - {file = "mypy_zope-0.9.1-py3-none-any.whl", hash = "sha256:733d4399affe9e61e332ce9c4049418d6775c39b473e4b9f409d51c207c1b71a"}, + {file = "mypy-zope-1.0.0.tar.gz", hash = "sha256:be815c2fcb5333aa87e8ec682029ad3214142fe2a05ea383f9ff2d77c98008b7"}, + {file = "mypy_zope-1.0.0-py3-none-any.whl", hash = "sha256:9732e9b2198f2aec3343b38a51905ff49d44dc9e39e8e8bc6fc490b232388209"}, ] [package.dependencies] -mypy = ">=1.0.0,<1.1.0" +mypy = ">=1.0.0,<1.5.0" "zope.interface" = "*" "zope.schema" = "*" @@ -1610,13 +1607,13 @@ files = [ [[package]] name = "phonenumbers" -version = "8.13.18" +version = "8.13.19" description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers." optional = false python-versions = "*" files = [ - {file = "phonenumbers-8.13.18-py2.py3-none-any.whl", hash = "sha256:3d802739a22592e4127139349937753dee9b6a20bdd5d56847cd885bdc766b1f"}, - {file = "phonenumbers-8.13.18.tar.gz", hash = "sha256:b360c756252805d44b447b5bca6d250cf6bd6c69b6f0f4258f3bfe5ab81bef69"}, + {file = "phonenumbers-8.13.19-py2.py3-none-any.whl", hash = "sha256:ba542f20f6dc83be8f127f240f9b5b7e7c1dec42aceff1879400d4dc0c781d81"}, + {file = "phonenumbers-8.13.19.tar.gz", hash = "sha256:38180247697240ccedd74dec4bfbdbc22bb108b9c5f991f270ca3e41395e6f96"}, ] [[package]] @@ -1744,24 +1741,22 @@ twisted = ["twisted"] [[package]] name = "psycopg2" -version = "2.9.6" +version = "2.9.7" description = "psycopg2 - Python-PostgreSQL Database Adapter" optional = true python-versions = ">=3.6" files = [ - {file = "psycopg2-2.9.6-cp310-cp310-win32.whl", hash = "sha256:f7a7a5ee78ba7dc74265ba69e010ae89dae635eea0e97b055fb641a01a31d2b1"}, - {file = "psycopg2-2.9.6-cp310-cp310-win_amd64.whl", hash = "sha256:f75001a1cbbe523e00b0ef896a5a1ada2da93ccd752b7636db5a99bc57c44494"}, - {file = "psycopg2-2.9.6-cp311-cp311-win32.whl", hash = "sha256:53f4ad0a3988f983e9b49a5d9765d663bbe84f508ed655affdb810af9d0972ad"}, - {file = "psycopg2-2.9.6-cp311-cp311-win_amd64.whl", hash = "sha256:b81fcb9ecfc584f661b71c889edeae70bae30d3ef74fa0ca388ecda50b1222b7"}, - {file = "psycopg2-2.9.6-cp36-cp36m-win32.whl", hash = "sha256:11aca705ec888e4f4cea97289a0bf0f22a067a32614f6ef64fcf7b8bfbc53744"}, - {file = "psycopg2-2.9.6-cp36-cp36m-win_amd64.whl", hash = "sha256:36c941a767341d11549c0fbdbb2bf5be2eda4caf87f65dfcd7d146828bd27f39"}, - {file = "psycopg2-2.9.6-cp37-cp37m-win32.whl", hash = "sha256:869776630c04f335d4124f120b7fb377fe44b0a7645ab3c34b4ba42516951889"}, - {file = "psycopg2-2.9.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a8ad4a47f42aa6aec8d061fdae21eaed8d864d4bb0f0cade5ad32ca16fcd6258"}, - {file = "psycopg2-2.9.6-cp38-cp38-win32.whl", hash = "sha256:2362ee4d07ac85ff0ad93e22c693d0f37ff63e28f0615a16b6635a645f4b9214"}, - {file = "psycopg2-2.9.6-cp38-cp38-win_amd64.whl", hash = "sha256:d24ead3716a7d093b90b27b3d73459fe8cd90fd7065cf43b3c40966221d8c394"}, - {file = "psycopg2-2.9.6-cp39-cp39-win32.whl", hash = "sha256:1861a53a6a0fd248e42ea37c957d36950da00266378746588eab4f4b5649e95f"}, - {file = "psycopg2-2.9.6-cp39-cp39-win_amd64.whl", hash = "sha256:ded2faa2e6dfb430af7713d87ab4abbfc764d8d7fb73eafe96a24155f906ebf5"}, - {file = "psycopg2-2.9.6.tar.gz", hash = "sha256:f15158418fd826831b28585e2ab48ed8df2d0d98f502a2b4fe619e7d5ca29011"}, + {file = "psycopg2-2.9.7-cp310-cp310-win32.whl", hash = "sha256:1a6a2d609bce44f78af4556bea0c62a5e7f05c23e5ea9c599e07678995609084"}, + {file = "psycopg2-2.9.7-cp310-cp310-win_amd64.whl", hash = "sha256:b22ed9c66da2589a664e0f1ca2465c29b75aaab36fa209d4fb916025fb9119e5"}, + {file = "psycopg2-2.9.7-cp311-cp311-win32.whl", hash = "sha256:44d93a0109dfdf22fe399b419bcd7fa589d86895d3931b01fb321d74dadc68f1"}, + {file = "psycopg2-2.9.7-cp311-cp311-win_amd64.whl", hash = "sha256:91e81a8333a0037babfc9fe6d11e997a9d4dac0f38c43074886b0d9dead94fe9"}, + {file = "psycopg2-2.9.7-cp37-cp37m-win32.whl", hash = "sha256:d1210fcf99aae6f728812d1d2240afc1dc44b9e6cba526a06fb8134f969957c2"}, + {file = "psycopg2-2.9.7-cp37-cp37m-win_amd64.whl", hash = "sha256:e9b04cbef584310a1ac0f0d55bb623ca3244c87c51187645432e342de9ae81a8"}, + {file = "psycopg2-2.9.7-cp38-cp38-win32.whl", hash = "sha256:d5c5297e2fbc8068d4255f1e606bfc9291f06f91ec31b2a0d4c536210ac5c0a2"}, + {file = "psycopg2-2.9.7-cp38-cp38-win_amd64.whl", hash = "sha256:8275abf628c6dc7ec834ea63f6f3846bf33518907a2b9b693d41fd063767a866"}, + {file = "psycopg2-2.9.7-cp39-cp39-win32.whl", hash = "sha256:c7949770cafbd2f12cecc97dea410c514368908a103acf519f2a346134caa4d5"}, + {file = "psycopg2-2.9.7-cp39-cp39-win_amd64.whl", hash = "sha256:b6bd7d9d3a7a63faae6edf365f0ed0e9b0a1aaf1da3ca146e6b043fb3eb5d723"}, + {file = "psycopg2-2.9.7.tar.gz", hash = "sha256:f00cc35bd7119f1fed17b85bd1007855194dde2cbd8de01ab8ebb17487440ad8"}, ] [[package]] @@ -2082,6 +2077,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2089,8 +2085,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2107,6 +2110,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2114,6 +2118,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -2329,28 +2334,28 @@ files = [ [[package]] name = "ruff" -version = "0.0.277" +version = "0.0.286" description = "An extremely fast Python linter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.277-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:3250b24333ef419b7a232080d9724ccc4d2da1dbbe4ce85c4caa2290d83200f8"}, - {file = "ruff-0.0.277-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3e60605e07482183ba1c1b7237eca827bd6cbd3535fe8a4ede28cbe2a323cb97"}, - {file = "ruff-0.0.277-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7baa97c3d7186e5ed4d5d4f6834d759a27e56cf7d5874b98c507335f0ad5aadb"}, - {file = "ruff-0.0.277-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:74e4b206cb24f2e98a615f87dbe0bde18105217cbcc8eb785bb05a644855ba50"}, - {file = "ruff-0.0.277-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:479864a3ccd8a6a20a37a6e7577bdc2406868ee80b1e65605478ad3b8eb2ba0b"}, - {file = "ruff-0.0.277-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:468bfb0a7567443cec3d03cf408d6f562b52f30c3c29df19927f1e0e13a40cd7"}, - {file = "ruff-0.0.277-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f32ec416c24542ca2f9cc8c8b65b84560530d338aaf247a4a78e74b99cd476b4"}, - {file = "ruff-0.0.277-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:14a7b2f00f149c5a295f188a643ac25226ff8a4d08f7a62b1d4b0a1dc9f9b85c"}, - {file = "ruff-0.0.277-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9879f59f763cc5628aa01c31ad256a0f4dc61a29355c7315b83c2a5aac932b5"}, - {file = "ruff-0.0.277-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f612e0a14b3d145d90eb6ead990064e22f6f27281d847237560b4e10bf2251f3"}, - {file = "ruff-0.0.277-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:323b674c98078be9aaded5b8b51c0d9c424486566fb6ec18439b496ce79e5998"}, - {file = "ruff-0.0.277-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3a43fbe026ca1a2a8c45aa0d600a0116bec4dfa6f8bf0c3b871ecda51ef2b5dd"}, - {file = "ruff-0.0.277-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:734165ea8feb81b0d53e3bf523adc2413fdb76f1264cde99555161dd5a725522"}, - {file = "ruff-0.0.277-py3-none-win32.whl", hash = "sha256:88d0f2afb2e0c26ac1120e7061ddda2a566196ec4007bd66d558f13b374b9efc"}, - {file = "ruff-0.0.277-py3-none-win_amd64.whl", hash = "sha256:6fe81732f788894a00f6ade1fe69e996cc9e485b7c35b0f53fb00284397284b2"}, - {file = "ruff-0.0.277-py3-none-win_arm64.whl", hash = "sha256:2d4444c60f2e705c14cd802b55cd2b561d25bf4311702c463a002392d3116b22"}, - {file = "ruff-0.0.277.tar.gz", hash = "sha256:2dab13cdedbf3af6d4427c07f47143746b6b95d9e4a254ac369a0edb9280a0d2"}, + {file = "ruff-0.0.286-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:8e22cb557e7395893490e7f9cfea1073d19a5b1dd337f44fd81359b2767da4e9"}, + {file = "ruff-0.0.286-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:68ed8c99c883ae79a9133cb1a86d7130feee0397fdf5ba385abf2d53e178d3fa"}, + {file = "ruff-0.0.286-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8301f0bb4ec1a5b29cfaf15b83565136c47abefb771603241af9d6038f8981e8"}, + {file = "ruff-0.0.286-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:acc4598f810bbc465ce0ed84417ac687e392c993a84c7eaf3abf97638701c1ec"}, + {file = "ruff-0.0.286-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88c8e358b445eb66d47164fa38541cfcc267847d1e7a92dd186dddb1a0a9a17f"}, + {file = "ruff-0.0.286-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0433683d0c5dbcf6162a4beb2356e820a593243f1fa714072fec15e2e4f4c939"}, + {file = "ruff-0.0.286-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddb61a0c4454cbe4623f4a07fef03c5ae921fe04fede8d15c6e36703c0a73b07"}, + {file = "ruff-0.0.286-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:47549c7c0be24c8ae9f2bce6f1c49fbafea83bca80142d118306f08ec7414041"}, + {file = "ruff-0.0.286-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:559aa793149ac23dc4310f94f2c83209eedb16908a0343663be19bec42233d25"}, + {file = "ruff-0.0.286-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d73cfb1c3352e7aa0ce6fb2321f36fa1d4a2c48d2ceac694cb03611ddf0e4db6"}, + {file = "ruff-0.0.286-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:3dad93b1f973c6d1db4b6a5da8690c5625a3fa32bdf38e543a6936e634b83dc3"}, + {file = "ruff-0.0.286-py3-none-musllinux_1_2_i686.whl", hash = "sha256:26afc0851f4fc3738afcf30f5f8b8612a31ac3455cb76e611deea80f5c0bf3ce"}, + {file = "ruff-0.0.286-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9b6b116d1c4000de1b9bf027131dbc3b8a70507788f794c6b09509d28952c512"}, + {file = "ruff-0.0.286-py3-none-win32.whl", hash = "sha256:556e965ac07c1e8c1c2d759ac512e526ecff62c00fde1a046acb088d3cbc1a6c"}, + {file = "ruff-0.0.286-py3-none-win_amd64.whl", hash = "sha256:5d295c758961376c84aaa92d16e643d110be32add7465e197bfdaec5a431a107"}, + {file = "ruff-0.0.286-py3-none-win_arm64.whl", hash = "sha256:1d6142d53ab7f164204b3133d053c4958d4d11ec3a39abf23a40b13b0784e3f0"}, + {file = "ruff-0.0.286.tar.gz", hash = "sha256:f1e9d169cce81a384a26ee5bb8c919fe9ae88255f39a1a69fd1ebab233a85ed2"}, ] [[package]] @@ -2385,13 +2390,13 @@ doc = ["Sphinx", "sphinx-rtd-theme"] [[package]] name = "sentry-sdk" -version = "1.29.2" +version = "1.30.0" description = "Python client for Sentry (https://sentry.io)" optional = true python-versions = "*" files = [ - {file = "sentry-sdk-1.29.2.tar.gz", hash = "sha256:a99ee105384788c3f228726a88baf515fe7b5f1d2d0f215a03d194369f158df7"}, - {file = "sentry_sdk-1.29.2-py2.py3-none-any.whl", hash = "sha256:3e17215d8006612e2df02b0e73115eb8376c37e3f586d8436fa41644e605074d"}, + {file = "sentry-sdk-1.30.0.tar.gz", hash = "sha256:7dc873b87e1faf4d00614afd1058bfa1522942f33daef8a59f90de8ed75cd10c"}, + {file = "sentry_sdk-1.30.0-py2.py3-none-any.whl", hash = "sha256:2e53ad63f96bb9da6570ba2e755c267e529edcf58580a2c0d2a11ef26e1e678b"}, ] [package.dependencies] @@ -2414,6 +2419,7 @@ httpx = ["httpx (>=0.16.0)"] huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] +opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] pure-eval = ["asttokens", "executing", "pure-eval"] pymongo = ["pymongo (>=3.1)"] pyspark = ["pyspark (>=2.4.4)"] @@ -2467,18 +2473,19 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( [[package]] name = "setuptools-rust" -version = "1.6.0" +version = "1.7.0" description = "Setuptools Rust extension plugin" optional = false python-versions = ">=3.7" files = [ - {file = "setuptools-rust-1.6.0.tar.gz", hash = "sha256:c86e734deac330597998bfbc08da45187e6b27837e23bd91eadb320732392262"}, - {file = "setuptools_rust-1.6.0-py3-none-any.whl", hash = "sha256:e28ae09fb7167c44ab34434eb49279307d611547cb56cb9789955cdb54a1aed9"}, + {file = "setuptools-rust-1.7.0.tar.gz", hash = "sha256:c7100999948235a38ae7e555fe199aa66c253dc384b125f5d85473bf81eae3a3"}, + {file = "setuptools_rust-1.7.0-py3-none-any.whl", hash = "sha256:071099885949132a2180d16abf907b60837e74b4085047ba7e9c0f5b365310c1"}, ] [package.dependencies] semantic-version = ">=2.8.2,<3" setuptools = ">=62.4" +tomli = {version = ">=1.2.1", markers = "python_version < \"3.11\""} typing-extensions = ">=3.7.4.3" [[package]] @@ -3002,13 +3009,13 @@ files = [ [[package]] name = "types-psycopg2" -version = "2.9.21.10" +version = "2.9.21.11" description = "Typing stubs for psycopg2" optional = false python-versions = "*" files = [ - {file = "types-psycopg2-2.9.21.10.tar.gz", hash = "sha256:c2600892312ae1c34e12f145749795d93dc4eac3ef7dbf8a9c1bfd45385e80d7"}, - {file = "types_psycopg2-2.9.21.10-py3-none-any.whl", hash = "sha256:918224a0731a3650832e46633e720703b5beef7693a064e777d9748654fcf5e5"}, + {file = "types-psycopg2-2.9.21.11.tar.gz", hash = "sha256:d5077eacf90e61db8c0b8eea2fdc9d4a97d7aaa16865fb4bd7034a7571520b4d"}, + {file = "types_psycopg2-2.9.21.11-py3-none-any.whl", hash = "sha256:7a323d7744bc8a882fb5a6f63448e903fc70d3dc0d6da9ec1f9c6c4dc10a7102"}, ] [[package]] @@ -3027,13 +3034,13 @@ cryptography = ">=35.0.0" [[package]] name = "types-pyyaml" -version = "6.0.12.10" +version = "6.0.12.11" description = "Typing stubs for PyYAML" optional = false python-versions = "*" files = [ - {file = "types-PyYAML-6.0.12.10.tar.gz", hash = "sha256:ebab3d0700b946553724ae6ca636ea932c1b0868701d4af121630e78d695fc97"}, - {file = "types_PyYAML-6.0.12.10-py3-none-any.whl", hash = "sha256:662fa444963eff9b68120d70cda1af5a5f2aa57900003c2006d7626450eaae5f"}, + {file = "types-PyYAML-6.0.12.11.tar.gz", hash = "sha256:7d340b19ca28cddfdba438ee638cd4084bde213e501a3978738543e27094775b"}, + {file = "types_PyYAML-6.0.12.11-py3-none-any.whl", hash = "sha256:a461508f3096d1d5810ec5ab95d7eeecb651f3a15b71959999988942063bf01d"}, ] [[package]] @@ -3207,22 +3214,22 @@ files = [ [[package]] name = "xmlschema" -version = "2.2.2" +version = "2.4.0" description = "An XML Schema validator and decoder" optional = true python-versions = ">=3.7" files = [ - {file = "xmlschema-2.2.2-py3-none-any.whl", hash = "sha256:557f3632b54b6ff10576736bba62e43db84eb60f6465a83818576cd9ffcc1799"}, - {file = "xmlschema-2.2.2.tar.gz", hash = "sha256:0caa96668807b4b51c42a0fe2b6610752bc59f069615df3e34dcfffb962973fd"}, + {file = "xmlschema-2.4.0-py3-none-any.whl", hash = "sha256:dc87be0caaa61f42649899189aab2fd8e0d567f2cf548433ba7b79278d231a4a"}, + {file = "xmlschema-2.4.0.tar.gz", hash = "sha256:d74cd0c10866ac609e1ef94a5a69b018ad16e39077bc6393408b40c6babee793"}, ] [package.dependencies] -elementpath = ">=4.0.0,<5.0.0" +elementpath = ">=4.1.5,<5.0.0" [package.extras] -codegen = ["elementpath (>=4.0.0,<5.0.0)", "jinja2"] -dev = ["Sphinx", "coverage", "elementpath (>=4.0.0,<5.0.0)", "flake8", "jinja2", "lxml", "lxml-stubs", "memory-profiler", "mypy", "sphinx-rtd-theme", "tox"] -docs = ["Sphinx", "elementpath (>=4.0.0,<5.0.0)", "jinja2", "sphinx-rtd-theme"] +codegen = ["elementpath (>=4.1.5,<5.0.0)", "jinja2"] +dev = ["Sphinx", "coverage", "elementpath (>=4.1.5,<5.0.0)", "flake8", "jinja2", "lxml", "lxml-stubs", "memory-profiler", "mypy", "sphinx-rtd-theme", "tox"] +docs = ["Sphinx", "elementpath (>=4.1.5,<5.0.0)", "jinja2", "sphinx-rtd-theme"] [[package]] name = "zipp" @@ -3343,4 +3350,4 @@ user-search = ["pyicu"] [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "0a8c6605e7e1d0ac7188a5d02b47a029bfb0f917458b87cb40755911442383d8" +content-hash = "4a3a82becd89b91e76e2bc2f8ba72123f665c517d9b841d9a34cd01b83a1adc3" diff --git a/pyproject.toml b/pyproject.toml index 2a4ff1ea01c8..c1f95e945847 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ showcontent = true [tool.black] -target-version = ['py37', 'py38', 'py39', 'py310'] +target-version = ['py38', 'py39', 'py310', 'py311'] # black ignores everything in .gitignore by default, see # https://black.readthedocs.io/en/stable/usage_and_configuration/file_collection_and_discovery.html#gitignore # Use `extend-exclude` if you want to exclude something in addition to this. @@ -306,10 +306,13 @@ all = [ ] [tool.poetry.dev-dependencies] -# We pin black so that our tests don't start failing on new releases. +# We pin development dependencies in poetry.lock so that our tests don't start +# failing on new releases. Keeping lower bounds loose here means that dependabot +# can bump versions without having to update the content-hash in the lockfile. +# This helps prevents merge conflicts when running a batch of dependabot updates. isort = ">=5.10.1" -black = ">=22.3.0" -ruff = "0.0.277" +black = ">=22.7.0" +ruff = "0.0.286" # Typechecking lxml-stubs = ">=0.4.0" diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 6e1eab2a3b29..14071105a05b 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -197,7 +197,6 @@ fn bench_eval_message(b: &mut Bencher) { false, false, false, - false, ); b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 00baceda91fa..59fd27665aee 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -228,7 +228,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ // We don't want to notify on edits *unless* the edit directly mentions a // user, which is handled above. PushRule { - rule_id: Cow::Borrowed("global/override/.org.matrix.msc3958.suppress_edits"), + rule_id: Cow::Borrowed("global/override/.m.rule.suppress_edits"), priority_class: 5, conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventPropertyIs( EventPropertyIsCondition { diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 48e670478bf7..5b9bf9b26ae1 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -564,7 +564,7 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false), + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true), None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 829fb79d0e5b..8e91f506cc42 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -527,7 +527,6 @@ pub struct FilteredPushRules { msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, - msc3958_suppress_edits_enabled: bool, } #[pymethods] @@ -539,7 +538,6 @@ impl FilteredPushRules { msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, - msc3958_suppress_edits_enabled: bool, ) -> Self { Self { push_rules, @@ -547,7 +545,6 @@ impl FilteredPushRules { msc1767_enabled, msc3381_polls_enabled, msc3664_enabled, - msc3958_suppress_edits_enabled, } } @@ -584,12 +581,6 @@ impl FilteredPushRules { return false; } - if !self.msc3958_suppress_edits_enabled - && rule.rule_id == "global/override/.org.matrix.msc3958.suppress_edits" - { - return false; - } - true }) .map(|r| { diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index d573a37b9aff..1f432d4ecfbf 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -46,7 +46,6 @@ class FilteredPushRules: msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, - msc3958_suppress_edits_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... diff --git a/synapse/__init__.py b/synapse/__init__.py index 2f9c22a83352..4a9bbc4d57b7 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -21,9 +21,14 @@ import sys from typing import Any, Dict +from PIL import ImageFile + from synapse.util.rust import check_rust_lib_up_to_date from synapse.util.stringutils import strtobool +# Allow truncated JPEG images to be thumbnailed. +ImageFile.LOAD_TRUNCATED_IMAGES = True + # Check that we're not running on an unsupported Python version. # # Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 49242800b858..ab2b29cf1b49 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -482,7 +482,10 @@ def r( do_backward[0] = False if forward_rows or backward_rows: - headers = [column[0] for column in txn.description] + assert txn.description is not None + headers: Optional[List[str]] = [ + column[0] for column in txn.description + ] else: headers = None @@ -544,6 +547,7 @@ async def handle_search_table( def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select, (forward_chunk, self.batch_size)) rows = txn.fetchall() + assert txn.description is not None headers = [column[0] for column in txn.description] return headers, rows @@ -919,7 +923,8 @@ async def _setup_sent_transactions(self) -> Tuple[int, int, int]: def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select) rows = txn.fetchall() - headers: List[str] = [column[0] for column in txn.description] + assert txn.description is not None + headers = [column[0] for column in txn.description] ts_ind = headers.index("ts") diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 7ffd72c42cd4..fdb2955be82b 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -16,6 +16,7 @@ """Contains exceptions and error codes.""" import logging +import math import typing from enum import Enum from http import HTTPStatus @@ -210,6 +211,11 @@ def __init__( def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, **self._additional_fields) + @property + def debug_context(self) -> Optional[str]: + """Override this to add debugging context that shouldn't be sent to clients.""" + return None + class InvalidAPICallError(SynapseError): """You called an existing API endpoint, but fed that endpoint @@ -503,19 +509,31 @@ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled.""" + include_retry_after_header = False + def __init__( self, + limiter_name: str, code: int = 429, - msg: str = "Too Many Requests", retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - super().__init__(code, msg, errcode) + headers = ( + {"Retry-After": str(math.ceil(retry_after_ms / 1000))} + if self.include_retry_after_header and retry_after_ms is not None + else None + ) + super().__init__(code, "Too Many Requests", errcode, headers=headers) self.retry_after_ms = retry_after_ms + self.limiter_name = limiter_name def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) + @property + def debug_context(self) -> Optional[str]: + return self.limiter_name + class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store""" diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 511790c7c5e4..887b214d64a3 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -61,12 +61,16 @@ class Ratelimiter: """ def __init__( - self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + self, + store: DataStore, + clock: Clock, + cfg: RatelimitSettings, ): self.clock = clock - self.rate_hz = rate_hz - self.burst_count = burst_count + self.rate_hz = cfg.per_second + self.burst_count = cfg.burst_count self.store = store + self._limiter_name = cfg.key # An ordered dictionary representing the token buckets tracked by this rate # limiter. Each entry maps a key of arbitrary type to a tuple representing: @@ -305,7 +309,8 @@ async def ratelimit( if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)) + limiter_name=self._limiter_name, + retry_after_ms=int(1000 * (time_allowed - time_now_s)), ) @@ -322,7 +327,9 @@ def __init__( # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - store=self.store, clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, + clock=self.clock, + cfg=RatelimitSettings(key=rc_message.key, per_second=0, burst_count=0), ) self._rc_message = rc_message @@ -332,8 +339,7 @@ def __init__( self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=rc_admin_redaction.per_second, - burst_count=rc_admin_redaction.burst_count, + cfg=rc_admin_redaction, ) else: self.admin_redaction_ratelimiter = None diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1d268a1817cd..69a831812759 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -186,9 +186,9 @@ def parse_size(value: Union[str, int]) -> int: TypeError, if given something other than an integer or a string ValueError: if given a string not of the form described above. """ - if type(value) is int: + if type(value) is int: # noqa: E721 return value - elif type(value) is str: + elif isinstance(value, str): sizes = {"K": 1024, "M": 1024 * 1024} size = 1 suffix = value[-1] @@ -218,9 +218,9 @@ def parse_duration(value: Union[str, int]) -> int: TypeError, if given something other than an integer or a string ValueError: if given a string not of the form described above. """ - if type(value) is int: + if type(value) is int: # noqa: E721 return value - elif type(value) is str: + elif isinstance(value, str): second = 1000 minute = 60 * second hour = 60 * minute diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 919f81a9b716..a70dfbf41f93 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -34,7 +34,7 @@ class AppServiceConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.app_service_config_files = config.get("app_service_config_files", []) if not isinstance(self.app_service_config_files, list) or not all( - type(x) is str for x in self.app_service_config_files + isinstance(x, str) for x in self.app_service_config_files ): raise ConfigError( "Expected '%s' to be a list of AS config files:" diff --git a/synapse/config/cas.py b/synapse/config/cas.py index c4e63e74118c..6e2d9addbf4c 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -18,7 +18,7 @@ from synapse.config.sso import SsoAttributeRequirement from synapse.types import JsonDict -from ._base import Config +from ._base import Config, ConfigError from ._util import validate_config @@ -41,6 +41,16 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: public_baseurl = self.root.server.public_baseurl self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" + self.cas_protocol_version = cas_config.get("protocol_version") + if ( + self.cas_protocol_version is not None + and self.cas_protocol_version not in [1, 2, 3] + ): + raise ConfigError( + "Unsupported CAS protocol version %s (only versions 1, 2, 3 are supported)" + % (self.cas_protocol_version,), + ("cas_config", "protocol_version"), + ) self.cas_displayname_attribute = cas_config.get("displayname_attribute") required_attributes = cas_config.get("required_attributes") or {} self.cas_required_attributes = _parsed_required_attributes_def( @@ -54,6 +64,7 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: else: self.cas_server_url = None self.cas_service_url = None + self.cas_protocol_version = None self.cas_displayname_attribute = None self.cas_required_attributes = [] diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 277ea4675b29..cabe0d4397cd 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -18,6 +18,7 @@ import attr import attr.validators +from synapse.api.errors import LimitExceededError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config import ConfigError from synapse.config._base import Config, RootConfig @@ -383,11 +384,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) - # MSC3959: Do not generate notifications for edits. - self.msc3958_supress_edit_notifs = experimental.get( - "msc3958_supress_edit_notifs", False - ) - # MSC3967: Do not require UIA when first uploading cross signing keys self.msc3967_enabled = experimental.get("msc3967_enabled", False) @@ -411,3 +407,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.msc4010_push_rules_account_data = experimental.get( "msc4010_push_rules_account_data", False ) + + # MSC4041: Use HTTP header Retry-After to enable library-assisted retry handling + # + # This is a bit hacky, but the most reasonable way to *alway* include the + # headers. + LimitExceededError.include_retry_after_header = experimental.get( + "msc4041_enabled", False + ) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index a5514e70a21d..4efbaeac0d7f 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast import attr @@ -21,16 +21,47 @@ from ._base import Config +@attr.s(slots=True, frozen=True, auto_attribs=True) class RatelimitSettings: - def __init__( - self, - config: Dict[str, float], + key: str + per_second: float + burst_count: int + + @classmethod + def parse( + cls, + config: Dict[str, Any], + key: str, defaults: Optional[Dict[str, float]] = None, - ): + ) -> "RatelimitSettings": + """Parse config[key] as a new-style rate limiter config. + + The key may refer to a nested dictionary using a full stop (.) to separate + each nested key. For example, use the key "a.b.c" to parse the following: + + a: + b: + c: + per_second: 10 + burst_count: 200 + + If this lookup fails, we'll fallback to the defaults. + """ defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} - self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = int(config.get("burst_count", defaults["burst_count"])) + rl_config = config + for part in key.split("."): + rl_config = rl_config.get(part, {}) + + # By this point we should have hit the rate limiter parameters. + # We don't actually check this though! + rl_config = cast(Dict[str, float], rl_config) + + return cls( + key=key, + per_second=rl_config.get("per_second", defaults["per_second"]), + burst_count=int(rl_config.get("burst_count", defaults["burst_count"])), + ) @attr.s(auto_attribs=True) @@ -49,15 +80,14 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # Load the new-style messages config if it exists. Otherwise fall back # to the old method. if "rc_message" in config: - self.rc_message = RatelimitSettings( - config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0} + self.rc_message = RatelimitSettings.parse( + config, "rc_message", defaults={"per_second": 0.2, "burst_count": 10.0} ) else: self.rc_message = RatelimitSettings( - { - "per_second": config.get("rc_messages_per_second", 0.2), - "burst_count": config.get("rc_message_burst_count", 10.0), - } + key="rc_messages", + per_second=config.get("rc_messages_per_second", 0.2), + burst_count=config.get("rc_message_burst_count", 10.0), ) # Load the new-style federation config, if it exists. Otherwise, fall @@ -79,51 +109,59 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: } ) - self.rc_registration = RatelimitSettings(config.get("rc_registration", {})) + self.rc_registration = RatelimitSettings.parse(config, "rc_registration", {}) - self.rc_registration_token_validity = RatelimitSettings( - config.get("rc_registration_token_validity", {}), + self.rc_registration_token_validity = RatelimitSettings.parse( + config, + "rc_registration_token_validity", defaults={"per_second": 0.1, "burst_count": 5}, ) # It is reasonable to login with a bunch of devices at once (i.e. when # setting up an account), but it is *not* valid to continually be # logging into new devices. - rc_login_config = config.get("rc_login", {}) - self.rc_login_address = RatelimitSettings( - rc_login_config.get("address", {}), + self.rc_login_address = RatelimitSettings.parse( + config, + "rc_login.address", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_account = RatelimitSettings( - rc_login_config.get("account", {}), + self.rc_login_account = RatelimitSettings.parse( + config, + "rc_login.account", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_failed_attempts = RatelimitSettings( - rc_login_config.get("failed_attempts", {}) + self.rc_login_failed_attempts = RatelimitSettings.parse( + config, + "rc_login.failed_attempts", + {}, ) self.federation_rr_transactions_per_room_per_second = config.get( "federation_rr_transactions_per_room_per_second", 50 ) - rc_admin_redaction = config.get("rc_admin_redaction") self.rc_admin_redaction = None - if rc_admin_redaction: - self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction) + if "rc_admin_redaction" in config: + self.rc_admin_redaction = RatelimitSettings.parse( + config, "rc_admin_redaction", {} + ) - self.rc_joins_local = RatelimitSettings( - config.get("rc_joins", {}).get("local", {}), + self.rc_joins_local = RatelimitSettings.parse( + config, + "rc_joins.local", defaults={"per_second": 0.1, "burst_count": 10}, ) - self.rc_joins_remote = RatelimitSettings( - config.get("rc_joins", {}).get("remote", {}), + self.rc_joins_remote = RatelimitSettings.parse( + config, + "rc_joins.remote", defaults={"per_second": 0.01, "burst_count": 10}, ) # Track the rate of joins to a given room. If there are too many, temporarily # prevent local joins and remote joins via this server. - self.rc_joins_per_room = RatelimitSettings( - config.get("rc_joins_per_room", {}), + self.rc_joins_per_room = RatelimitSettings.parse( + config, + "rc_joins_per_room", defaults={"per_second": 1, "burst_count": 10}, ) @@ -132,31 +170,37 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # * For requests received over federation this is keyed by the origin. # # Note that this isn't exposed in the configuration as it is obscure. - self.rc_key_requests = RatelimitSettings( - config.get("rc_key_requests", {}), + self.rc_key_requests = RatelimitSettings.parse( + config, + "rc_key_requests", defaults={"per_second": 20, "burst_count": 100}, ) - self.rc_3pid_validation = RatelimitSettings( - config.get("rc_3pid_validation") or {}, + self.rc_3pid_validation = RatelimitSettings.parse( + config, + "rc_3pid_validation", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_room = RatelimitSettings( - config.get("rc_invites", {}).get("per_room", {}), + self.rc_invites_per_room = RatelimitSettings.parse( + config, + "rc_invites.per_room", defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_invites_per_user = RatelimitSettings( - config.get("rc_invites", {}).get("per_user", {}), + self.rc_invites_per_user = RatelimitSettings.parse( + config, + "rc_invites.per_user", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_issuer = RatelimitSettings( - config.get("rc_invites", {}).get("per_issuer", {}), + self.rc_invites_per_issuer = RatelimitSettings.parse( + config, + "rc_invites.per_issuer", defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_third_party_invite = RatelimitSettings( - config.get("rc_third_party_invite", {}), + self.rc_third_party_invite = RatelimitSettings.parse( + config, + "rc_third_party_invite", defaults={"per_second": 0.0025, "burst_count": 5}, ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 3a260a492bea..2ac9f8b309cf 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -669,12 +669,18 @@ def _is_membership_change_allowed( errcode=Codes.INSUFFICIENT_POWER, ) elif Membership.BAN == membership: - if user_level < ban_level or user_level <= target_level: + if user_level < ban_level: raise UnstableSpecAuthError( 403, "You don't have permission to ban", errcode=Codes.INSUFFICIENT_POWER, ) + elif user_level <= target_level: + raise UnstableSpecAuthError( + 403, + "You don't have permission to ban this user", + errcode=Codes.INSUFFICIENT_POWER, + ) elif room_version.knock_join_rule and Membership.KNOCK == membership: if join_rule != JoinRules.KNOCK and ( not room_version.knock_restricted_join_rule @@ -846,11 +852,11 @@ def _check_power_levels( "kick", "invite", }: - if type(v) is not int: + if type(v) is not int: # noqa: E721 raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: if not isinstance(v, collections.abc.Mapping) or not all( - type(v) is int for v in v.values() + type(v) is int for v in v.values() # noqa: E721 ): raise SynapseError( 400, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 52acb219556f..53af423a5a98 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -702,7 +702,7 @@ def _copy_power_level_value_as_integer( :raises TypeError: if `old_value` is neither an integer nor a base-10 string representation of an integer. """ - if type(old_value) is int: + if type(old_value) is int: # noqa: E721 power_levels[key] = old_value return @@ -730,7 +730,7 @@ def validate_canonicaljson(value: Any) -> None: * Floats * NaN, Infinity, -Infinity """ - if type(value) is int: + if type(value) is int: # noqa: E721 if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 9278f1a1aa65..34625dd7a185 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -151,7 +151,7 @@ def _validate_retention(self, event: EventBase) -> None: max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if type(min_lifetime) is not int: + if type(min_lifetime) is not int: # noqa: E721 raise SynapseError( code=400, msg="'min_lifetime' must be an integer", @@ -159,7 +159,7 @@ def _validate_retention(self, event: EventBase) -> None: ) if max_lifetime is not None: - if type(max_lifetime) is not int: + if type(max_lifetime) is not int: # noqa: E721 raise SynapseError( code=400, msg="'max_lifetime' must be an integer", diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 31e0260b8312..d4e7dd45a9b8 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -280,7 +280,7 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB _strip_unsigned_values(pdu_json) depth = pdu_json["depth"] - if type(depth) is not int: + if type(depth) is not int: # noqa: E721 raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 89bd597409c6..607013f121bf 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1891,7 +1891,7 @@ def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse": ) origin_server_ts = d.get("origin_server_ts") - if type(origin_server_ts) is not int: + if type(origin_server_ts) is not int: # noqa: E721 raise ValueError( "Invalid response: 'origin_server_ts' must be a int but received %r" % origin_server_ts diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 0b17f713ea94..5ce3f345cbeb 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -475,13 +475,11 @@ async def get_public_rooms( See synapse.federation.federation_client.FederationClient.get_public_rooms for more information. """ + path = _create_v1_path("/publicRooms") + if search_filter: # this uses MSC2197 (Search Filtering over Federation) - path = _create_v1_path("/publicRooms") - - data: Dict[str, Any] = { - "include_all_networks": "true" if include_all_networks else "false" - } + data: Dict[str, Any] = {"include_all_networks": include_all_networks} if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: @@ -505,17 +503,15 @@ async def get_public_rooms( ) raise else: - path = _create_v1_path("/publicRooms") - args: Dict[str, Union[str, Iterable[str]]] = { "include_all_networks": "true" if include_all_networks else "false" } if third_party_instance_id: - args["third_party_instance_id"] = (third_party_instance_id,) + args["third_party_instance_id"] = third_party_instance_id if limit: - args["limit"] = [str(limit)] + args["limit"] = str(limit) if since_token: - args["since"] = [since_token] + args["since"] = since_token try: response = await self.client.get_json( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 59ecafa6a094..2b0c50513095 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -218,19 +218,17 @@ def __init__(self, hs: "HomeServer"): self._failed_uia_attempts_ratelimiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, ) # The number of seconds to keep a UI auth session active. self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout - # Ratelimitier for failed /login attempts + # Ratelimiter for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, ) self._clock = self.hs.get_clock() diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 5c71637038b6..a85054545356 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -67,6 +67,7 @@ def __init__(self, hs: "HomeServer"): self._cas_server_url = hs.config.cas.cas_server_url self._cas_service_url = hs.config.cas.cas_service_url + self._cas_protocol_version = hs.config.cas.cas_protocol_version self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute self._cas_required_attributes = hs.config.cas.cas_required_attributes @@ -121,7 +122,10 @@ async def _validate_ticket( Returns: The parsed CAS response. """ - uri = self._cas_server_url + "/proxyValidate" + if self._cas_protocol_version == 3: + uri = self._cas_server_url + "/p3/proxyValidate" + else: + uri = self._cas_server_url + "/proxyValidate" args = { "ticket": ticket, "service": self._build_service_param(service_args), diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 17ff8821d974..798c7039f9b4 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -90,8 +90,7 @@ def __init__(self, hs: "HomeServer"): self._ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_key_requests.per_second, - burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, + cfg=hs.config.ratelimiting.rc_key_requests, ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 33359f6ed748..d12803bf0f31 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -67,6 +67,7 @@ async def get_stream( context = await presence_handler.user_syncing( requester.user.to_string(), + requester.device_id, affect_presence=affect_presence, presence_state=PresenceState.ONLINE, ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 3031384d25bb..472879c964cc 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -66,14 +66,12 @@ def __init__(self, hs: "HomeServer"): self._3pid_validation_ratelimiter_ip = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, - burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + cfg=hs.config.ratelimiting.rc_3pid_validation, ) self._3pid_validation_ratelimiter_address = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, - burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + cfg=hs.config.ratelimiting.rc_3pid_validation, ) async def ratelimit_request_token_requests( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a74db1dccffa..d6be18cdefff 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -379,7 +379,7 @@ def maybe_schedule_expiry(self, event: EventBase) -> None: """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if type(expiry_ts) is not int or event.is_state(): + if type(expiry_ts) is not int or event.is_state(): # noqa: E721 return # _schedule_expiry_for_event won't actually schedule anything if there's already @@ -908,19 +908,6 @@ async def get_event_id_from_transaction( if existing_event_id: return existing_event_id - # Some requsters don't have device IDs (appservice, guests, and access - # tokens minted with the admin API), fallback to checking the access token - # ID, which should be close enough. - if requester.access_token_id: - existing_event_id = ( - await self.store.get_event_id_from_transaction_id_and_token_id( - room_id, - requester.user.to_string(), - requester.access_token_id, - txn_id, - ) - ) - return existing_event_id async def get_event_from_transaction( @@ -1474,23 +1461,23 @@ async def handle_new_client_event( # We now persist the event (and update the cache in parallel, since we # don't want to block on it). - event, context = events_and_context[0] + # + # Note: mypy gets confused if we inline dl and check with twisted#11770. + # Some kind of bug in mypy's deduction? + deferreds = ( + run_in_background( + self._persist_events, + requester=requester, + events_and_context=events_and_context, + ratelimit=ratelimit, + extra_users=extra_users, + ), + run_in_background( + self.cache_joined_hosts_for_events, events_and_context + ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), + ) result, _ = await make_deferred_yieldable( - gather_results( - ( - run_in_background( - self._persist_events, - requester=requester, - events_and_context=events_and_context, - ratelimit=ratelimit, - extra_users=extra_users, - ), - run_in_background( - self.cache_joined_hosts_for_events, events_and_context - ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), - ), - consumeErrors=True, - ) + gather_results(deferreds, consumeErrors=True) ).addErrback(unwrapFirstError) return result @@ -1921,7 +1908,10 @@ async def persist_and_notify_client_events( # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. run_as_background_process( - "bump_presence_active_time", self._bump_active_time, requester.user + "bump_presence_active_time", + self._bump_active_time, + requester.user, + requester.device_id, ) async def _notify() -> None: @@ -1958,10 +1948,10 @@ async def _maybe_kick_guest_users( logger.info("maybe_kick_guest_users %r", current_state) await self.hs.get_room_member_handler().kick_guest_users(current_state) - async def _bump_active_time(self, user: UserID) -> None: + async def _bump_active_time(self, user: UserID, device_id: Optional[str]) -> None: try: presence = self.hs.get_presence_handler() - await presence.bump_presence_active_time(user) + await presence.bump_presence_active_time(user, device_id) except Exception: logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index e8e9db4b91a6..2f841863ae74 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -23,6 +23,7 @@ """ import abc import contextlib +import itertools import logging from bisect import bisect from contextlib import contextmanager @@ -151,15 +152,13 @@ def __init__(self, hs: "HomeServer"): self._federation_queue = PresenceFederationQueue(hs, self) - self._busy_presence_enabled = hs.config.experimental.msc3026_enabled - self.VALID_PRESENCE: Tuple[str, ...] = ( PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE, ) - if self._busy_presence_enabled: + if hs.config.experimental.msc3026_enabled: self.VALID_PRESENCE += (PresenceState.BUSY,) active_presence = self.store.take_presence_startup_info() @@ -167,7 +166,11 @@ def __init__(self, hs: "HomeServer"): @abc.abstractmethod async def user_syncing( - self, user_id: str, affect_presence: bool, presence_state: str + self, + user_id: str, + device_id: Optional[str], + affect_presence: bool, + presence_state: str, ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -178,6 +181,7 @@ async def user_syncing( Args: user_id: the user that is starting a sync + device_id: the user's device that is starting a sync affect_presence: If false this function will be a no-op. Useful for streams that are not associated with an actual client that is being used by a user. @@ -185,15 +189,17 @@ async def user_syncing( """ @abc.abstractmethod - def get_currently_syncing_users_for_replication(self) -> Iterable[str]: - """Get an iterable of syncing users on this worker, to send to the presence handler + def get_currently_syncing_users_for_replication( + self, + ) -> Iterable[Tuple[str, Optional[str]]]: + """Get an iterable of syncing users and devices on this worker, to send to the presence handler This is called when a replication connection is established. It should return - a list of user ids, which are then sent as USER_SYNC commands to inform the - process handling presence about those users. + a list of tuples of user ID & device ID, which are then sent as USER_SYNC commands + to inform the process handling presence about those users/devices. Returns: - An iterable of user_id strings. + An iterable of tuples of user ID and device ID. """ async def get_state(self, target_user: UserID) -> UserPresenceState: @@ -254,28 +260,39 @@ async def current_state_for_user(self, user_id: str) -> UserPresenceState: async def set_state( self, target_user: UserID, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. + device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. - ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. - If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. + is_sync: True if this update was from a sync, which results in + *not* overriding a previously set BUSY status, updating the + user's last_user_sync_ts, and ignoring the "status_msg" field of + the `state` dict. """ @abc.abstractmethod - async def bump_presence_active_time(self, user: UserID) -> None: + async def bump_presence_active_time( + self, user: UserID, device_id: Optional[str] + ) -> None: """We've seen the user do something that indicates they're interacting with the app. """ async def update_external_syncs_row( # noqa: B027 (no-op by design) - self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + self, + process_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + sync_time_msec: int, ) -> None: """Update the syncing users for an external process as a delta. @@ -286,6 +303,7 @@ async def update_external_syncs_row( # noqa: B027 (no-op by design) syncing against. This allows synapse to process updates as user start and stop syncing against a given process. user_id: The user who has started or stopped syncing + device_id: The user's device that has started or stopped syncing is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ @@ -381,7 +399,9 @@ async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: # We set force_notify=True here so that this presence update is guaranteed to # increment the presence stream ID (which resending the current user's presence # otherwise would not do). - await self.set_state(UserID.from_string(user_id), state, force_notify=True) + await self.set_state( + UserID.from_string(user_id), None, state, force_notify=True + ) async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: raise NotImplementedError( @@ -414,16 +434,18 @@ def __init__(self, hs: "HomeServer"): hs.config.worker.writers.presence, ) - # The number of ongoing syncs on this process, by user id. + # The number of ongoing syncs on this process, by (user ID, device ID). # Empty if _presence_enabled is false. - self._user_to_num_current_syncs: Dict[str, int] = {} + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - # user_id -> last_sync_ms. Lists the users that have stopped syncing but - # we haven't notified the presence writer of that yet - self.users_going_offline: Dict[str, int] = {} + # (user_id, device_id) -> last_sync_ms. Lists the devices that have stopped + # syncing but we haven't notified the presence writer of that yet + self._user_devices_going_offline: Dict[Tuple[str, Optional[str]], int] = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -446,42 +468,54 @@ async def _on_shutdown(self) -> None: ClearUserSyncsCommand(self.instance_id) ) - def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None: + def send_user_sync( + self, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + last_sync_ms: int, + ) -> None: if self._presence_enabled: self.hs.get_replication_command_handler().send_user_sync( - self.instance_id, user_id, is_syncing, last_sync_ms + self.instance_id, user_id, device_id, is_syncing, last_sync_ms ) - def mark_as_coming_online(self, user_id: str) -> None: + def mark_as_coming_online(self, user_id: str, device_id: Optional[str]) -> None: """A user has started syncing. Send a UserSync to the presence writer, unless they had recently stopped syncing. """ - going_offline = self.users_going_offline.pop(user_id, None) + going_offline = self._user_devices_going_offline.pop((user_id, device_id), None) if not going_offline: # Safe to skip because we haven't yet told the presence writer they # were offline - self.send_user_sync(user_id, True, self.clock.time_msec()) + self.send_user_sync(user_id, device_id, True, self.clock.time_msec()) - def mark_as_going_offline(self, user_id: str) -> None: + def mark_as_going_offline(self, user_id: str, device_id: Optional[str]) -> None: """A user has stopped syncing. We wait before notifying the presence writer as its likely they'll come back soon. This allows us to avoid sending a stopped syncing immediately followed by a started syncing notification to the presence writer """ - self.users_going_offline[user_id] = self.clock.time_msec() + self._user_devices_going_offline[(user_id, device_id)] = self.clock.time_msec() def send_stop_syncing(self) -> None: """Check if there are any users who have stopped syncing a while ago and haven't come back yet. If there are poke the presence writer about them. """ now = self.clock.time_msec() - for user_id, last_sync_ms in list(self.users_going_offline.items()): + for (user_id, device_id), last_sync_ms in list( + self._user_devices_going_offline.items() + ): if now - last_sync_ms > UPDATE_SYNCING_USERS_MS: - self.users_going_offline.pop(user_id, None) - self.send_user_sync(user_id, False, last_sync_ms) + self._user_devices_going_offline.pop((user_id, device_id), None) + self.send_user_sync(user_id, device_id, False, last_sync_ms) async def user_syncing( - self, user_id: str, affect_presence: bool, presence_state: str + self, + user_id: str, + device_id: Optional[str], + affect_presence: bool, + presence_state: str, ) -> ContextManager[None]: """Record that a user is syncing. @@ -491,36 +525,32 @@ async def user_syncing( if not affect_presence or not self._presence_enabled: return _NullContextManager() - prev_state = await self.current_state_for_user(user_id) - if prev_state.state != PresenceState.BUSY: - # We set state here but pass ignore_status_msg = True as we don't want to - # cause the status message to be cleared. - # Note that this causes last_active_ts to be incremented which is not - # what the spec wants: see comment in the BasePresenceHandler version - # of this function. - await self.set_state( - UserID.from_string(user_id), - {"presence": presence_state}, - ignore_status_msg=True, - ) + # Note that this causes last_active_ts to be incremented which is not + # what the spec wants. + await self.set_state( + UserID.from_string(user_id), + device_id, + state={"presence": presence_state}, + is_sync=True, + ) - curr_sync = self._user_to_num_current_syncs.get(user_id, 0) - self._user_to_num_current_syncs[user_id] = curr_sync + 1 + curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0) + self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1 - # If we went from no in flight sync to some, notify replication - if self._user_to_num_current_syncs[user_id] == 1: - self.mark_as_coming_online(user_id) + # If this is the first in-flight sync, notify replication + if self._user_device_to_num_current_syncs[(user_id, device_id)] == 1: + self.mark_as_coming_online(user_id, device_id) def _end() -> None: # We check that the user_id is in user_to_num_current_syncs because # user_to_num_current_syncs may have been cleared if we are # shutting down. - if user_id in self._user_to_num_current_syncs: - self._user_to_num_current_syncs[user_id] -= 1 + if (user_id, device_id) in self._user_device_to_num_current_syncs: + self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1 - # If we went from one in flight sync to non, notify replication - if self._user_to_num_current_syncs[user_id] == 0: - self.mark_as_going_offline(user_id) + # If there are no more in-flight syncs, notify replication + if self._user_device_to_num_current_syncs[(user_id, device_id)] == 0: + self.mark_as_going_offline(user_id, device_id) @contextlib.contextmanager def _user_syncing() -> Generator[None, None, None]: @@ -587,28 +617,34 @@ async def process_replication_rows( # If this is a federation sender, notify about presence updates. await self.maybe_send_presence_to_interested_destinations(state_to_notify) - def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + def get_currently_syncing_users_for_replication( + self, + ) -> Iterable[Tuple[str, Optional[str]]]: return [ - user_id - for user_id, count in self._user_to_num_current_syncs.items() + user_id_device_id + for user_id_device_id, count in self._user_device_to_num_current_syncs.items() if count > 0 ] async def set_state( self, target_user: UserID, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. + device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. - ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. - If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. + is_sync: True if this update was from a sync, which results in + *not* overriding a previously set BUSY status, updating the + user's last_user_sync_ts, and ignoring the "status_msg" field of + the `state` dict. """ presence = state["presence"] @@ -625,12 +661,15 @@ async def set_state( await self._set_state_client( instance_name=self._presence_writer_instance, user_id=user_id, + device_id=device_id, state=state, - ignore_status_msg=ignore_status_msg, force_notify=force_notify, + is_sync=is_sync, ) - async def bump_presence_active_time(self, user: UserID) -> None: + async def bump_presence_active_time( + self, user: UserID, device_id: Optional[str] + ) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -641,7 +680,9 @@ async def bump_presence_active_time(self, user: UserID) -> None: # Proxy request to instance that writes presence user_id = user.to_string() await self._bump_active_client( - instance_name=self._presence_writer_instance, user_id=user_id + instance_name=self._presence_writer_instance, + user_id=user_id, + device_id=device_id, ) @@ -703,17 +744,23 @@ def __init__(self, hs: "HomeServer"): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs: Dict[str, int] = {} + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} # Keeps track of the number of *ongoing* syncs on other processes. + # # While any sync is ongoing on another process the user will never # go offline. + # # Each process has a unique identifier and an update frequency. If # no update is received from that process within the update period then # we assume that all the sync requests on that process have stopped. - # Stored as a dict from process_id to set of user_id, and a dict of - # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs: Dict[str, Set[str]] = {} + # Stored as a dict from process_id to set of (user_id, device_id), and + # a dict of process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs: Dict[ + str, Set[Tuple[str, Optional[str]]] + ] = {} self.external_process_last_updated_ms: Dict[str, int] = {} self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -918,7 +965,10 @@ async def _handle_timeouts(self) -> None: # that were syncing on that process to see if they need to be timed # out. users_to_check.update( - self.external_process_to_current_syncs.pop(process_id, ()) + user_id + for user_id, device_id in self.external_process_to_current_syncs.pop( + process_id, () + ) ) self.external_process_last_updated_ms.pop(process_id) @@ -931,11 +981,15 @@ async def _handle_timeouts(self) -> None: syncing_user_ids = { user_id - for user_id, count in self.user_to_num_current_syncs.items() + for (user_id, _), count in self._user_device_to_num_current_syncs.items() if count } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) + syncing_user_ids.update( + user_id + for user_id, _ in itertools.chain( + *self.external_process_to_current_syncs.values() + ) + ) changes = handle_timeouts( states, @@ -946,7 +1000,9 @@ async def _handle_timeouts(self) -> None: return await self._update_states(changes) - async def bump_presence_active_time(self, user: UserID) -> None: + async def bump_presence_active_time( + self, user: UserID, device_id: Optional[str] + ) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -969,6 +1025,7 @@ async def bump_presence_active_time(self, user: UserID) -> None: async def user_syncing( self, user_id: str, + device_id: Optional[str], affect_presence: bool = True, presence_state: str = PresenceState.ONLINE, ) -> ContextManager[None]: @@ -980,7 +1037,8 @@ async def user_syncing( when users disconnect/reconnect. Args: - user_id + user_id: the user that is starting a sync + device_id: the user's device that is starting a sync affect_presence: If false this function will be a no-op. Useful for streams that are not associated with an actual client that is being used by a user. @@ -989,52 +1047,21 @@ async def user_syncing( if not affect_presence or not self._presence_enabled: return _NullContextManager() - curr_sync = self.user_to_num_current_syncs.get(user_id, 0) - self.user_to_num_current_syncs[user_id] = curr_sync + 1 - - prev_state = await self.current_state_for_user(user_id) - - # If they're busy then they don't stop being busy just by syncing, - # so just update the last sync time. - if prev_state.state != PresenceState.BUSY: - # XXX: We set_state separately here and just update the last_active_ts above - # This keeps the logic as similar as possible between the worker and single - # process modes. Using set_state will actually cause last_active_ts to be - # updated always, which is not what the spec calls for, but synapse has done - # this for... forever, I think. - await self.set_state( - UserID.from_string(user_id), - {"presence": presence_state}, - ignore_status_msg=True, - ) - # Retrieve the new state for the logic below. This should come from the - # in-memory cache. - prev_state = await self.current_state_for_user(user_id) + curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0) + self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1 - # To keep the single process behaviour consistent with worker mode, run the - # same logic as `update_external_syncs_row`, even though it looks weird. - if prev_state.state == PresenceState.OFFLINE: - await self._update_states( - [ - prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=self.clock.time_msec(), - last_user_sync_ts=self.clock.time_msec(), - ) - ] - ) - # otherwise, set the new presence state & update the last sync time, - # but don't update last_active_ts as this isn't an indication that - # they've been active (even though it's probably been updated by - # set_state above) - else: - await self._update_states( - [prev_state.copy_and_replace(last_user_sync_ts=self.clock.time_msec())] - ) + # Note that this causes last_active_ts to be incremented which is not + # what the spec wants. + await self.set_state( + UserID.from_string(user_id), + device_id, + state={"presence": presence_state}, + is_sync=True, + ) async def _end() -> None: try: - self.user_to_num_current_syncs[user_id] -= 1 + self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1 prev_state = await self.current_state_for_user(user_id) await self._update_states( @@ -1056,12 +1083,19 @@ def _user_syncing() -> Generator[None, None, None]: return _user_syncing() - def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + def get_currently_syncing_users_for_replication( + self, + ) -> Iterable[Tuple[str, Optional[str]]]: # since we are the process handling presence, there is nothing to do here. return [] async def update_external_syncs_row( - self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + self, + process_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + sync_time_msec: int, ) -> None: """Update the syncing users for an external process as a delta. @@ -1070,6 +1104,7 @@ async def update_external_syncs_row( syncing against. This allows synapse to process updates as user start and stop syncing against a given process. user_id: The user who has started or stopped syncing + device_id: The user's device that has started or stopped syncing is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ @@ -1080,31 +1115,27 @@ async def update_external_syncs_row( process_id, set() ) - updates = [] - if is_syncing and user_id not in process_presence: - if prev_state.state == PresenceState.OFFLINE: - updates.append( - prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=sync_time_msec, - last_user_sync_ts=sync_time_msec, - ) - ) - else: - updates.append( - prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) - ) - process_presence.add(user_id) - elif user_id in process_presence: - updates.append( - prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + # USER_SYNC is sent when a user's device starts or stops syncing on + # a remote # process. (But only for the initial and last sync for that + # device.) + # + # When a device *starts* syncing it also calls set_state(...) which + # will update the state, last_active_ts, and last_user_sync_ts. + # Simply ensure the user & device is tracked as syncing in this case. + # + # When a device *stops* syncing, update the last_user_sync_ts and mark + # them as no longer syncing. Note this doesn't quite match the + # monolith behaviour, which updates last_user_sync_ts at the end of + # every sync, not just the last in-flight sync. + if is_syncing and (user_id, device_id) not in process_presence: + process_presence.add((user_id, device_id)) + elif not is_syncing and (user_id, device_id) in process_presence: + new_state = prev_state.copy_and_replace( + last_user_sync_ts=sync_time_msec ) + await self._update_states([new_state]) - if not is_syncing: - process_presence.discard(user_id) - - if updates: - await self._update_states(updates) + process_presence.discard((user_id, device_id)) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() @@ -1118,7 +1149,9 @@ async def update_external_syncs_clear(self, process_id: str) -> None: process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = await self.current_state_for_users(process_presence) + prev_states = await self.current_state_for_users( + {user_id for user_id, device_id in process_presence} + ) time_now_ms = self.clock.time_msec() await self._update_states( @@ -1203,18 +1236,22 @@ async def incoming_presence(self, origin: str, content: JsonDict) -> None: async def set_state( self, target_user: UserID, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. + device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. - ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. - If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. + is_sync: True if this update was from a sync, which results in + *not* overriding a previously set BUSY status, updating the + user's last_user_sync_ts, and ignoring the "status_msg" field of + the `state` dict. """ status_msg = state.get("status_msg", None) presence = state["presence"] @@ -1227,18 +1264,27 @@ async def set_state( return user_id = target_user.to_string() + now = self.clock.time_msec() prev_state = await self.current_state_for_user(user_id) + # Syncs do not override a previous presence of busy. + # + # TODO: This is a hack for lack of multi-device support. Unfortunately + # removing this requires coordination with clients. + if prev_state.state == PresenceState.BUSY and is_sync: + presence = PresenceState.BUSY + new_fields = {"state": presence} - if not ignore_status_msg: - new_fields["status_msg"] = status_msg + if presence == PresenceState.ONLINE or presence == PresenceState.BUSY: + new_fields["last_active_ts"] = now - if presence == PresenceState.ONLINE or ( - presence == PresenceState.BUSY and self._busy_presence_enabled - ): - new_fields["last_active_ts"] = self.clock.time_msec() + if is_sync: + new_fields["last_user_sync_ts"] = now + else: + # Syncs do not override the status message. + new_fields["status_msg"] = status_msg await self._update_states( [prev_state.copy_and_replace(**new_fields)], force_notify=force_notify diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 1d8d4a72e7a2..de0f04e3fe48 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -112,8 +112,7 @@ def __init__(self, hs: "HomeServer"): self._join_rate_limiter_local = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, - burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, + cfg=hs.config.ratelimiting.rc_joins_local, ) # Tracks joins from local users to rooms this server isn't a member of. # I.e. joins this server makes by requesting /make_join /send_join from @@ -121,8 +120,7 @@ def __init__(self, hs: "HomeServer"): self._join_rate_limiter_remote = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, - burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, + cfg=hs.config.ratelimiting.rc_joins_remote, ) # TODO: find a better place to keep this Ratelimiter. # It needs to be @@ -135,8 +133,7 @@ def __init__(self, hs: "HomeServer"): self._join_rate_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, - burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count, + cfg=hs.config.ratelimiting.rc_joins_per_room, ) # Ratelimiter for invites, keyed by room (across all issuers, all @@ -144,8 +141,7 @@ def __init__(self, hs: "HomeServer"): self._invites_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_room, ) # Ratelimiter for invites, keyed by recipient (across all rooms, all @@ -153,8 +149,7 @@ def __init__(self, hs: "HomeServer"): self._invites_per_recipient_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_user, ) # Ratelimiter for invites, keyed by issuer (across all rooms, all @@ -162,15 +157,13 @@ def __init__(self, hs: "HomeServer"): self._invites_per_issuer_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_issuer, ) self._third_party_invite_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second, - burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count, + cfg=hs.config.ratelimiting.rc_third_party_invite, ) self.request_ratelimiter = hs.get_request_ratelimiter() diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index dad3e23470fb..dd559b4c450f 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -35,6 +35,7 @@ UnsupportedRoomVersionError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.config.ratelimiting import RatelimitSettings from synapse.events import EventBase from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache @@ -94,7 +95,9 @@ def __init__(self, hs: "HomeServer"): self._server_name = hs.hostname self._federation_client = hs.get_federation_client() self._ratelimiter = Ratelimiter( - store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 + store=self._store, + clock=hs.get_clock(), + cfg=RatelimitSettings("", per_second=5, burst_count=10), ) # If a user tries to fetch the same page multiple times in quick succession, diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index 804cc6e81e00..05e21509deac 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -23,9 +23,11 @@ import twisted from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IOpenSSLContextFactory +from twisted.internet.endpoints import HostnameEndpoint +from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory from twisted.internet.ssl import optionsForClientTLS from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory +from twisted.protocols.tls import TLSMemoryBIOFactory from synapse.logging.context import make_deferred_yieldable from synapse.types import ISynapseReactor @@ -97,6 +99,7 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory: **kwargs, ) + factory: IProtocolFactory if _is_old_twisted: # before twisted 21.2, we have to override the ESMTPSender protocol to disable # TLS @@ -110,22 +113,13 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory: factory = build_sender_factory(hostname=smtphost if enable_tls else None) if force_tls: - reactor.connectSSL( - smtphost, - smtpport, - factory, - optionsForClientTLS(smtphost), - timeout=30, - bindAddress=None, - ) - else: - reactor.connectTCP( - smtphost, - smtpport, - factory, - timeout=30, - bindAddress=None, - ) + factory = TLSMemoryBIOFactory(optionsForClientTLS(smtphost), True, factory) + + endpoint = HostnameEndpoint( + reactor, smtphost, smtpport, timeout=30, bindAddress=None + ) + + await make_deferred_yieldable(endpoint.connect(factory)) await make_deferred_yieldable(d) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 583c03447c17..11342ccac8a3 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -243,7 +243,7 @@ def _validate(v: Any) -> bool: return ( isinstance(v, list) and len(v) == 2 - and type(v[0]) == int + and type(v[0]) == int # noqa: E721 and isinstance(v[1], dict) ) diff --git a/synapse/http/server.py b/synapse/http/server.py index 5109cec983c9..3bbf91298e3d 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -115,7 +115,13 @@ def return_json_error( if exc.headers is not None: for header, value in exc.headers.items(): request.setHeader(header, value) - logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) + error_ctx = exc.debug_context + if error_ctx: + logger.info( + "%s SynapseError: %s - %s (%s)", request, error_code, exc.msg, error_ctx + ) + else: + logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) elif f.check(CancelledError): error_code = HTTP_STATUS_REQUEST_CANCELLED error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN} diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index b78d6e17c93c..98c6038ff23f 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -44,6 +44,7 @@ "processName", "relativeCreated", "stack_info", + "taskName", "thread", "threadName", } diff --git a/synapse/logging/context.py b/synapse/logging/context.py index f62bea968fe4..64c6ae451208 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -809,23 +809,24 @@ def run_in_background( # type: ignore[misc] # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain # value. Convert it to a `Deferred`. + d: "defer.Deferred[R]" if isinstance(res, typing.Coroutine): # Wrap the coroutine in a `Deferred`. - res = defer.ensureDeferred(res) + d = defer.ensureDeferred(res) elif isinstance(res, defer.Deferred): - pass + d = res elif isinstance(res, Awaitable): # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable` # or `Future` from `make_awaitable`. - res = defer.ensureDeferred(_unwrap_awaitable(res)) + d = defer.ensureDeferred(_unwrap_awaitable(res)) else: # `res` is a plain value. Wrap it in a `Deferred`. - res = defer.succeed(res) + d = defer.succeed(res) - if res.called and not res.paused: + if d.called and not d.paused: # The function should have maintained the logcontext, so we can # optimise out the messing about - return res + return d # The function may have reset the context before returning, so # we need to restore it now. @@ -843,8 +844,8 @@ def run_in_background( # type: ignore[misc] # which is supposed to have a single entry and exit point. But # by spawning off another deferred, we are effectively # adding a new exit point.) - res.addBoth(_set_context_cb, ctx) - return res + d.addBoth(_set_context_cb, ctx) + return d T = TypeVar("T") @@ -877,7 +878,7 @@ def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T] ResultT = TypeVar("ResultT") -def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: +def _set_context_cb(result: ResultT, context: LoggingContextOrSentinel) -> ResultT: """A callback function which just sets the logging context""" set_current_context(context) return result diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index be910128aa4e..5c3045e197e9 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -910,10 +910,10 @@ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> async def _wrapper( *args: P.args, **kwargs: P.kwargs ) -> Any: # Return type is RInner - with wrapping_logic(func, *args, **kwargs): - # type-ignore: func() returns R, but mypy doesn't know that R is - # Awaitable here. - return await func(*args, **kwargs) # type: ignore[misc] + # type-ignore: func() returns R, but mypy doesn't know that R is + # Awaitable here. + with wrapping_logic(func, *args, **kwargs): # type: ignore[arg-type] + return await func(*args, **kwargs) else: # The other case here handles sync functions including those decorated with @@ -980,8 +980,7 @@ def trace_with_opname( See the module's doc string for usage examples. """ - # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 - @contextlib.contextmanager # type: ignore[arg-type] + @contextlib.contextmanager def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: @@ -1024,8 +1023,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: if not opentracing: return func - # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 - @contextlib.contextmanager # type: ignore[arg-type] + @contextlib.contextmanager def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 4b750c700b89..1b7b014f9ac2 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -214,7 +214,10 @@ async def create_content( user_id=auth_user, ) - await self._generate_thumbnails(None, media_id, media_id, media_type) + try: + await self._generate_thumbnails(None, media_id, media_id, media_type) + except Exception as e: + logger.info("Failed to generate thumbnails: %s", e) return MXCUri(self.server_name, media_id) diff --git a/synapse/media/oembed.py b/synapse/media/oembed.py index 5ad9eec80b97..2ce842c98d4a 100644 --- a/synapse/media/oembed.py +++ b/synapse/media/oembed.py @@ -204,7 +204,7 @@ def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult: calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) - if type(val) is int: + if type(val) is int: # noqa: E721 open_graph_response[f"og:video:{size}"] = val elif oembed_type == "link": diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index 2bfa58ceee5d..d8979813b335 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -78,7 +78,7 @@ def __init__(self, input_path: str): image_exif = self.image._getexif() # type: ignore if image_exif is not None: image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) - assert type(image_orientation) is int + assert type(image_orientation) is int # noqa: E721 self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) except Exception as e: # A lot of parsing errors can happen when parsing EXIF diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 990c079c815b..554634579ed0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -379,7 +379,7 @@ async def _action_for_event_by_user( keys = list(notification_levels.keys()) for key in keys: level = notification_levels.get(key, SENTINEL) - if level is not SENTINEL and type(level) is not int: + if level is not SENTINEL and type(level) is not int: # noqa: E721 try: notification_levels[key] = int(level) except (TypeError, ValueError): @@ -472,7 +472,11 @@ async def _action_for_event_by_user( def _is_simple_value(value: Any) -> bool: - return isinstance(value, (bool, str)) or type(value) is int or value is None + return ( + isinstance(value, (bool, str)) + or type(value) is int # noqa: E721 + or value is None + ) def _flatten_dict( diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 73f3de364205..209833d28753 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -62,7 +62,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): NAME = "multi_user_device_resync" PATH_ARGS = () - CACHE = False + CACHE = True def __init__(self, hs: "HomeServer"): super().__init__(hs) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index db16aac9c206..6c9e79fb07c9 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -51,14 +51,14 @@ def __init__(self, hs: "HomeServer"): self._presence_handler = hs.get_presence_handler() @staticmethod - async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override] - return {} + async def _serialize_payload(user_id: str, device_id: Optional[str]) -> JsonDict: # type: ignore[override] + return {"device_id": device_id} async def _handle_request( # type: ignore[override] self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: await self._presence_handler.bump_presence_active_time( - UserID.from_string(user_id) + UserID.from_string(user_id), content.get("device_id") ) return (200, {}) @@ -73,8 +73,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint): { "state": { ... }, - "ignore_status_msg": false, - "force_notify": false + "force_notify": false, + "is_sync": false } 200 OK @@ -95,14 +95,16 @@ def __init__(self, hs: "HomeServer"): @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> JsonDict: return { + "device_id": device_id, "state": state, - "ignore_status_msg": ignore_status_msg, "force_notify": force_notify, + "is_sync": is_sync, } async def _handle_request( # type: ignore[override] @@ -110,9 +112,10 @@ async def _handle_request( # type: ignore[override] ) -> Tuple[int, JsonDict]: await self._presence_handler.set_state( UserID.from_string(user_id), + content.get("device_id"), content["state"], - content["ignore_status_msg"], content["force_notify"], + content.get("is_sync", False), ) return (200, {}) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 10f5c98ff8a9..e616b5e1c8ad 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -267,27 +267,38 @@ class UserSyncCommand(Command): NAME = "USER_SYNC" def __init__( - self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + self, + instance_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + last_sync_ms: int, ): self.instance_id = instance_id self.user_id = user_id + self.device_id = device_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand": - instance_id, user_id, state, last_sync_ms = line.split(" ", 3) + device_id: Optional[str] + instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4) + + if device_id == "None": + device_id = None if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(instance_id, user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, device_id, state == "start", int(last_sync_ms)) def to_line(self) -> str: return " ".join( ( self.instance_id, self.user_id, + str(self.device_id), "start" if self.is_syncing else "end", str(self.last_sync_ms), ) @@ -452,6 +463,17 @@ def to_line(self) -> str: return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key]) +class NewActiveTaskCommand(_SimpleCommand): + """Sent to inform instance handling background tasks that a new active task is available to run. + + Format:: + + NEW_ACTIVE_TASK "" + """ + + NAME = "NEW_ACTIVE_TASK" + + _COMMANDS: Tuple[Type[Command], ...] = ( ServerCommand, RdataCommand, @@ -466,6 +488,7 @@ def to_line(self) -> str: RemoteServerUpCommand, ClearUserSyncsCommand, LockReleasedCommand, + NewActiveTaskCommand, ) # Map of command name to command type. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 38adcbe1d0e8..d9045d7b73f5 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -40,6 +40,7 @@ Command, FederationAckCommand, LockReleasedCommand, + NewActiveTaskCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -238,6 +239,10 @@ def __init__(self, hs: "HomeServer"): if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + self._task_scheduler = None + if hs.config.worker.run_background_tasks: + self._task_scheduler = hs.get_task_scheduler() + if hs.config.redis.redis_enabled: # If we're using Redis, it's the background worker that should # receive USER_IP commands and store the relevant client IPs. @@ -423,7 +428,11 @@ def on_USER_SYNC( if self._is_presence_writer: return self._presence_handler.update_external_syncs_row( - cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + cmd.instance_id, + cmd.user_id, + cmd.device_id, + cmd.is_syncing, + cmd.last_sync_ms, ) else: return None @@ -663,6 +672,15 @@ def on_LOCK_RELEASED( cmd.instance_name, cmd.lock_name, cmd.lock_key ) + async def on_NEW_ACTIVE_TASK( + self, conn: IReplicationConnection, cmd: NewActiveTaskCommand + ) -> None: + """Called when get a new NEW_ACTIVE_TASK command.""" + if self._task_scheduler: + task = await self._task_scheduler.get_task(cmd.data) + if task: + await self._task_scheduler._launch_task(task) + def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -685,9 +703,9 @@ def new_connection(self, connection: IReplicationConnection) -> None: ) now = self._clock.time_msec() - for user_id in currently_syncing: + for user_id, device_id in currently_syncing: connection.send_command( - UserSyncCommand(self._instance_id, user_id, True, now) + UserSyncCommand(self._instance_id, user_id, device_id, True, now) ) def lost_connection(self, connection: IReplicationConnection) -> None: @@ -739,11 +757,16 @@ def send_federation_ack(self, token: int) -> None: self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( - self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + self, + instance_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + last_sync_ms: int, ) -> None: """Poke the master that a user has started/stopped syncing.""" self.send_command( - UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms) ) def send_user_ip( @@ -776,6 +799,10 @@ def on_lock_released( if instance_name == self._instance_name: self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key)) + def send_new_active_task(self, task_id: str) -> None: + """Called when a new task has been scheduled for immediate launch and is ACTIVE.""" + self.send_command(NewActiveTaskCommand(task_id)) + UpdateToken = TypeVar("UpdateToken") UpdateRow = TypeVar("UpdateRow") diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 55e752fda85a..94170715fb77 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -157,7 +157,7 @@ async def on_POST( logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: ts = body["purge_up_to_ts"] - if type(ts) is not int: + if type(ts) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "purge_up_to_ts must be an int", diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 95e751288b03..ffce92d45ee1 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -143,7 +143,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: else: # Get length of token to generate (default is 16) length = body.get("length", 16) - if type(length) is not int: + if type(length) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "length must be an integer", @@ -163,7 +163,8 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: uses_allowed = body.get("uses_allowed", None) if not ( - uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) + uses_allowed is None + or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -172,13 +173,16 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ) expiry_time = body.get("expiry_time", None) - if type(expiry_time) not in (int, type(None)): + if expiry_time is not None and type(expiry_time) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if type(expiry_time) is int and expiry_time < self.clock.time_msec(): + if ( + type(expiry_time) is int # noqa: E721 + and expiry_time < self.clock.time_msec() + ): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", @@ -283,7 +287,7 @@ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi uses_allowed = body["uses_allowed"] if not ( uses_allowed is None - or (type(uses_allowed) is int and uses_allowed >= 0) + or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -294,13 +298,16 @@ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi if "expiry_time" in body: expiry_time = body["expiry_time"] - if type(expiry_time) not in (int, type(None)): + if expiry_time is not None and type(expiry_time) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if type(expiry_time) is int and expiry_time < self.clock.time_msec(): + if ( + type(expiry_time) is int # noqa: E721 + and expiry_time < self.clock.time_msec() + ): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 240e6254b0bd..625a47ec1a5a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1172,14 +1172,17 @@ async def on_POST( messages_per_second = body.get("messages_per_second", 0) burst_count = body.get("burst_count", 0) - if type(messages_per_second) is not int or messages_per_second < 0: + if ( + type(messages_per_second) is not int # noqa: E721 + or messages_per_second < 0 + ): raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) - if type(burst_count) is not int or burst_count < 0: + if type(burst_count) is not int or burst_count < 0: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index d724c6892067..7be327e26f08 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -120,14 +120,12 @@ def __init__(self, hs: "HomeServer"): self._address_ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_address, ) self._account_ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_account, ) # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index b1629f94a5f8..d189a923b5bf 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.ratelimiting import Ratelimiter +from synapse.config.ratelimiting import RatelimitSettings from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -66,15 +67,18 @@ def __init__(self, hs: "HomeServer"): self.token_timeout = hs.config.auth.login_via_existing_token_timeout self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth - # Ratelimit aggressively to a maxmimum of 1 request per minute. + # Ratelimit aggressively to a maximum of 1 request per minute. # # This endpoint can be used to spawn additional sessions and could be # abused by a malicious client to create many sessions. self._ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - rate_hz=1 / 60, - burst_count=1, + cfg=RatelimitSettings( + key="", + per_second=1 / 60, + burst_count=1, + ), ) @interactive_auth_handler diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py index 8e193330f8bc..d578faa96984 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py @@ -97,7 +97,7 @@ async def on_PUT( raise SynapseError(400, "Unable to parse state") if self._use_presence: - await self.presence_handler.set_state(user, state) + await self.presence_handler.set_state(user, requester.device_id, state) return 200, {} diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 4f96e51eeb93..1707e519723a 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -52,7 +52,9 @@ async def on_POST( ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time( + requester.user, requester.device_id + ) body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 316e7b99821e..869a37445950 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -94,7 +94,9 @@ async def on_POST( Codes.INVALID_PARAM, ) - await self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time( + requester.user, requester.device_id + ) if receipt_type == ReceiptTypes.FULLY_READ: await self.read_marker_handler.received_client_read_marker( diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 77e3b91b7999..132623462adc 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -376,8 +376,7 @@ def __init__(self, hs: "HomeServer"): self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, - burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, + cfg=hs.config.ratelimiting.rc_registration_token_validity, ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index ac1a63ca2745..ee93e459f6bf 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -55,7 +55,7 @@ async def on_POST( "Param 'reason' must be a string", Codes.BAD_JSON, ) - if type(body.get("score", 0)) is not int: + if type(body.get("score", 0)) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index dc498001e450..553938ce9d13 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1229,7 +1229,9 @@ async def on_PUT( content = parse_json_object_from_request(request) - await self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time( + requester.user, requester.device_id + ) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index d7854ed4fd9d..42bdd3bb108b 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -205,6 +205,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: context = await self.presence_handler.user_syncing( user.to_string(), + requester.device_id, affect_presence=affect_presence, presence_state=set_presence, ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 981fd1f58a68..0aaa838d0478 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -16,6 +16,7 @@ import re from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple +from pydantic import Extra, StrictInt, StrictStr from signedjson.sign import sign_json from twisted.web.server import Request @@ -24,9 +25,10 @@ from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, + parse_and_validate_json_object_from_request, parse_integer, - parse_json_object_from_request, ) +from synapse.rest.models import RequestBodyModel from synapse.storage.keys import FetchKeyResultForRemote from synapse.types import JsonDict from synapse.util import json_decoder @@ -38,6 +40,13 @@ logger = logging.getLogger(__name__) +class _KeyQueryCriteriaDataModel(RequestBodyModel): + class Config: + extra = Extra.allow + + minimum_valid_until_ts: Optional[StrictInt] + + class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported @@ -96,6 +105,9 @@ class RemoteKey(RestServlet): CATEGORY = "Federation requests" + class PostBody(RequestBodyModel): + server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]] + def __init__(self, hs: "HomeServer"): self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main @@ -137,24 +149,29 @@ async def on_GET( ) minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") - arguments = {} - if minimum_valid_until_ts is not None: - arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server: {key_id: arguments}} + query = { + server: { + key_id: _KeyQueryCriteriaDataModel( + minimum_valid_until_ts=minimum_valid_until_ts + ) + } + } else: query = {server: {}} return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) + content = parse_and_validate_json_object_from_request(request, self.PostBody) - query = content["server_keys"] + query = content.server_keys return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, query: JsonDict, query_remote_on_cache_miss: bool = False + self, + query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]], + query_remote_on_cache_miss: bool = False, ) -> JsonDict: logger.info("Handling query for keys %r", query) @@ -196,8 +213,10 @@ async def query_keys( else: ts_added_ms = key_result.added_ts ts_valid_until_ms = key_result.valid_until_ts - req_key = query.get(server_name, {}).get(key_id, {}) - req_valid_until = req_key.get("minimum_valid_until_ts") + req_key = query.get(server_name, {}).get( + key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None) + ) + req_valid_until = req_key.minimum_valid_until_ts if req_valid_until is not None: if ts_valid_until_ms < req_valid_until: logger.debug( diff --git a/synapse/server.py b/synapse/server.py index 7cdd3ea3c2e1..71ead524d684 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -408,8 +408,7 @@ def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( store=self.get_datastores().main, clock=self.get_clock(), - rate_hz=self.config.ratelimiting.rc_registration.per_second, - burst_count=self.config.ratelimiting.rc_registration.burst_count, + cfg=self.config.ratelimiting.rc_registration, ) @cache_in_self @@ -914,6 +913,7 @@ def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager: """Usage metrics shared between phone home stats and the prometheus exporter.""" return CommonUsageMetricsManager(self) + @cache_in_self def get_worker_locks_handler(self) -> WorkerLocksHandler: return WorkerLocksHandler(self) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index ddca0af1da39..7619f405fa09 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -405,14 +405,14 @@ async def run_background_updates(self, sleep: bool) -> None: try: result = await self.do_next_background_update(sleep) back_to_back_failures = 0 - except Exception: + except Exception as e: + logger.exception("Error doing update: %s", e) back_to_back_failures += 1 if back_to_back_failures >= 5: self._aborted = True raise RuntimeError( "5 back-to-back background update failures; aborting." ) - logger.exception("Error doing update") else: if result: logger.info( diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a1c8fb0f46a4..55ac313f33b0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -31,6 +31,7 @@ Iterator, List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -358,7 +359,21 @@ def rowcount(self) -> int: return self.txn.rowcount @property - def description(self) -> Any: + def description( + self, + ) -> Optional[ + Sequence[ + Tuple[ + str, + Optional[Any], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + ] + ] + ]: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index c1353b18c1cd..0c1ed752406f 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -978,26 +978,12 @@ def _persist_transaction_ids_txn( """Persist the mapping from transaction IDs to event IDs (if defined).""" inserted_ts = self._clock.time_msec() - to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = [] to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = [] for event, _ in events_and_contexts: txn_id = getattr(event.internal_metadata, "txn_id", None) - token_id = getattr(event.internal_metadata, "token_id", None) device_id = getattr(event.internal_metadata, "device_id", None) if txn_id is not None: - if token_id is not None: - to_insert_token_id.append( - ( - event.event_id, - event.room_id, - event.sender, - token_id, - txn_id, - inserted_ts, - ) - ) - if device_id is not None: to_insert_device_id.append( ( @@ -1010,26 +996,7 @@ def _persist_transaction_ids_txn( ) ) - # Synapse usually relies on the device_id to scope transactions for events, - # except for users without device IDs (appservice, guests, and access - # tokens minted with the admin API) which use the access token ID instead. - # - # TODO https://github.com/matrix-org/synapse/issues/16042 - if to_insert_token_id: - self.db_pool.simple_insert_many_txn( - txn, - table="event_txn_id", - keys=( - "event_id", - "room_id", - "user_id", - "token_id", - "txn_id", - "inserted_ts", - ), - values=to_insert_token_id, - ) - + # Synapse relies on the device_id to scope transactions for events.. if to_insert_device_id: self.db_pool.simple_insert_many_txn( txn, @@ -1671,7 +1638,7 @@ def _update_metadata_tables_txn( if self._ephemeral_messages_enabled: # If there's an expiry timestamp on the event, store it. expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if type(expiry_ts) is int and not event.is_state(): + if type(expiry_ts) is int and not event.is_state(): # noqa: E721 self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) # Insert into the room_memberships table. @@ -2039,10 +2006,10 @@ def _store_retention_policy_for_room_txn( ): if ( "min_lifetime" in event.content - and type(event.content["min_lifetime"]) is not int + and type(event.content["min_lifetime"]) is not int # noqa: E721 ) or ( "max_lifetime" in event.content - and type(event.content["max_lifetime"]) is not int + and type(event.content["max_lifetime"]) is not int # noqa: E721 ): # Ignore the event if one of the value isn't an integer. return diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7e7648c95112..1eb313040ed9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -2022,25 +2022,6 @@ def get_next_event_to_expire_txn( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - async def get_event_id_from_transaction_id_and_token_id( - self, room_id: str, user_id: str, token_id: int, txn_id: str - ) -> Optional[str]: - """Look up if we have already persisted an event for the transaction ID, - returning the event ID if so. - """ - return await self.db_pool.simple_select_one_onecol( - table="event_txn_id", - keyvalues={ - "room_id": room_id, - "user_id": user_id, - "token_id": token_id, - "txn_id": txn_id, - }, - retcol="event_id", - allow_none=True, - desc="get_event_id_from_transaction_id_and_token_id", - ) - async def get_event_id_from_transaction_id_and_device_id( self, room_id: str, user_id: str, device_id: str, txn_id: str ) -> Optional[str]: @@ -2072,29 +2053,35 @@ async def get_already_persisted_events( """ mapping = {} - txn_id_to_event: Dict[Tuple[str, int, str], str] = {} + txn_id_to_event: Dict[Tuple[str, str, str, str], str] = {} for event in events: - token_id = getattr(event.internal_metadata, "token_id", None) + device_id = getattr(event.internal_metadata, "device_id", None) txn_id = getattr(event.internal_metadata, "txn_id", None) - if token_id and txn_id: + if device_id and txn_id: # Check if this is a duplicate of an event in the given events. - existing = txn_id_to_event.get((event.room_id, token_id, txn_id)) + existing = txn_id_to_event.get( + (event.room_id, event.sender, device_id, txn_id) + ) if existing: mapping[event.event_id] = existing continue # Check if this is a duplicate of an event we've already # persisted. - existing = await self.get_event_id_from_transaction_id_and_token_id( - event.room_id, event.sender, token_id, txn_id + existing = await self.get_event_id_from_transaction_id_and_device_id( + event.room_id, event.sender, device_id, txn_id ) if existing: mapping[event.event_id] = existing - txn_id_to_event[(event.room_id, token_id, txn_id)] = existing + txn_id_to_event[ + (event.room_id, event.sender, device_id, txn_id) + ] = existing else: - txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id + txn_id_to_event[ + (event.room_id, event.sender, device_id, txn_id) + ] = event.event_id return mapping diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 54d40e7a3ab0..5a01ec213759 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type from weakref import WeakValueDictionary -from twisted.internet.interfaces import IReactorCore +from twisted.internet.task import LoopingCall from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -26,6 +26,7 @@ LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.types import ISynapseReactor from synapse.util import Clock from synapse.util.stringutils import random_string @@ -358,7 +359,7 @@ class Lock: def __init__( self, - reactor: IReactorCore, + reactor: ISynapseReactor, clock: Clock, store: LockStore, read_write: bool, @@ -377,19 +378,25 @@ def __init__( self._table = "worker_read_write_locks" if read_write else "worker_locks" - self._looping_call = clock.looping_call( + # We might be called from a non-main thread, so we defer setting up the + # looping call. + self._looping_call: Optional[LoopingCall] = None + reactor.callFromThread(self._setup_looping_call) + + self._dropped = False + + def _setup_looping_call(self) -> None: + self._looping_call = self._clock.looping_call( self._renew, _RENEWAL_INTERVAL_MS, - store, - clock, - read_write, - lock_name, - lock_key, - token, + self._store, + self._clock, + self._read_write, + self._lock_name, + self._lock_key, + self._token, ) - self._dropped = False - @staticmethod @wrap_as_background_process("Lock._renew") async def _renew( @@ -459,7 +466,7 @@ async def release(self) -> None: if self._dropped: return - if self._looping_call.running: + if self._looping_call and self._looping_call.running: self._looping_call.stop() await self._store.db_pool.simple_delete( @@ -486,8 +493,9 @@ def __del__(self) -> None: # We should not be dropped without the lock being released (unless # we're shutting down), but if we are then let's at least stop # renewing the lock. - if self._looping_call.running: - self._looping_call.stop() + if self._looping_call and self._looping_call.running: + # We might be called from a non-main thread. + self._reactor.callFromThread(self._looping_call.stop) if self._reactor.running: logger.error( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index c13c0bc7d725..bec0dc2afeeb 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -88,7 +88,6 @@ def _load_rules( msc1767_enabled=experimental_config.msc1767_enabled, msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, - msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs, ) return filtered_rules diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 649d3c8e9f96..422f11f59e9e 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 80 # remember to update the list below when updating +SCHEMA_VERSION = 81 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -114,19 +114,15 @@ Changes in SCHEMA_VERSION = 80 - The event_txn_id_device_id is always written to for new events. - Add tables for the task scheduler. + +Changes in SCHEMA_VERSION = 81 + - The event_txn_id is no longer written to for new events. """ SCHEMA_COMPAT_VERSION = ( - # Queries against `event_stream_ordering` columns in membership tables must - # be disambiguated. - # - # The threads_id column must written to with non-null values for the - # event_push_actions, event_push_actions_staging, and event_push_summary tables. - # - # insertions to the column `full_user_id` of tables profiles and user_filters can no - # longer be null - 76 + # The `event_txn_id_device_id` must be written to for new events. + 80 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index bf7bd351e0cc..029eedcc6fae 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -470,7 +470,7 @@ def __init__(self) -> None: def deferred(self, key: KT) -> "defer.Deferred[VT]": if not self._deferred: self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - return self._deferred.observe().addCallback(lambda res: res.get(key)) + return self._deferred.observe().addCallback(lambda res: res[key]) def add_invalidation_callback( self, key: KT, callback: Optional[Callable[[], None]] diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 114130a08fe2..f7cead9e1206 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -51,9 +51,9 @@ def dependencies(self) -> Iterable[str]: DEV_EXTRAS = {"lint", "mypy", "test", "dev"} -RUNTIME_EXTRAS = ( - set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) - DEV_EXTRAS -) +ALL_EXTRAS = metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra") +assert ALL_EXTRAS is not None +RUNTIME_EXTRAS = set(ALL_EXTRAS) - DEV_EXTRAS VERSION = metadata.version(DISTRIBUTION_NAME) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index cde4a0780fe7..f693ba2a8c0c 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -291,7 +291,8 @@ def _on_enter(self, request_id: object) -> "defer.Deferred[None]": if self.metrics_name: rate_limit_reject_counter.labels(self.metrics_name).inc() raise LimitExceededError( - retry_after_ms=int(self.window_size / self.sleep_limit) + limiter_name="rc_federation", + retry_after_ms=int(self.window_size / self.sleep_limit), ) self.request_times.append(time_now) diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 4aea64b338b4..9e89aeb74891 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -57,14 +57,13 @@ class TaskScheduler: the code launching the task. You can also specify the `result` (and/or an `error`) when returning from the function. - The reconciliation loop runs every 5 mns, so this is not a precise scheduler. When wanting - to launch now, the launch will still not happen before the next loop run. - - Tasks will be run on the worker specified with `run_background_tasks_on` config, - or the main one by default. + The reconciliation loop runs every minute, so this is not a precise scheduler. There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already full. In this regard, please take great care that scheduled tasks can actually finished. For now there is no mechanism to stop a running task if it is stuck. + + Tasks will be run on the worker specified with `run_background_tasks_on` config, + or the main one by default. """ # Precision of the scheduler, evaluation of tasks to run will only happen @@ -85,7 +84,7 @@ def __init__(self, hs: "HomeServer"): self._actions: Dict[ str, Callable[ - [ScheduledTask, bool], + [ScheduledTask], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], ], ] = {} @@ -98,11 +97,13 @@ def __init__(self, hs: "HomeServer"): "handle_scheduled_tasks", self._handle_scheduled_tasks, ) + else: + self.replication_client = hs.get_replication_command_handler() def register_action( self, function: Callable[ - [ScheduledTask, bool], + [ScheduledTask], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], ], action_name: str, @@ -115,10 +116,9 @@ def register_action( calling `schedule_task` but rather in an `__init__` method. Args: - function: The function to be executed for this action. The parameters - passed to the function when launched are the `ScheduledTask` being run, - and a `first_launch` boolean to signal if it's a resumed task or the first - launch of it. The function should return a tuple of new `status`, `result` + function: The function to be executed for this action. The parameter + passed to the function when launched is the `ScheduledTask` being run. + The function should return a tuple of new `status`, `result` and `error` as specified in `ScheduledTask`. action_name: The name of the action to be associated with the function """ @@ -171,6 +171,12 @@ async def schedule_task( ) await self._store.insert_scheduled_task(task) + if status == TaskStatus.ACTIVE: + if self._run_background_tasks: + await self._launch_task(task) + else: + self.replication_client.send_new_active_task(task.id) + return task.id async def update_task( @@ -265,21 +271,13 @@ async def delete_task(self, id: str) -> None: Args: id: id of the task to delete """ - if self.task_is_running(id): - raise Exception(f"Task {id} is currently running and can't be deleted") + task = await self.get_task(id) + if task is None: + raise Exception(f"Task {id} does not exist") + if task.status == TaskStatus.ACTIVE: + raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") await self._store.delete_scheduled_task(id) - def task_is_running(self, id: str) -> bool: - """Check if a task is currently running. - - Can only be called from the worker handling the task scheduling. - - Args: - id: id of the task to check - """ - assert self._run_background_tasks - return id in self._running_tasks - async def _handle_scheduled_tasks(self) -> None: """Main loop taking care of launching tasks and cleaning up old ones.""" await self._launch_scheduled_tasks() @@ -288,29 +286,11 @@ async def _handle_scheduled_tasks(self) -> None: async def _launch_scheduled_tasks(self) -> None: """Retrieve and launch scheduled tasks that should be running at that time.""" for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]): - if not self.task_is_running(task.id): - if ( - len(self._running_tasks) - < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS - ): - await self._launch_task(task, first_launch=False) - else: - if ( - self._clock.time_msec() - > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS - ): - logger.warn( - f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" - ) + await self._launch_task(task) for task in await self.get_tasks( statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec() ): - if ( - not self.task_is_running(task.id) - and len(self._running_tasks) - < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS - ): - await self._launch_task(task, first_launch=True) + await self._launch_task(task) running_tasks_gauge.set(len(self._running_tasks)) @@ -320,27 +300,27 @@ async def _clean_scheduled_tasks(self) -> None: statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE] ): # FAILED and COMPLETE tasks should never be running - assert not self.task_is_running(task.id) + assert task.id not in self._running_tasks if ( self._clock.time_msec() > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS ): await self._store.delete_scheduled_task(task.id) - async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None: + async def _launch_task(self, task: ScheduledTask) -> None: """Launch a scheduled task now. Args: task: the task to launch - first_launch: `True` if it's the first time is launched, `False` otherwise """ - assert task.action in self._actions + assert self._run_background_tasks + assert task.action in self._actions function = self._actions[task.action] async def wrapper() -> None: try: - (status, result, error) = await function(task, first_launch) + (status, result, error) = await function(task) except Exception: f = Failure() logger.error( @@ -360,6 +340,20 @@ async def wrapper() -> None: ) self._running_tasks.remove(task.id) + if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: + return + + if ( + self._clock.time_msec() + > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS + ): + logger.warn( + f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" + ) + + if task.id in self._running_tasks: + return + self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) description = f"{task.id}-{task.action}" diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ce96574915fd..dcd01d56885c 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pymacaroons @@ -35,7 +35,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import simple_async_mock from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -60,16 +59,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) - self.store.insert_client_ip = simple_async_mock(None) - self.store.is_support_user = simple_async_mock(False) + self.store.insert_client_ip = AsyncMock(return_value=None) + self.store.is_support_user = AsyncMock(return_value=False) def test_get_user_by_req_user_valid_token(self) -> None: user_info = TokenLookupResult( user_id=self.test_user, token_id=5, device_id="device" ) - self.store.get_user_by_access_token = simple_async_mock(user_info) - self.store.mark_access_token_as_used = simple_async_mock(None) - self.store.get_user_locked_status = simple_async_mock(False) + self.store.get_user_by_access_token = AsyncMock(return_value=user_info) + self.store.mark_access_token_as_used = AsyncMock(return_value=None) + self.store.get_user_locked_status = AsyncMock(return_value=False) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -78,7 +77,7 @@ def test_get_user_by_req_user_valid_token(self) -> None: self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self) -> None: - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -91,7 +90,7 @@ def test_get_user_by_req_user_bad_token(self) -> None: def test_get_user_by_req_user_missing_token(self) -> None: user_info = TokenLookupResult(user_id=self.test_user, token_id=5) - self.store.get_user_by_access_token = simple_async_mock(user_info) + self.store.get_user_by_access_token = AsyncMock(return_value=user_info) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -106,7 +105,7 @@ def test_get_user_by_req_appservice_valid_token(self) -> None: token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -125,7 +124,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None: ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "192.168.10.10" @@ -144,7 +143,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "131.111.8.42" @@ -158,7 +157,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: def test_get_user_by_req_appservice_bad_token(self) -> None: self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -172,7 +171,7 @@ def test_get_user_by_req_appservice_bad_token(self) -> None: def test_get_user_by_req_appservice_missing_token(self) -> None: app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -190,8 +189,8 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None: app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = simple_async_mock({"is_guest": False}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -210,7 +209,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None: ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -234,10 +233,10 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None: app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = simple_async_mock({"is_guest": False}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) # This also needs to just return a truth-y value - self.store.get_device = simple_async_mock({"hidden": False}) + self.store.get_device = AsyncMock(return_value={"hidden": False}) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -266,10 +265,10 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None: app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = simple_async_mock({"is_guest": False}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) # This also needs to just return a falsey value - self.store.get_device = simple_async_mock(None) + self.store.get_device = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -283,8 +282,8 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None: self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: - self.store.get_user_by_access_token = simple_async_mock( - TokenLookupResult( + self.store.get_user_by_access_token = AsyncMock( + return_value=TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", token_id=5, @@ -292,9 +291,9 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> Non token_used=True, ) ) - self.store.insert_client_ip = simple_async_mock(None) - self.store.mark_access_token_as_used = simple_async_mock(None) - self.store.get_user_locked_status = simple_async_mock(False) + self.store.insert_client_ip = AsyncMock(return_value=None) + self.store.mark_access_token_as_used = AsyncMock(return_value=None) + self.store.get_user_locked_status = AsyncMock(return_value=False) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -304,8 +303,8 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> Non def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: self.auth._track_puppeted_user_ips = True - self.store.get_user_by_access_token = simple_async_mock( - TokenLookupResult( + self.store.get_user_by_access_token = AsyncMock( + return_value=TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", token_id=5, @@ -313,9 +312,9 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: token_used=True, ) ) - self.store.get_user_locked_status = simple_async_mock(False) - self.store.insert_client_ip = simple_async_mock(None) - self.store.mark_access_token_as_used = simple_async_mock(None) + self.store.get_user_locked_status = AsyncMock(return_value=False) + self.store.insert_client_ip = AsyncMock(return_value=None) + self.store.mark_access_token_as_used = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -324,7 +323,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: self.assertEqual(self.store.insert_client_ip.call_count, 2) def test_get_user_from_macaroon(self) -> None: - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -342,8 +341,8 @@ def test_get_user_from_macaroon(self) -> None: ) def test_get_guest_user_from_macaroon(self) -> None: - self.store.get_user_by_id = simple_async_mock({"is_guest": True}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -373,7 +372,7 @@ def test_blocking_mau(self) -> None: self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = simple_async_mock(lots_of_users) + self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) e = self.get_failure( self.auth_blocking.check_auth_blocking(), ResourceLimitError @@ -383,25 +382,27 @@ def test_blocking_mau(self) -> None: self.assertEqual(e.value.code, 403) # Ensure does not throw an error - self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) + self.store.get_monthly_active_count = AsyncMock( + return_value=small_number_of_users + ) self.get_success(self.auth_blocking.check_auth_blocking()) def test_blocking_mau__depending_on_user_type(self) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = simple_async_mock(100) + self.store.get_monthly_active_count = AsyncMock(return_value=100) # Support users allowed self.get_success( self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) ) - self.store.get_monthly_active_count = simple_async_mock(100) + self.store.get_monthly_active_count = AsyncMock(return_value=100) # Bots not allowed self.get_failure( self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError, ) - self.store.get_monthly_active_count = simple_async_mock(100) + self.store.get_monthly_active_count = AsyncMock(return_value=100) # Real users not allowed self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) @@ -412,9 +413,9 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips( self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = False - self.store.get_monthly_active_count = simple_async_mock(100) - self.store.user_last_seen_monthly_active = simple_async_mock() - self.store.is_trial_user = simple_async_mock() + self.store.get_monthly_active_count = AsyncMock(return_value=100) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) + self.store.is_trial_user = AsyncMock(return_value=False) appservice = ApplicationService( "abcd", @@ -443,9 +444,9 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips( self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = True - self.store.get_monthly_active_count = simple_async_mock(100) - self.store.user_last_seen_monthly_active = simple_async_mock() - self.store.is_trial_user = simple_async_mock() + self.store.get_monthly_active_count = AsyncMock(return_value=100) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) + self.store.is_trial_user = AsyncMock(return_value=False) appservice = ApplicationService( "abcd", @@ -473,7 +474,7 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips( def test_reserved_threepid(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._max_mau_value = 1 - self.store.get_monthly_active_count = simple_async_mock(2) + self.store.get_monthly_active_count = AsyncMock(return_value=2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.auth_blocking._mau_limits_reserved_threepids = [threepid] diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py new file mode 100644 index 000000000000..8e159029d9b0 --- /dev/null +++ b/tests/api/test_errors.py @@ -0,0 +1,43 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.errors import LimitExceededError + +from tests import unittest + + +class LimitExceededErrorTestCase(unittest.TestCase): + def test_key_appears_in_context_but_not_error_dict(self) -> None: + err = LimitExceededError("needle") + serialised = json.dumps(err.error_dict(None)) + self.assertIn("needle", err.debug_context) + self.assertNotIn("needle", serialised) + + # Create a sub-class to avoid mutating the class-level property. + class LimitExceededErrorHeaders(LimitExceededError): + include_retry_after_header = True + + def test_limit_exceeded_header(self) -> None: + err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100) + self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) + assert err.headers is not None + self.assertEqual(err.headers.get("Retry-After"), "1") + + def test_limit_exceeded_rounding(self) -> None: + err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001) + self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) + assert err.headers is not None + self.assertEqual(err.headers.get("Retry-After"), "4") diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index fa6c1c02ce95..a24638c9eff7 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,5 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from synapse.appservice import ApplicationService +from synapse.config.ratelimiting import RatelimitSettings from synapse.types import create_requester from tests import unittest @@ -10,8 +11,7 @@ def test_allowed_via_can_do_action(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) @@ -43,8 +43,11 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> Non limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings( + key="", + per_second=0.1, + burst_count=1, + ), ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -76,8 +79,11 @@ def test_allowed_appservice_via_can_requester_do_action(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings( + key="", + per_second=0.1, + burst_count=1, + ), ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -101,8 +107,7 @@ def test_allowed_via_ratelimit(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) # Shouldn't raise @@ -128,8 +133,7 @@ def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) # First attempt should be allowed @@ -177,8 +181,7 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) # First attempt should be allowed @@ -208,8 +211,7 @@ def test_pruning(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -244,7 +246,11 @@ def test_db_user_override(self) -> None: ) ) - limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=store, + clock=self.clock, + cfg=RatelimitSettings("", per_second=0.1, burst_count=1), + ) # Shouldn't raise for _ in range(20): @@ -254,8 +260,11 @@ def test_multiple_actions(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + key="", + per_second=0.1, + burst_count=3, + ), ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( @@ -321,8 +330,7 @@ def test_rate_limit_burst_only_given_once(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings("", per_second=0.1, burst_count=3), ) def consume_at(time: float) -> bool: @@ -346,8 +354,11 @@ def test_record_action_which_doesnt_fill_bucket(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + "", + per_second=0.1, + burst_count=3, + ), ) # Observe two actions, leaving room in the bucket for one more. @@ -369,8 +380,11 @@ def test_record_action_which_fills_bucket(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + "", + per_second=0.1, + burst_count=3, + ), ) # Observe three actions, filling up the bucket. @@ -398,8 +412,11 @@ def test_record_action_which_overfills_bucket(self) -> None: limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + "", + per_second=0.1, + burst_count=3, + ), ) # Observe four actions, exceeding the bucket. diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 3c635e3dcbdb..75fb5fae6b92 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -96,7 +96,7 @@ async def get_json( ) # We assign to a method, which mypy doesn't like. - self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] + self.api.get_json = Mock(side_effect=get_json) # type: ignore[method-assign] result = self.get_success( self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) @@ -168,7 +168,7 @@ async def get_json( ) # We assign to a method, which mypy doesn't like. - self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] + self.api.get_json = Mock(side_effect=get_json) # type: ignore[method-assign] result = self.get_success( self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) @@ -215,7 +215,7 @@ async def post_json_get_json( return RESPONSE # We assign to a method, which mypy doesn't like. - self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment] + self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[method-assign] MISSING_KEYS = [ # Known user, known device, missing algorithm. diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 66753c60c4b1..6ac5fc1ae7c4 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -13,14 +13,13 @@ # limitations under the License. import re from typing import Any, Generator -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.internet import defer from synapse.appservice import ApplicationService, Namespace from tests import unittest -from tests.test_utils import simple_async_mock def _regex(regex: str, exclusive: bool = True) -> Namespace: @@ -43,8 +42,8 @@ def setUp(self) -> None: ) self.store = Mock() - self.store.get_aliases_for_room = simple_async_mock([]) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_aliases_for_room = AsyncMock(return_value=[]) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) @defer.inlineCallbacks def test_regex_user_id_prefix_match( @@ -127,10 +126,10 @@ def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, Non self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room = simple_async_mock( - ["#irc_foobar:matrix.org", "#athing:matrix.org"] + self.store.get_aliases_for_room = AsyncMock( + return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) self.assertTrue( ( yield self.service.is_interested_in_event( @@ -182,10 +181,10 @@ def test_regex_alias_no_match( self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room = simple_async_mock( - ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] + self.store.get_aliases_for_room = AsyncMock( + return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) self.assertFalse( ( yield defer.ensureDeferred( @@ -205,8 +204,10 @@ def test_regex_multiple_matches( ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_aliases_for_room = AsyncMock( + return_value=["#irc_barfoo:matrix.org"] + ) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) self.assertTrue( ( yield self.service.is_interested_in_event( @@ -235,10 +236,10 @@ def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, No def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. - self.store.get_local_users_in_room = simple_async_mock( - ["@alice:here", "@irc_fo:here", "@bob:here"] + self.store.get_local_users_in_room = AsyncMock( + return_value=["@alice:here", "@irc_fo:here", "@bob:here"] ) - self.store.get_aliases_for_room = simple_async_mock([]) + self.store.get_aliases_for_room = AsyncMock(return_value=[]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index e2a3bad065da..445919417e63 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Sequence, Tuple, cast -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from typing_extensions import TypeAlias @@ -37,7 +37,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import simple_async_mock from ..utils import MockClock @@ -62,10 +61,12 @@ def test_single_service_up_txn_sent(self) -> None: txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) - txn.send = simple_async_mock(True) - txn.complete = simple_async_mock(True) - self.store.create_appservice_txn = simple_async_mock(txn) + self.store.get_appservice_state = AsyncMock( + return_value=ApplicationServiceState.UP + ) + txn.send = AsyncMock(return_value=True) + txn.complete = AsyncMock(return_value=True) + self.store.create_appservice_txn = AsyncMock(return_value=txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -89,10 +90,10 @@ def test_single_service_down(self) -> None: events = [Mock(), Mock()] txn = Mock(id="idhere", service=service, events=events) - self.store.get_appservice_state = simple_async_mock( - ApplicationServiceState.DOWN + self.store.get_appservice_state = AsyncMock( + return_value=ApplicationServiceState.DOWN ) - self.store.create_appservice_txn = simple_async_mock(txn) + self.store.create_appservice_txn = AsyncMock(return_value=txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -118,10 +119,12 @@ def test_single_service_up_txn_not_sent(self) -> None: txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) - self.store.set_appservice_state = simple_async_mock(True) - txn.send = simple_async_mock(False) # fails to send - self.store.create_appservice_txn = simple_async_mock(txn) + self.store.get_appservice_state = AsyncMock( + return_value=ApplicationServiceState.UP + ) + self.store.set_appservice_state = AsyncMock(return_value=True) + txn.send = AsyncMock(return_value=False) # fails to send + self.store.create_appservice_txn = AsyncMock(return_value=txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -150,7 +153,7 @@ def setUp(self) -> None: self.as_api = Mock() self.store = Mock() self.service = Mock() - self.callback = simple_async_mock() + self.callback = AsyncMock() self.recoverer = _Recoverer( clock=cast(Clock, self.clock), as_api=self.as_api, @@ -174,8 +177,8 @@ def take_txn( self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = simple_async_mock(True) - txn.complete = simple_async_mock(None) + txn.send = AsyncMock(return_value=True) + txn.complete = AsyncMock(return_value=None) # wait for exp backoff self.clock.advance_time(2) self.assertEqual(1, txn.send.call_count) @@ -202,8 +205,8 @@ def take_txn( self.recoverer.recover() self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = simple_async_mock(False) - txn.complete = simple_async_mock(None) + txn.send = AsyncMock(return_value=False) + txn.complete = AsyncMock(return_value=None) self.clock.advance_time(2) self.assertEqual(1, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) @@ -216,7 +219,7 @@ def take_txn( self.assertEqual(3, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, self.callback.call_count) - txn.send = simple_async_mock(True) # successfully send the txn + txn.send = AsyncMock(return_value=True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEqual(1, txn.send.call_count) # new mock reset call count @@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None: self.scheduler = ApplicationServiceScheduler(hs) self.txn_ctrl = Mock() - self.txn_ctrl.send = simple_async_mock() + self.txn_ctrl.send = AsyncMock() # Replace instantiated _TransactionController instances with our Mock self.scheduler.txn_ctrl = self.txn_ctrl diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py index f12147eaa000..0c27dd21e2b8 100644 --- a/tests/config/test_ratelimiting.py +++ b/tests/config/test_ratelimiting.py @@ -12,11 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import RatelimitSettings from tests.unittest import TestCase from tests.utils import default_config +class ParseRatelimitSettingsTestcase(TestCase): + def test_depth_1(self) -> None: + cfg = { + "a": { + "per_second": 5, + "burst_count": 10, + } + } + parsed = RatelimitSettings.parse(cfg, "a") + self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) + + def test_depth_2(self) -> None: + cfg = { + "a": { + "b": { + "per_second": 5, + "burst_count": 10, + }, + } + } + parsed = RatelimitSettings.parse(cfg, "a.b") + self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10)) + + def test_missing(self) -> None: + parsed = RatelimitSettings.parse( + {}, "a", defaults={"per_second": 5, "burst_count": 10} + ) + self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) + + class RatelimitConfigTestCase(TestCase): def test_parse_rc_federation(self) -> None: config_dict = default_config("test") diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 2be341ac7b84..f93ba5d4cf0c 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -13,7 +13,7 @@ # limitations under the License. import time from typing import Any, Dict, List, Optional, cast -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import attr import canonicaljson @@ -45,7 +45,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import logcontext_clean, override_config @@ -291,7 +290,7 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None: with a null `ts_valid_until_ms` """ mock_fetcher = Mock() - mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) + mock_fetcher.get_keys = AsyncMock(return_value={}) key1 = signedjson.key.generate_signing_key("1") r = self.hs.get_datastores().main.store_server_signature_keys( diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 6fb1f1bd6e31..0fcfe38efada 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Iterable, List, Optional, Set, Tuple, Union -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import attr @@ -30,7 +30,6 @@ from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config -from tests.test_utils import simple_async_mock from tests.unittest import ( FederatingHomeserverTestCase, HomeserverTestCase, @@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. self.fed_transport_client = Mock(spec=["send_transaction"]) - self.fed_transport_client.send_transaction = simple_async_mock({}) + self.fed_transport_client.send_transaction = AsyncMock(return_value={}) hs = self.setup_test_homeserver( federation_transport_client=self.fed_transport_client, diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 129d7cfd93f5..73a2766bafcb 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock from synapse.api.errors import Codes, SynapseError from synapse.rest import admin @@ -20,7 +20,6 @@ from synapse.types import JsonDict, UserID, create_requester from tests import unittest -from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): @@ -58,7 +57,7 @@ def test_complexity_simple(self) -> None: async def get_current_state_event_counts(room_id: str) -> int: return int(500 * 1.23) - store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] + store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign] # Get the room complexity again -- make sure it's our artificial value channel = self.make_signed_federation_request( @@ -75,9 +74,9 @@ def test_join_too_large(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) d = handler._remote_join( @@ -106,9 +105,9 @@ def test_join_too_large_admin(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) d = handler._remote_join( @@ -143,16 +142,16 @@ def test_join_too_large_once_joined(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) # Artificially raise the complexity async def get_current_state_event_counts(room_id: str) -> int: return 600 - self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] + self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign] d = handler._remote_join( create_requester(u1), @@ -200,9 +199,9 @@ def test_join_too_large_no_admin(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) d = handler._remote_join( @@ -230,9 +229,9 @@ def test_join_too_large_admin(self) -> None: fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index b290b020a274..75ae740b435d 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -1,6 +1,6 @@ from typing import Callable, Collection, List, Optional, Tuple from unittest import mock -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -19,7 +19,7 @@ from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination -from tests.test_utils import event_injection, make_awaitable +from tests.test_utils import event_injection from tests.unittest import FederatingHomeserverTestCase @@ -50,8 +50,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # This mock is crucial for destination_rooms to be populated. # TODO: this seems to no longer be the case---tests pass with this mock # commented out. - state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment] - return_value=make_awaitable({"test", "host2"}) + state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign] + return_value={"test", "host2"} ) # whenever send_transaction is called, record the pdu data @@ -436,7 +436,7 @@ def test_catch_up_on_synapse_startup(self) -> None: def wake_destination_track(destination: str) -> None: woken.add(destination) - self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment] + self.federation_sender.wake_destination = wake_destination_track # type: ignore[method-assign] # We wait quite long so that all dests can be woken up, since there is a delay # between them. diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 9e104fd96aeb..7bd3d06859f6 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, FrozenSet, List, Optional, Set -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from signedjson import key, sign from signedjson.types import BaseKey, SigningKey @@ -29,7 +29,6 @@ from synapse.types import JsonDict, ReadReceipt from synapse.util import Clock -from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase @@ -43,15 +42,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.federation_transport_client = Mock(spec=["send_transaction"]) + self.federation_transport_client.send_transaction = AsyncMock() hs = self.setup_test_homeserver( federation_transport_client=self.federation_transport_client, ) - hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] - return_value=make_awaitable({"test", "host2"}) + hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign] + return_value={"test", "host2"} ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[method-assign] hs.get_storage_controllers().state.get_current_hosts_in_room ) @@ -64,7 +64,7 @@ def default_config(self) -> JsonDict: def test_send_receipts(self) -> None: mock_send_transaction = self.federation_transport_client.send_transaction - mock_send_transaction.return_value = make_awaitable({}) + mock_send_transaction.return_value = {} sender = self.hs.get_federation_sender() receipt = ReadReceipt( @@ -104,7 +104,7 @@ def test_send_receipts(self) -> None: def test_send_receipts_thread(self) -> None: mock_send_transaction = self.federation_transport_client.send_transaction - mock_send_transaction.return_value = make_awaitable({}) + mock_send_transaction.return_value = {} # Create receipts for: # @@ -180,7 +180,7 @@ def test_send_receipts_with_backoff(self) -> None: """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" mock_send_transaction = self.federation_transport_client.send_transaction - mock_send_transaction.return_value = make_awaitable({}) + mock_send_transaction.return_value = {} sender = self.hs.get_federation_sender() receipt = ReadReceipt( @@ -276,6 +276,8 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.federation_transport_client = Mock( spec=["send_transaction", "query_user_devices"] ) + self.federation_transport_client.send_transaction = AsyncMock() + self.federation_transport_client.query_user_devices = AsyncMock() return self.setup_test_homeserver( federation_transport_client=self.federation_transport_client, ) @@ -317,13 +319,13 @@ async def get_current_hosts_in_room(room_id: str) -> Set[str]: self.record_transaction ) - def record_transaction( + async def record_transaction( self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None - ) -> "defer.Deferred[JsonDict]": + ) -> JsonDict: assert json_cb is not None data = json_cb() self.edus.extend(data["edus"]) - return defer.succeed({}) + return {} def test_send_device_updates(self) -> None: """Basic case: each device update should result in an EDU""" @@ -354,15 +356,11 @@ def test_dont_send_device_updates_for_remote_users(self) -> None: # Send the server a device list EDU for the other user, this will cause # it to try and resync the device lists. - self.federation_transport_client.query_user_devices.return_value = ( - make_awaitable( - { - "stream_id": "1", - "user_id": "@user2:host2", - "devices": [{"device_id": "D1"}], - } - ) - ) + self.federation_transport_client.query_user_devices.return_value = { + "stream_id": "1", + "user_id": "@user2:host2", + "devices": [{"device_id": "D1"}], + } self.get_success( self.device_handler.device_list_updater.incoming_device_list_update( @@ -533,7 +531,7 @@ def test_unreachable_server(self) -> None: recovery """ mock_send_txn = self.federation_transport_client.send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) + mock_send_txn.side_effect = AssertionError("fail") # create devices u1 = self.register_user("user", "pass") @@ -578,7 +576,7 @@ def test_prune_outbound_device_pokes1(self) -> None: This case tests the behaviour when the server has never been reachable. """ mock_send_txn = self.federation_transport_client.send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) + mock_send_txn.side_effect = AssertionError("fail") # create devices u1 = self.register_user("user", "pass") @@ -636,7 +634,7 @@ def test_prune_outbound_device_pokes2(self) -> None: # now the server goes offline mock_send_txn = self.federation_transport_client.send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) + mock_send_txn.side_effect = AssertionError("fail") self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 70209ab09011..3f42f79f26db 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -218,7 +218,7 @@ async def approve_all_signature_checking( ) -> EventBase: return pdu - homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[assignment] + homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[method-assign] approve_all_signature_checking ) @@ -229,7 +229,7 @@ async def _check_event_auth( ) -> None: pass - homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] + homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[method-assign] return super().prepare(reactor, clock, homeserver) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 9014e60577c7..46d022092e82 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Dict, Iterable, List, Optional -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from parameterized import parameterized @@ -36,7 +36,7 @@ from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import event_injection, make_awaitable, simple_async_mock +from tests.test_utils import event_injection from tests.unittest import override_config from tests.utils import MockClock @@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): def setUp(self) -> None: self.mock_store = Mock() - self.mock_as_api = Mock() + self.mock_as_api = AsyncMock() self.mock_scheduler = Mock() hs = Mock() hs.get_datastores.return_value = Mock(main=self.mock_store) - self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None) - self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) - self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( - None - ) + self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None) + self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None) + self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() @@ -69,21 +67,25 @@ def test_notify_interested_services(self) -> None: self._mkservice(is_interested_in_event=False), ] - self.mock_as_api.query_user.return_value = make_awaitable(True) + self.mock_as_api.query_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = make_awaitable([]) + self.mock_store.get_user_by_id = AsyncMock(return_value=[]) event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_all_new_event_ids_stream.side_effect = [ - make_awaitable((0, {})), - make_awaitable((1, {event.event_id: 0})), - ] - self.mock_store.get_events_as_list.side_effect = [ - make_awaitable([]), - make_awaitable([event]), - ] + self.mock_store.get_all_new_event_ids_stream = AsyncMock( + side_effect=[ + (0, {}), + (1, {event.event_id: 0}), + ] + ) + self.mock_store.get_events_as_list = AsyncMock( + side_effect=[ + [], + [event], + ] + ) self.handler.notify_interested_services(RoomStreamToken(None, 1)) self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( @@ -95,14 +97,16 @@ def test_query_user_exists_unknown_user(self) -> None: services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = make_awaitable(None) + self.mock_store.get_user_by_id = AsyncMock(return_value=None) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_event_ids_stream.side_effect = [ - make_awaitable((0, {event.event_id: 0})), - ] - self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])] + self.mock_as_api.query_user.return_value = True + self.mock_store.get_all_new_event_ids_stream = AsyncMock( + side_effect=[ + (0, {event.event_id: 0}), + ] + ) + self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]]) self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @@ -112,13 +116,15 @@ def test_query_user_exists_known_user(self) -> None: services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) + self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id}) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_event_ids_stream.side_effect = [ - make_awaitable((0, [event], {event.event_id: 0})), - ] + self.mock_as_api.query_user.return_value = True + self.mock_store.get_all_new_event_ids_stream = AsyncMock( + side_effect=[ + (0, [event], {event.event_id: 0}), + ] + ) self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -141,10 +147,10 @@ def test_query_room_alias_exists(self) -> None: self._mkservice_alias(is_room_alias_in_namespace=False), ] - self.mock_as_api.query_alias.return_value = make_awaitable(True) + self.mock_as_api.query_alias = AsyncMock(return_value=True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_association_from_room_alias.return_value = make_awaitable( - Mock(room_id=room_id, servers=servers) + self.mock_store.get_association_from_room_alias = AsyncMock( + return_value=Mock(room_id=room_id, servers=servers) ) result = self.successResultOf( @@ -177,7 +183,7 @@ def test_get_3pe_protocols_no_protocols(self) -> None: def test_get_3pe_protocols_protocol_no_response(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] - self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) + self.mock_as_api.get_3pe_protocol.return_value = None response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols()) ) @@ -189,9 +195,10 @@ def test_get_3pe_protocols_protocol_no_response(self) -> None: def test_get_3pe_protocols_select_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] - self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( - {"x-protocol-data": 42, "instances": []} - ) + self.mock_as_api.get_3pe_protocol.return_value = { + "x-protocol-data": 42, + "instances": [], + } response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) ) @@ -205,9 +212,10 @@ def test_get_3pe_protocols_select_one_protocol(self) -> None: def test_get_3pe_protocols_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] - self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( - {"x-protocol-data": 42, "instances": []} - ) + self.mock_as_api.get_3pe_protocol.return_value = { + "x-protocol-data": 42, + "instances": [], + } response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols()) ) @@ -222,9 +230,10 @@ def test_get_3pe_protocols_multiple_protocol(self) -> None: service_one = self._mkservice(False, ["my-protocol"]) service_two = self._mkservice(False, ["other-protocol"]) self.mock_store.get_app_services.return_value = [service_one, service_two] - self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( - {"x-protocol-data": 42, "instances": []} - ) + self.mock_as_api.get_3pe_protocol.return_value = { + "x-protocol-data": 42, + "instances": [], + } response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols()) ) @@ -287,13 +296,11 @@ def test_notify_interested_services_ephemeral(self) -> None: interested_service = self._mkservice(is_interested_in_event=True) services = [interested_service] self.mock_store.get_app_services.return_value = services - self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( - 579 - ) + self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579) event = Mock(event_id="event_1") - self.event_source.sources.receipt.get_new_events_as.return_value = ( - make_awaitable(([event], None)) + self.event_source.sources.receipt.get_new_events_as = AsyncMock( + return_value=([event], None) ) self.handler.notify_interested_services_ephemeral( @@ -317,13 +324,11 @@ def test_notify_interested_services_ephemeral_out_of_order(self) -> None: services = [interested_service] self.mock_store.get_app_services.return_value = services - self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( - 580 - ) + self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580) event = Mock(event_id="event_1") - self.event_source.sources.receipt.get_new_events_as.return_value = ( - make_awaitable(([event], None)) + self.event_source.sources.receipt.get_new_events_as = AsyncMock( + return_value=([event], None) ) self.handler.notify_interested_services_ephemeral( @@ -350,9 +355,7 @@ def _mkservice( A mock representing the ApplicationService. """ service = Mock() - service.is_interested_in_event.return_value = make_awaitable( - is_interested_in_event - ) + service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event) service.token = "mock_service_token" service.url = "mock_service_url" service.protocols = protocols @@ -396,12 +399,12 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events - self.send_mock = simple_async_mock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] + self.send_mock = AsyncMock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign] return_value=self._services ) @@ -894,12 +897,12 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that # will be sent over the wire - self.put_json = simple_async_mock() - hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] + self.put_json = AsyncMock() + hs.get_application_service_api().put_json = self.put_json # type: ignore[method-assign] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign] return_value=self._services ) @@ -1000,8 +1003,8 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track what's going out - self.send_mock = simple_async_mock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + self.send_mock = AsyncMock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] # We assign to a method. # Define an application service for the tests self._service_token = "VERYSECRET" diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 036dbbc45ba5..413ff8795bef 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional -from unittest.mock import Mock +from unittest.mock import AsyncMock import pymacaroons @@ -25,7 +25,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class AuthTestCase(unittest.HomeserverTestCase): @@ -166,8 +165,8 @@ def test_mau_limits_disabled(self) -> None: def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.large_number_of_users) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.large_number_of_users ) self.get_failure( @@ -177,8 +176,8 @@ def test_mau_limits_exceeded_large(self) -> None: ResourceLimitError, ) - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.large_number_of_users) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.large_number_of_users ) token = self.get_success( self.auth_handler.create_login_token_for_user_id(self.user1) @@ -191,8 +190,8 @@ def test_mau_limits_parity(self) -> None: self.auth_blocking._limit_usage_by_mau = True # Set the server to be at the edge of too many users. - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.auth_blocking._max_mau_value) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.auth_blocking._max_mau_value ) # If not in monthly active cohort @@ -208,8 +207,8 @@ def test_mau_limits_parity(self) -> None: self.assertIsNone(self.token_login(token)) # If in monthly active cohort - self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(self.clock.time_msec()) + self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock( + return_value=self.clock.time_msec() ) self.get_success( self.auth_handler.create_access_token_for_user_id( @@ -224,8 +223,8 @@ def test_mau_limits_parity(self) -> None: def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.small_number_of_users) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.small_number_of_users ) # Ensure does not raise exception self.get_success( @@ -234,8 +233,8 @@ def test_mau_limits_not_exceeded(self) -> None: ) ) - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.small_number_of_users) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.small_number_of_users ) token = self.get_success( self.auth_handler.create_login_token_for_user_id(self.user1) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 63aad0d10c2f..8582b1cd1e9e 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -20,7 +20,6 @@ from synapse.server import HomeServer from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. @@ -61,7 +60,7 @@ def test_map_cas_user_to_user(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] cas_response = CasResponse("test_user", {}) request = _mock_request() @@ -89,7 +88,7 @@ def test_map_cas_user_to_existing_user(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # Map a user via SSO. cas_response = CasResponse("test_user", {}) @@ -129,7 +128,7 @@ def test_map_cas_user_to_invalid_localpart(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] cas_response = CasResponse("föö", {}) request = _mock_request() @@ -160,7 +159,7 @@ def test_required_attributes(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # The response doesn't have the proper userGroup or department. cas_response = CasResponse("test_user", {}) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index e1e58fa6e648..55a4f95ef32b 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -32,7 +32,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import override_config user1 = "@boris:aaa" @@ -41,7 +40,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.appservice_api = mock.Mock() + self.appservice_api = mock.AsyncMock() hs = self.setup_test_homeserver( "server", application_service_api=self.appservice_api, @@ -123,50 +122,50 @@ def test_get_devices_by_user(self) -> None: self.assertEqual(3, len(res)) device_map = {d["device_id"]: d for d in res} - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user1, "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, "last_seen_ts": None, - }, - device_map["xyz"], + }.items(), + device_map["xyz"].items(), ) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user1, "device_id": "fco", "display_name": "display 1", "last_seen_ip": "ip1", "last_seen_ts": 1000000, - }, - device_map["fco"], + }.items(), + device_map["fco"].items(), ) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user1, "device_id": "abc", "display_name": "display 2", "last_seen_ip": "ip3", "last_seen_ts": 3000000, - }, - device_map["abc"], + }.items(), + device_map["abc"].items(), ) def test_get_device(self) -> None: self._record_users() res = self.get_success(self.handler.get_device(user1, "abc")) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user1, "device_id": "abc", "display_name": "display 2", "last_seen_ip": "ip3", "last_seen_ts": 3000000, - }, - res, + }.items(), + res.items(), ) def test_delete_device(self) -> None: @@ -375,13 +374,11 @@ def test_on_federation_query_user_devices_appservice(self) -> None: ) # Setup a response. - self.appservice_api.query_keys.return_value = make_awaitable( - { - "device_keys": { - local_user: {device_2: device_key_2b, device_3: device_key_3} - } + self.appservice_api.query_keys.return_value = { + "device_keys": { + local_user: {device_2: device_key_2b, device_3: device_key_3} } - ) + } # Request all devices. res = self.get_success( diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 90aec484c48c..367d94eca3dd 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Awaitable, Callable, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -27,14 +27,13 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): """Tests the directory service.""" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.mock_federation = Mock() + self.mock_federation = AsyncMock() self.mock_registry = Mock() self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} @@ -73,9 +72,10 @@ def test_get_local_association(self) -> None: self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self) -> None: - self.mock_federation.make_query.return_value = make_awaitable( - {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} - ) + self.mock_federation.make_query.return_value = { + "room_id": "!8765qwer:test", + "servers": ["test", "remote"], + } result = self.get_success(self.handler.get_association(self.remote_room)) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 2eaffe511ee4..c5556f284491 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable +from typing import Dict, Iterable from unittest import mock from parameterized import parameterized @@ -31,13 +31,12 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import override_config class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.appservice_api = mock.Mock() + self.appservice_api = mock.AsyncMock() return self.setup_test_homeserver( federation_client=mock.Mock(), application_service_api=self.appservice_api ) @@ -801,29 +800,27 @@ def test_query_devices_remote_no_sync(self) -> None: remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment] - return_value=make_awaitable( - { - "device_keys": {remote_user_id: {}}, - "master_keys": { - remote_user_id: { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - }, - "self_signing_keys": { - remote_user_id: { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" - + remote_self_signing_key: remote_self_signing_key - }, - } + self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign] + return_value={ + "device_keys": {remote_user_id: {}}, + "master_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - } - ) + }, + "self_signing_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + } + }, + } ) e2e_handler = self.hs.get_e2e_keys_handler() @@ -874,34 +871,29 @@ def test_query_devices_remote_sync(self) -> None: # Pretend we're sharing a room with the user we're querying. If not, # `_query_devices_for_destination` will return early. - self.store.get_rooms_for_user = mock.Mock( - return_value=make_awaitable({"some_room_id"}) - ) + self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"}) remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment] - return_value=make_awaitable( - { + self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign] + return_value={ + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" - + remote_self_signing_key: remote_self_signing_key - }, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key }, - } - ) + }, + } ) e2e_handler = self.hs.get_e2e_keys_handler() @@ -987,20 +979,20 @@ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> Non mock_get_rooms = mock.patch.object( self.store, "get_rooms_for_user", - new_callable=mock.MagicMock, - return_value=make_awaitable(["some_room_id"]), + new_callable=mock.AsyncMock, + return_value=["some_room_id"], ) mock_get_users = mock.patch.object( self.store, "get_users_server_still_shares_room_with", - new_callable=mock.MagicMock, - return_value=make_awaitable({remote_user_id}), + new_callable=mock.AsyncMock, + return_value={remote_user_id}, ) mock_request = mock.patch.object( self.hs.get_federation_client(), "query_user_devices", - new_callable=mock.MagicMock, - return_value=make_awaitable(response_body), + new_callable=mock.AsyncMock, + return_value=response_body, ) with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request: @@ -1060,8 +1052,9 @@ def test_query_appservice(self) -> None: ) # Setup a response, but only for device 2. - self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)]) + self.appservice_api.claim_client_keys.return_value = ( + {local_user: {device_id_2: otk}}, + [(local_user, device_id_1, "alg1", 1)], ) # we shouldn't have any unused fallback keys yet @@ -1127,9 +1120,10 @@ def test_query_appservice_with_fallback(self) -> None: ) # Setup a response. - self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, []) - ) + response: Dict[str, Dict[str, Dict[str, JsonDict]]] = { + local_user: {device_id_1: {**as_otk, **as_fallback_key}} + } + self.appservice_api.claim_client_keys.return_value = (response, []) # Claim OTKs, which will ask the appservice and do nothing else. claim_res = self.get_success( @@ -1171,8 +1165,9 @@ def test_query_appservice_with_fallback(self) -> None: self.assertEqual(fallback_res, ["alg1"]) # The appservice will return only the OTK. - self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_1: as_otk}}, []) + self.appservice_api.claim_client_keys.return_value = ( + {local_user: {device_id_1: as_otk}}, + [], ) # Claim OTKs, which should return the OTK from the appservice and the @@ -1234,8 +1229,9 @@ def test_query_appservice_with_fallback(self) -> None: self.assertEqual(fallback_res, ["alg1"]) # Finally, return only the fallback key from the appservice. - self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_1: as_fallback_key}}, []) + self.appservice_api.claim_client_keys.return_value = ( + {local_user: {device_id_1: as_fallback_key}}, + [], ) # Claim OTKs, which will return only the fallback key from the database. @@ -1350,13 +1346,11 @@ def test_query_local_devices_appservice(self) -> None: ) # Setup a response. - self.appservice_api.query_keys.return_value = make_awaitable( - { - "device_keys": { - local_user: {device_2: device_key_2b, device_3: device_key_3} - } + self.appservice_api.query_keys.return_value = { + "device_keys": { + local_user: {device_2: device_key_2b, device_3: device_key_3} } - ) + } # Request all devices. res = self.get_success(self.handler.query_local_devices({local_user: None})) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 5f11d5df11ad..21d63ab1f297 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -14,7 +14,7 @@ import logging from typing import Collection, Optional, cast from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor @@ -40,7 +40,7 @@ from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import event_injection, make_awaitable +from tests.test_utils import event_injection logger = logging.getLogger(__name__) @@ -370,15 +370,15 @@ def test_backfill_ignores_known_events(self) -> None: # We mock out the FederationClient.backfill method, to pretend that a remote # server has returned our fake event. - federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) - self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment] + federation_client_backfill_mock = AsyncMock(return_value=[event]) + self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[method-assign] # We also mock the persist method with a side effect of itself. This allows us # to track when it has been called while preserving its function. persist_events_and_notify_mock = Mock( side_effect=self.hs.get_federation_event_handler().persist_events_and_notify ) - self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment] + self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[method-assign] persist_events_and_notify_mock ) @@ -631,33 +631,29 @@ def test_failed_partial_join_is_clean(self) -> None: }, RoomVersions.V10, ) - mock_make_membership_event = Mock( - return_value=make_awaitable( - ( - "example.com", - membership_event, - RoomVersions.V10, - ) + mock_make_membership_event = AsyncMock( + return_value=( + "example.com", + membership_event, + RoomVersions.V10, ) ) - mock_send_join = Mock( - return_value=make_awaitable( - SendJoinResult( - membership_event, - "example.com", - state=[ - EVENT_CREATE, - EVENT_CREATOR_MEMBERSHIP, - EVENT_INVITATION_MEMBERSHIP, - ], - auth_chain=[ - EVENT_CREATE, - EVENT_CREATOR_MEMBERSHIP, - EVENT_INVITATION_MEMBERSHIP, - ], - partial_state=True, - servers_in_room={"example.com"}, - ) + mock_send_join = AsyncMock( + return_value=SendJoinResult( + membership_event, + "example.com", + state=[ + EVENT_CREATE, + EVENT_CREATOR_MEMBERSHIP, + EVENT_INVITATION_MEMBERSHIP, + ], + auth_chain=[ + EVENT_CREATE, + EVENT_CREATOR_MEMBERSHIP, + EVENT_INVITATION_MEMBERSHIP, + ], + partial_state=True, + servers_in_room={"example.com"}, ) ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 23f1b33b2fda..70e6a7e142f1 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -35,7 +35,7 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import event_injection, make_awaitable +from tests.test_utils import event_injection class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): @@ -50,6 +50,10 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation_transport_client = mock.Mock( spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] ) + self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock() + self.mock_federation_transport_client.get_room_state = mock.AsyncMock() + self.mock_federation_transport_client.get_event = mock.AsyncMock() + self.mock_federation_transport_client.backfill = mock.AsyncMock() return super().setup_test_homeserver( federation_transport_client=self.mock_federation_transport_client ) @@ -198,20 +202,14 @@ async def get_event( ) # we expect an outbound request to /state_ids, so stub that out - self.mock_federation_transport_client.get_room_state_ids.return_value = ( - make_awaitable( - { - "pdu_ids": [e.event_id for e in state_at_prev_event], - "auth_chain_ids": [], - } - ) - ) + self.mock_federation_transport_client.get_room_state_ids.return_value = { + "pdu_ids": [e.event_id for e in state_at_prev_event], + "auth_chain_ids": [], + } # we also expect an outbound request to /state self.mock_federation_transport_client.get_room_state.return_value = ( - make_awaitable( - StateRequestResponse(auth_events=[], state=state_at_prev_event) - ) + StateRequestResponse(auth_events=[], state=state_at_prev_event) ) # we have to bump the clock a bit, to keep the retry logic in @@ -273,26 +271,23 @@ def test_process_pulled_event_records_failed_backfill_attempts( room_version = self.get_success(main_store.get_room_version(room_id)) # We expect an outbound request to /state_ids, so stub that out - self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable( - { - # Mimic the other server not knowing about the state at all. - # We want to cause Synapse to throw an error (`Unable to get - # missing prev_event $fake_prev_event`) and fail to backfill - # the pulled event. - "pdu_ids": [], - "auth_chain_ids": [], - } - ) + self.mock_federation_transport_client.get_room_state_ids.return_value = { + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + "pdu_ids": [], + "auth_chain_ids": [], + } + # We also expect an outbound request to /state - self.mock_federation_transport_client.get_room_state.return_value = make_awaitable( - StateRequestResponse( - # Mimic the other server not knowing about the state at all. - # We want to cause Synapse to throw an error (`Unable to get - # missing prev_event $fake_prev_event`) and fail to backfill - # the pulled event. - auth_events=[], - state=[], - ) + self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse( + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + auth_events=[], + state=[], ) pulled_event = make_event_from_dict( @@ -545,25 +540,23 @@ def test_backfill_signature_failure_does_not_fetch_same_prev_event_later( ) # We expect an outbound request to /backfill, so stub that out - self.mock_federation_transport_client.backfill.return_value = make_awaitable( - { - "origin": self.OTHER_SERVER_NAME, - "origin_server_ts": 123, - "pdus": [ - # This is one of the important aspects of this test: we include - # `pulled_event_without_signatures` so it fails the signature check - # when we filter down the backfill response down to events which - # have valid signatures in - # `_check_sigs_and_hash_for_pulled_events_and_fetch` - pulled_event_without_signatures.get_pdu_json(), - # Then later when we process this valid signature event, when we - # fetch the missing `prev_event`s, we want to make sure that we - # backoff and don't try and fetch `pulled_event_without_signatures` - # again since we know it just had an invalid signature. - pulled_event.get_pdu_json(), - ], - } - ) + self.mock_federation_transport_client.backfill.return_value = { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + # This is one of the important aspects of this test: we include + # `pulled_event_without_signatures` so it fails the signature check + # when we filter down the backfill response down to events which + # have valid signatures in + # `_check_sigs_and_hash_for_pulled_events_and_fetch` + pulled_event_without_signatures.get_pdu_json(), + # Then later when we process this valid signature event, when we + # fetch the missing `prev_event`s, we want to make sure that we + # backoff and don't try and fetch `pulled_event_without_signatures` + # again since we know it just had an invalid signature. + pulled_event.get_pdu_json(), + ], + } # Keep track of the count and make sure we don't make any of these requests event_endpoint_requested_count = 0 @@ -731,15 +724,13 @@ def test_backfill_process_previously_failed_pull_attempt_event_in_the_background ) # We expect an outbound request to /backfill, so stub that out - self.mock_federation_transport_client.backfill.return_value = make_awaitable( - { - "origin": self.OTHER_SERVER_NAME, - "origin_server_ts": 123, - "pdus": [ - pulled_event.get_pdu_json(), - ], - } - ) + self.mock_federation_transport_client.backfill.return_value = { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + pulled_event.get_pdu_json(), + ], + } # The function under test: try to backfill and process the pulled event with LoggingContext("test"): diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 9691d66b48a0..1c5897c84e49 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -46,18 +46,11 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._persist_event_storage_controller = persistence self.user_id = self.register_user("tester", "foobar") - self.access_token = self.login("tester", "foobar") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) - - info = self.get_success( - self.hs.get_datastores().main.get_user_by_access_token( - self.access_token, - ) - ) - assert info is not None - self.token_id = info.token_id + device_id = "dev-1" + access_token = self.login("tester", "foobar", device_id=device_id) + self.room_id = self.helper.create_room_as(self.user_id, tok=access_token) - self.requester = create_requester(self.user_id, access_token_id=self.token_id) + self.requester = create_requester(self.user_id, device_id=device_id) def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: # Create a member event we can use as an auth_event diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 3baeb28e620f..534cae7f893f 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -39,7 +39,7 @@ from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils import FakeResponse, get_awaitable_result from tests.unittest import HomeserverTestCase, skip_unless from tests.utils import mock_getRawHeaders @@ -148,7 +148,7 @@ def _assertParams(self) -> None: def test_inactive_token(self) -> None: """The handler should return a 403 where the token is inactive.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={"active": False}, @@ -167,7 +167,7 @@ def test_inactive_token(self) -> None: def test_active_no_scope(self) -> None: """The handler should return a 403 where no scope is given.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={"active": True}, @@ -186,7 +186,7 @@ def test_active_no_scope(self) -> None: def test_active_user_no_subject(self) -> None: """The handler should return a 500 when no subject is present.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}, @@ -205,7 +205,7 @@ def test_active_user_no_subject(self) -> None: def test_active_no_user_scope(self) -> None: """The handler should return a 500 when no subject is present.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -228,7 +228,7 @@ def test_active_no_user_scope(self) -> None: def test_active_admin_not_user(self) -> None: """The handler should raise when the scope has admin right but not user.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -252,7 +252,7 @@ def test_active_admin_not_user(self) -> None: def test_active_admin(self) -> None: """The handler should return a requester with admin rights.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -282,7 +282,7 @@ def test_active_admin(self) -> None: def test_active_admin_highest_privilege(self) -> None: """The handler should resolve to the most permissive scope.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -314,7 +314,7 @@ def test_active_admin_highest_privilege(self) -> None: def test_active_user(self) -> None: """The handler should return a requester with normal user rights.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -345,7 +345,7 @@ def test_active_user_admin_impersonation(self) -> None: """The handler should return a requester with normal user rights and an user ID matching the one specified in query param `user_id`""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -379,7 +379,7 @@ def test_active_user_admin_impersonation(self) -> None: def test_active_user_with_device(self) -> None: """The handler should return a requester with normal user rights and a device ID.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -409,7 +409,7 @@ def test_active_user_with_device(self) -> None: def test_multiple_devices(self) -> None: """The handler should raise an error if multiple devices are found in the scope.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -434,7 +434,7 @@ def test_multiple_devices(self) -> None: def test_active_guest_not_allowed(self) -> None: """The handler should return an insufficient scope error.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -464,7 +464,7 @@ def test_active_guest_not_allowed(self) -> None: def test_active_guest_allowed(self) -> None: """The handler should return a requester with guest user rights and a device ID.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -500,19 +500,19 @@ def test_unavailable_introspection_endpoint(self) -> None: request.requestHeaders.getRawHeaders = mock_getRawHeaders() # The introspection endpoint is returning an error. - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse(code=500, body=b"Internal Server Error") ) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) # The introspection endpoint request fails. - self.http_client.request = simple_async_mock(raises=Exception()) + self.http_client.request = AsyncMock(side_effect=Exception()) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) # The introspection endpoint does not return a JSON object. - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload=["this is an array", "not an object"] ) @@ -521,7 +521,7 @@ def test_unavailable_introspection_endpoint(self) -> None: self.assertEqual(error.value.code, 503) # The introspection endpoint does not return valid JSON. - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse(code=200, body=b"this is not valid JSON") ) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) @@ -529,7 +529,7 @@ def test_unavailable_introspection_endpoint(self) -> None: def test_introspection_token_cache(self) -> None: access_token = "open_sesame" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={"active": "true", "scope": "guest", "jti": access_token}, @@ -560,7 +560,7 @@ def test_introspection_token_cache(self) -> None: # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a # token with a soon-to-expire `exp` field to the cache - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ @@ -641,7 +641,7 @@ def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: def test_cross_signing(self) -> None: """Try uploading device keys with OAuth delegation enabled.""" - self.http_client.request = simple_async_mock( + self.http_client.request = AsyncMock( return_value=FakeResponse.json( code=200, payload={ diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 0a8bae54fbea..e797aaae00dd 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, AsyncMock, Mock, patch from urllib.parse import parse_qs, urlparse import pymacaroons @@ -28,7 +28,7 @@ from synapse.util.macaroons import get_value_from_macaroon from synapse.util.stringutils import random_string -from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils import FakeResponse, get_awaitable_result from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config @@ -157,15 +157,15 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) - sso_handler.render_error = self.render_error # type: ignore[assignment] + sso_handler.render_error = self.render_error # type: ignore[method-assign] # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 auth_handler = hs.get_auth_handler() # Mock the complete SSO login method. - self.complete_sso_login = simple_async_mock() - auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + self.complete_sso_login = AsyncMock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[method-assign] return hs diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 394006f5f314..11ec8c7f116f 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -16,7 +16,7 @@ from http import HTTPStatus from typing import Any, Dict, List, Optional, Type, Union -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -32,7 +32,6 @@ from tests import unittest from tests.server import FakeChannel -from tests.test_utils import make_awaitable from tests.unittest import override_config # Login flows we expect to appear in the list after the normal ones. @@ -187,7 +186,7 @@ def password_only_auth_provider_login_test_body(self) -> None: self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) # check_password must return an awaitable - mock_password_provider.check_password.return_value = make_awaitable(True) + mock_password_provider.check_password = AsyncMock(return_value=True) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) @@ -209,13 +208,13 @@ def password_only_auth_provider_ui_auth_test_body(self) -> None: """UI Auth should delegate correctly to the password provider""" # log in twice, to get two devices - mock_password_provider.check_password.return_value = make_awaitable(True) + mock_password_provider.check_password = AsyncMock(return_value=True) tok1 = self.login("u", "p") self.login("u", "p", device_id="dev2") mock_password_provider.reset_mock() # have the auth provider deny the request to start with - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) # make the initial request which returns a 401 session = self._start_delete_device_session(tok1, "dev2") @@ -229,7 +228,7 @@ def password_only_auth_provider_ui_auth_test_body(self) -> None: mock_password_provider.reset_mock() # Finally, check the request goes through when we allow it - mock_password_provider.check_password.return_value = make_awaitable(True) + mock_password_provider.check_password = AsyncMock(return_value=True) channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @@ -243,7 +242,7 @@ def local_user_fallback_login_test_body(self) -> None: self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) @@ -260,7 +259,7 @@ def local_user_fallback_ui_auth_test_body(self) -> None: self.register_user("localuser", "localpass") # have the auth provider deny the request - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) # log in twice, to get two devices tok1 = self.login("localuser", "localpass") @@ -303,7 +302,7 @@ def no_local_user_fallback_login_test_body(self) -> None: self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -325,7 +324,7 @@ def no_local_user_fallback_ui_auth_test_body(self) -> None: self.register_user("localuser", "localpass") # allow login via the auth provider - mock_password_provider.check_password.return_value = make_awaitable(True) + mock_password_provider.check_password = AsyncMock(return_value=True) # log in twice, to get two devices tok1 = self.login("localuser", "p") @@ -342,7 +341,7 @@ def no_local_user_fallback_ui_auth_test_body(self) -> None: mock_password_provider.check_password.assert_not_called() # now try deleting with the local password - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) channel = self._authed_delete_device( tok1, "dev2", session, "localuser", "localpass" ) @@ -396,9 +395,7 @@ def custom_auth_provider_login_test_body(self) -> None: self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = make_awaitable( - ("@user:test", None) - ) + mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None)) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@user:test", channel.json_body["user_id"]) @@ -447,9 +444,7 @@ def custom_auth_provider_ui_auth_test_body(self) -> None: mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = make_awaitable( - ("@user:test", None) - ) + mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None)) body["auth"]["test_field"] = "foo" channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 403) @@ -460,8 +455,8 @@ def custom_auth_provider_ui_auth_test_body(self) -> None: mock_password_provider.reset_mock() # and finally, succeed - mock_password_provider.check_auth.return_value = make_awaitable( - ("@localuser:test", None) + mock_password_provider.check_auth = AsyncMock( + return_value=("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 200) @@ -478,10 +473,10 @@ def test_custom_auth_provider_callback(self) -> None: self.custom_auth_provider_callback_test_body() def custom_auth_provider_callback_test_body(self) -> None: - callback = Mock(return_value=make_awaitable(None)) + callback = AsyncMock(return_value=None) - mock_password_provider.check_auth.return_value = make_awaitable( - ("@user:test", callback) + mock_password_provider.check_auth = AsyncMock( + return_value=("@user:test", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) @@ -616,8 +611,8 @@ def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None: login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") - mock_password_provider.check_auth.return_value = make_awaitable( - ("@localuser:test", None) + mock_password_provider.check_auth = AsyncMock( + return_value=("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) @@ -835,11 +830,11 @@ def _test_3pid_allowed(self, username: str, registration: bool) -> None: username: The username to use for the test. registration: Whether to test with registration URLs. """ - self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment] - return_value=make_awaitable(0), + self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[method-assign] + return_value=0 ) - m = Mock(return_value=make_awaitable(False)) + m = AsyncMock(return_value=False) self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] self.register_user(username, "password") @@ -869,7 +864,7 @@ def _test_3pid_allowed(self, username: str, registration: bool) -> None: m.assert_called_once_with("email", "foo@test.com", registration) - m = Mock(return_value=make_awaitable(True)) + m = AsyncMock(return_value=True) self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] channel = self.make_request( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 1aebcc16adc5..a987267308ee 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -524,6 +524,7 @@ def default_config(self) -> JsonDict: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = f"@test:{self.hs.config.server.server_name}" + self.device_id = "dev-1" # Move the reactor to the initial time. self.reactor.advance(1000) @@ -608,7 +609,10 @@ def test_restored_presence_online_after_sync( self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000 / 2) self.get_success( presence_handler.user_syncing( - self.user_id, sync_state != PresenceState.OFFLINE, sync_state + self.user_id, + self.device_id, + sync_state != PresenceState.OFFLINE, + sync_state, ) ) @@ -632,6 +636,7 @@ def test_restored_presence_online_after_sync( class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): user_id = "@test:server" user_id_obj = UserID.from_string(user_id) + device_id = "dev-1" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.presence_handler = hs.get_presence_handler() @@ -641,13 +646,20 @@ def test_external_process_timeout(self) -> None: """Test that if an external process doesn't update the records for a while we time out their syncing users presence. """ - process_id = "1" - # Notify handler that a user is now syncing. + # Create a worker and use it to handle /sync traffic instead. + # This is used to test that presence changes get replicated from workers + # to the main process correctly. + worker_to_sync_against = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "synchrotron"} + ) + worker_presence_handler = worker_to_sync_against.get_presence_handler() + self.get_success( - self.presence_handler.update_external_syncs_row( - process_id, self.user_id, True, self.clock.time_msec() - ) + worker_presence_handler.user_syncing( + self.user_id, self.device_id, True, PresenceState.ONLINE + ), + by=0.1, ) # Check that if we wait a while without telling the handler the user has @@ -701,7 +713,7 @@ def test_user_goes_offline_manually_with_no_status_msg(self) -> None: # Mark user as offline self.get_success( self.presence_handler.set_state( - self.user_id_obj, {"presence": PresenceState.OFFLINE} + self.user_id_obj, self.device_id, {"presence": PresenceState.OFFLINE} ) ) @@ -733,7 +745,7 @@ def test_user_reset_online_with_no_status(self) -> None: # Mark user as online again self.get_success( self.presence_handler.set_state( - self.user_id_obj, {"presence": PresenceState.ONLINE} + self.user_id_obj, self.device_id, {"presence": PresenceState.ONLINE} ) ) @@ -762,7 +774,7 @@ def test_set_presence_from_syncing_not_set(self) -> None: self.get_success( self.presence_handler.user_syncing( - self.user_id, False, PresenceState.ONLINE + self.user_id, self.device_id, False, PresenceState.ONLINE ) ) @@ -779,7 +791,9 @@ def test_set_presence_from_syncing_is_set(self) -> None: self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg) self.get_success( - self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE) + self.presence_handler.user_syncing( + self.user_id, self.device_id, True, PresenceState.ONLINE + ) ) state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) @@ -793,7 +807,9 @@ def test_set_presence_from_syncing_keeps_status(self) -> None: self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg) self.get_success( - self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE) + self.presence_handler.user_syncing( + self.user_id, self.device_id, True, PresenceState.ONLINE + ) ) state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) @@ -820,7 +836,7 @@ def test_set_presence_from_syncing_keeps_busy( # This is used to test that presence changes get replicated from workers # to the main process correctly. worker_to_sync_against = self.make_worker_hs( - "synapse.app.generic_worker", {"worker_name": "presence_writer"} + "synapse.app.generic_worker", {"worker_name": "synchrotron"} ) # Set presence to BUSY @@ -831,8 +847,9 @@ def test_set_presence_from_syncing_keeps_busy( # /presence/*. self.get_success( worker_to_sync_against.get_presence_handler().user_syncing( - self.user_id, True, PresenceState.ONLINE - ) + self.user_id, self.device_id, True, PresenceState.ONLINE + ), + by=0.1, ) # Check against the main process that the user's presence did not change. @@ -840,6 +857,21 @@ def test_set_presence_from_syncing_keeps_busy( # we should still be busy self.assertEqual(state.state, PresenceState.BUSY) + # Advance such that the device would be discarded if it was not busy, + # then pump so _handle_timeouts function to called. + self.reactor.advance(IDLE_TIMER / 1000) + self.reactor.pump([5]) + + # The account should still be busy. + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) + self.assertEqual(state.state, PresenceState.BUSY) + + # Ensure that a /presence call can set the user *off* busy. + self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg) + + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) + self.assertEqual(state.state, PresenceState.ONLINE) + def _set_presencestate_with_status_msg( self, state: str, status_msg: Optional[str] ) -> None: @@ -852,6 +884,7 @@ def _set_presencestate_with_status_msg( self.get_success( self.presence_handler.set_state( self.user_id_obj, + self.device_id, {"presence": state, "status_msg": status_msg}, ) ) @@ -1093,7 +1126,9 @@ def test_remote_joins(self) -> None: # Mark test2 as online, test will be offline with a last_active of 0 self.get_success( self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + UserID.from_string("@test2:server"), + "dev-1", + {"presence": PresenceState.ONLINE}, ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -1140,7 +1175,9 @@ def test_remote_gets_presence_when_local_user_joins(self) -> None: # Mark test as online self.get_success( self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + UserID.from_string("@test:server"), + "dev-1", + {"presence": PresenceState.ONLINE}, ) ) @@ -1148,7 +1185,9 @@ def test_remote_gets_presence_when_local_user_joins(self) -> None: # Note we don't join them to the room yet self.get_success( self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + UserID.from_string("@test2:server"), + "dev-1", + {"presence": PresenceState.ONLINE}, ) ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index ec2f5d30bea9..f9b292b9ece1 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Awaitable, Callable, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from parameterized import parameterized @@ -26,7 +26,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class ProfileTestCase(unittest.HomeserverTestCase): @@ -35,7 +34,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.mock_federation = Mock() + self.mock_federation = AsyncMock() self.mock_registry = Mock() self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} @@ -135,9 +134,7 @@ def test_set_my_name_noauth(self) -> None: ) def test_get_other_name(self) -> None: - self.mock_federation.make_query.return_value = make_awaitable( - {"displayname": "Alice"} - ) + self.mock_federation.make_query.return_value = {"displayname": "Alice"} displayname = self.get_success(self.handler.get_displayname(self.alice)) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 54eeec228e20..e9fbf32c7ce9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Collection, List, Optional, Tuple -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -38,7 +38,6 @@ ) from synapse.util import Clock -from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -203,24 +202,22 @@ def test_mau_limits_when_disabled(self) -> None: @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self) -> None: - self.store.count_monthly_users = Mock( # type: ignore[assignment] - return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) + self.store.count_monthly_users = AsyncMock( # type: ignore[method-assign] + return_value=self.hs.config.server.max_mau_value - 1 ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self) -> None: - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.lots_of_users) - ) + self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + self.store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -229,15 +226,13 @@ def test_get_or_create_user_mau_blocked(self) -> None: @override_config({"limit_usage_by_mau": True}) def test_register_mau_blocked(self) -> None: - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.lots_of_users) - ) + self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + self.store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError @@ -292,7 +287,7 @@ def test_auto_create_auto_join_where_auto_create_is_false(self) -> None: @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: room_alias_str = "#room:test" - self.store.is_real_user = Mock(return_value=make_awaitable(False)) + self.store.is_real_user = AsyncMock(return_value=False) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -304,8 +299,8 @@ def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: room_alias_str = "#room:test" - self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment] - self.store.is_real_user = Mock(return_value=make_awaitable(True)) + self.store.count_real_users = AsyncMock(return_value=1) # type: ignore[method-assign] + self.store.is_real_user = AsyncMock(return_value=True) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_directory_handler() @@ -319,8 +314,8 @@ def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> N def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( self, ) -> None: - self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment] - self.store.is_real_user = Mock(return_value=make_awaitable(True)) + self.store.count_real_users = AsyncMock(return_value=2) # type: ignore[method-assign] + self.store.is_real_user = AsyncMock(return_value=True) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index 41199ffa297f..3e28117e2c0f 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, patch from twisted.test.proto_helpers import MemoryReactor @@ -16,7 +16,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request -from tests.test_utils import make_awaitable from tests.unittest import ( FederatingHomeserverTestCase, HomeserverTestCase, @@ -154,25 +153,21 @@ def test_remote_joins_contribute_to_rate_limit(self) -> None: None, ) - mock_make_membership_event = Mock( - return_value=make_awaitable( - ( - self.OTHER_SERVER_NAME, - join_event, - self.hs.config.server.default_room_version, - ) + mock_make_membership_event = AsyncMock( + return_value=( + self.OTHER_SERVER_NAME, + join_event, + self.hs.config.server.default_room_version, ) ) - mock_send_join = Mock( - return_value=make_awaitable( - SendJoinResult( - join_event, - self.OTHER_SERVER_NAME, - state=[create_event], - auth_chain=[create_event], - partial_state=False, - servers_in_room=frozenset(), - ) + mock_send_join = AsyncMock( + return_value=SendJoinResult( + join_event, + self.OTHER_SERVER_NAME, + state=[create_event], + auth_chain=[create_event], + partial_state=False, + servers_in_room=frozenset(), ) ) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index b5c772a7aedf..00f4e181e81a 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Dict, Optional, Set, Tuple -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import attr @@ -25,7 +25,6 @@ from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -134,7 +133,7 @@ def test_map_saml_response_to_user(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) @@ -164,7 +163,7 @@ def test_map_saml_response_to_existing_user(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # Map a user via SSO. saml_response = FakeAuthnResponse( @@ -206,11 +205,11 @@ def test_map_saml_response_to_invalid_localpart(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # mock out the error renderer too sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] + sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) request = _mock_request() @@ -227,9 +226,9 @@ def test_map_saml_response_to_user_retries(self) -> None: # stub out the auth handler and error renderer auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] + sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] # register a user to occupy the first-choice MXID store = self.hs.get_datastores().main @@ -312,7 +311,7 @@ def test_attribute_requirements(self) -> None: # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # The response doesn't have the proper userGroup or department. saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index 8b6e4a40b620..a066745d70b8 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -13,19 +13,40 @@ # limitations under the License. -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Type, Union +from unittest.mock import patch from zope.interface import implementer from twisted.internet import defer -from twisted.internet.address import IPv4Address +from twisted.internet._sslverify import ClientTLSOptions +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import ensureDeferred +from twisted.internet.interfaces import IProtocolFactory +from twisted.internet.ssl import ContextFactory from twisted.mail import interfaces, smtp from tests.server import FakeTransport from tests.unittest import HomeserverTestCase, override_config +def TestingESMTPTLSClientFactory( + contextFactory: ContextFactory, + _connectWrapped: bool, + wrappedProtocol: IProtocolFactory, +) -> IProtocolFactory: + """We use this to pass through in testing without using TLS, but + saving the context information to check that it would have happened. + + Note that this is what the MemoryReactor does on connectSSL. + It only saves the contextFactory, but starts the connection with the + underlying Factory. + See: L{twisted.internet.testing.MemoryReactor.connectSSL}""" + + wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined] + return wrappedProtocol + + @implementer(interfaces.IMessageDelivery) class _DummyMessageDelivery: def __init__(self) -> None: @@ -75,7 +96,13 @@ def connectionLost(self) -> None: pass -class SendEmailHandlerTestCase(HomeserverTestCase): +class SendEmailHandlerTestCaseIPv4(HomeserverTestCase): + ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address + + def setUp(self) -> None: + super().setUp() + self.reactor.lookups["localhost"] = "127.0.0.1" + def test_send_email(self) -> None: """Happy-path test that we can send email to a non-TLS server.""" h = self.hs.get_send_email_handler() @@ -89,7 +116,7 @@ def test_send_email(self) -> None: (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ 0 ] - self.assertEqual(host, "localhost") + self.assertEqual(host, self.reactor.lookups["localhost"]) self.assertEqual(port, 25) # wire it up to an SMTP server @@ -105,7 +132,9 @@ def test_send_email(self) -> None: FakeTransport( client_protocol, self.reactor, - peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + peer_address=self.ip_class( + "TCP", self.reactor.lookups["localhost"], 1234 + ), ) ) @@ -118,6 +147,10 @@ def test_send_email(self) -> None: self.assertEqual(str(user), "foo@bar.com") self.assertIn(b"Subject: test subject", msg) + @patch( + "synapse.handlers.send_email.TLSMemoryBIOFactory", + TestingESMTPTLSClientFactory, + ) @override_config( { "email": { @@ -135,17 +168,23 @@ def test_send_email_force_tls(self) -> None: ) ) # there should be an attempt to connect to localhost:465 - self.assertEqual(len(self.reactor.sslClients), 1) + self.assertEqual(len(self.reactor.tcpClients), 1) ( host, port, client_factory, - contextFactory, _timeout, _bindAddress, - ) = self.reactor.sslClients[0] - self.assertEqual(host, "localhost") + ) = self.reactor.tcpClients[0] + self.assertEqual(host, self.reactor.lookups["localhost"]) self.assertEqual(port, 465) + # We need to make sure that TLS is happenning + self.assertIsInstance( + client_factory._wrappedFactory._testingContextFactory, + ClientTLSOptions, + ) + # And since we use endpoints, they go through reactor.connectTCP + # which works differently to connectSSL on the testing reactor # wire it up to an SMTP server message_delivery = _DummyMessageDelivery() @@ -160,7 +199,9 @@ def test_send_email_force_tls(self) -> None: FakeTransport( client_protocol, self.reactor, - peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + peer_address=self.ip_class( + "TCP", self.reactor.lookups["localhost"], 1234 + ), ) ) @@ -172,3 +213,11 @@ def test_send_email_force_tls(self) -> None: user, msg = message_delivery.messages.pop() self.assertEqual(str(user), "foo@bar.com") self.assertIn(b"Subject: test subject", msg) + + +class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4): + ip_class = IPv6Address + + def setUp(self) -> None: + super().setUp() + self.reactor.lookups["localhost"] = "::1" diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 9f035a02dc69..948d04fc323f 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, Mock, patch from twisted.test.proto_helpers import MemoryReactor @@ -29,7 +29,6 @@ import tests.unittest import tests.utils -from tests.test_utils import make_awaitable class SyncTestCase(tests.unittest.HomeserverTestCase): @@ -253,8 +252,8 @@ def test_ban_wins_race_with_join(self) -> None: mocked_get_prev_events = patch.object( self.hs.get_datastores().main, "get_prev_events_for_room", - new_callable=MagicMock, - return_value=make_awaitable([last_room_creation_event_id]), + new_callable=AsyncMock, + return_value=[last_room_creation_event_id], ) with mocked_get_prev_events: self.helper.join(room_id, eve, tok=eve_token) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5da1d95f0b22..2a295da3a0b7 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -15,7 +15,7 @@ import json from typing import Dict, List, Set -from unittest.mock import ANY, Mock, call +from unittest.mock import ANY, AsyncMock, Mock, call from netaddr import IPSet @@ -33,7 +33,6 @@ from tests import unittest from tests.server import ThreadedMemoryReactorClock -from tests.test_utils import make_awaitable from tests.unittest import override_config # Some local users to test with @@ -74,11 +73,11 @@ def make_homeserver( # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) - mock_keyring.verify_json_for_server.return_value = make_awaitable(True) + mock_keyring.verify_json_for_server = AsyncMock(return_value=True) # we mock out the federation client too - self.mock_federation_client = Mock(spec=["put_json"]) - self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) + self.mock_federation_client = AsyncMock(spec=["put_json"]) + self.mock_federation_client.put_json.return_value = (200, "OK") self.mock_federation_client.agent = MatrixFederationAgent( reactor, tls_client_options_factory=None, @@ -121,20 +120,18 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.datastore = hs.get_datastores().main - self.datastore.get_destination_retry_timings = Mock( - return_value=make_awaitable(None) - ) + self.datastore.get_destination_retry_timings = AsyncMock(return_value=None) - self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment] - return_value=make_awaitable((0, [])) + self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[method-assign] + return_value=(0, []) ) - self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) + self.datastore.get_destination_last_successful_stream_ordering = AsyncMock( # type: ignore[method-assign] + return_value=None ) - self.datastore.get_received_txn_response = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) + self.datastore.get_received_txn_response = AsyncMock( # type: ignore[method-assign] + return_value=None ) self.room_members: List[UserID] = [] @@ -146,25 +143,25 @@ async def check_user_in_room(room_id: str, requester: Requester) -> None: raise AuthError(401, "User is not in the room") return None - hs.get_auth().check_user_in_room = Mock( # type: ignore[assignment] + hs.get_auth().check_user_in_room = Mock( # type: ignore[method-assign] side_effect=check_user_in_room ) async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID - hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[assignment] + hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[method-assign] side_effect=check_host_in_room ) async def get_current_hosts_in_room(room_id: str) -> Set[str]: return {member.domain for member in self.room_members} - hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[method-assign] side_effect=get_current_hosts_in_room ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[method-assign] side_effect=get_current_hosts_in_room ) @@ -173,27 +170,25 @@ async def get_users_in_room(room_id: str) -> Set[str]: self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room) - self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment] - side_effect=( - # we deliberately return a non-None stream pos to avoid - # doing an initial_sync - lambda: make_awaitable(1) - ) + self.datastore.get_user_directory_stream_pos = AsyncMock( # type: ignore[method-assign] + # we deliberately return a non-None stream pos to avoid + # doing an initial_sync + return_value=1 ) - self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment] + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[method-assign] - self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment] - side_effect=lambda: 0 + self.datastore.get_to_device_stream_token = Mock( # type: ignore[method-assign] + return_value=0 ) - self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment] - side_effect=lambda *args, **kargs: make_awaitable(([], 0)) + self.datastore.get_new_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign] + return_value=([], 0) ) - self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment] - side_effect=lambda *args, **kargs: make_awaitable(None) + self.datastore.delete_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign] + return_value=None ) - self.datastore.set_received_txn_response = Mock( # type: ignore[assignment] - side_effect=lambda *args, **kwargs: make_awaitable(None) + self.datastore.set_received_txn_response = AsyncMock( # type: ignore[method-assign] + return_value=None ) def test_started_typing_local(self) -> None: diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 430209705e23..b5f15aa7d425 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Tuple -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from urllib.parse import quote from twisted.test.proto_helpers import MemoryReactor @@ -30,7 +30,7 @@ from tests import unittest from tests.storage.test_user_directory import GetUserDirectoryTables -from tests.test_utils import event_injection, make_awaitable +from tests.test_utils import event_injection from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config @@ -471,7 +471,7 @@ def test_handle_user_deactivated_regular_user(self) -> None: self.store.register_user(user_id=r_user_id, password_hash=None) ) - mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) + mock_remove_from_user_dir = AsyncMock(return_value=None) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 6a0b5fc0bd56..0d17f2fe5be4 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -14,8 +14,8 @@ import base64 import logging import os -from typing import Any, Awaitable, Callable, Generator, List, Optional, cast -from unittest.mock import Mock, patch +from typing import Generator, List, Optional, cast +from unittest.mock import AsyncMock, patch import treq from netaddr import IPSet @@ -41,7 +41,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent -from synapse.http.federation.srv_resolver import Server +from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import ( WELL_KNOWN_MAX_SIZE, WellKnownResolver, @@ -68,21 +68,11 @@ logger = logging.getLogger(__name__) -# Once Async Mocks or lambdas are supported this can go away. -def generate_resolve_service( - result: List[Server], -) -> Callable[[Any], Awaitable[List[Server]]]: - async def resolve_service(_: Any) -> List[Server]: - return result - - return resolve_service - - class MatrixFederationAgentTests(unittest.TestCase): def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() - self.mock_resolver = Mock() + self.mock_resolver = AsyncMock(spec=SrvResolver) config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] @@ -636,7 +626,7 @@ def test_get_hostname_bad_cert(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv1"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar") @@ -722,7 +712,7 @@ def test_get_no_srv_no_well_known(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") @@ -776,7 +766,7 @@ def test_get_well_known(self) -> None: """Test the behaviour when the .well-known delegates elsewhere""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -840,7 +830,7 @@ def test_get_well_known_redirect(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -930,7 +920,7 @@ def test_get_invalid_well_known(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") @@ -986,7 +976,7 @@ def test_get_well_known_unsigned_cert(self) -> None: # the config left to the default, which will not trust it (since the # presented cert is signed by a test CA) - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" config = default_config("test", parse=True) @@ -1037,9 +1027,9 @@ def test_get_hostname_srv(self) -> None: """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service( - [Server(host=b"srvtarget", port=8443)] - ) + self.mock_resolver.resolve_service.return_value = [ + Server(host=b"srvtarget", port=8443) + ] self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") @@ -1094,9 +1084,9 @@ def test_get_well_known_srv(self) -> None: self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - self.mock_resolver.resolve_service.side_effect = generate_resolve_service( - [Server(host=b"srvtarget", port=8443)] - ) + self.mock_resolver.resolve_service.return_value = [ + Server(host=b"srvtarget", port=8443) + ] self._handle_well_known_connection( client_factory, @@ -1137,7 +1127,7 @@ def test_idna_servername(self) -> None: """test the behaviour when the server name has idna chars in""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" @@ -1201,9 +1191,9 @@ def test_idna_srv_target(self) -> None: """test the behaviour when the target of a SRV record has idna chars""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service( - [Server(host=b"xn--trget-3qa.com", port=8443)] # tĂ¢rget.com - ) + self.mock_resolver.resolve_service.return_value = [ + Server(host=b"xn--trget-3qa.com", port=8443) + ] # tĂ¢rget.com self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request( @@ -1407,12 +1397,10 @@ def test_srv_fallbacks(self) -> None: """Test that other SRV results are tried if the first one fails.""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service( - [ - Server(host=b"target.com", port=8443), - Server(host=b"target.com", port=8444), - ] - ) + self.mock_resolver.resolve_service.return_value = [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] self.reactor.lookups["target.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index fa27f1279a95..c379853e20ef 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -164,7 +164,7 @@ def test_with_request_context(self) -> None: # Call requestReceived to finish instantiating the object. request.content = BytesIO() # Partially skip some internal processing of SynapseRequest. - request._started_processing = Mock() # type: ignore[assignment] + request._started_processing = Mock() # type: ignore[method-assign] request.request_metrics = Mock(spec=["name"]) with patch.object(Request, "render"): request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index fe631d7ecbd8..172fc3a736df 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -33,7 +33,6 @@ from tests.events.test_presence_router import send_presence_update, sync_presence from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.test_utils import simple_async_mock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config @@ -70,7 +69,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. self.fed_transport_client = Mock(spec=["send_transaction"]) - self.fed_transport_client.send_transaction = simple_async_mock({}) + self.fed_transport_client.send_transaction = AsyncMock(return_value={}) return self.setup_test_homeserver( federation_transport_client=self.fed_transport_client, @@ -234,7 +233,7 @@ def test_get_user_ip_and_agents__no_user_found(self) -> None: def test_sending_events_into_room(self) -> None: """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent - self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment] + self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[method-assign] spec=[], side_effect=self.event_creation_handler.create_and_send_nonmember_event, ) @@ -579,10 +578,8 @@ def test_update_room_membership_remote_join(self) -> None: """Test that the module API can join a remote room.""" # Necessary to fake a remote join. fake_stream_id = 1 - mocked_remote_join = simple_async_mock( - return_value=("fake-event-id", fake_stream_id) - ) - self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] + mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id)) + self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[method-assign] fake_remote_host = f"{self.module_api.server_name}-remote" # Given that the join is to be faked, we expect the relevant join event not to diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 829b9df83d4e..7c23b77e0a11 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Optional -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from parameterized import parameterized @@ -28,7 +28,6 @@ from synapse.types import JsonDict, create_requester from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -191,7 +190,7 @@ def test_action_for_event_by_user_disabled_by_config(self) -> None: # Mock the method which calculates push rules -- we do this instead of # e.g. checking the results in the database because we want to ensure # that code isn't even running. - bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment] + bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[method-assign] # Ensure no actions are generated! self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) @@ -382,7 +381,6 @@ def test_room_mentions(self) -> None: ) ) - @override_config({"experimental_features": {"msc3958_supress_edit_notifs": True}}) def test_suppress_edits(self) -> None: """Under the default push rules, event edits should not generate notifications.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py index f7c6417a09fd..af25815fa56e 100644 --- a/tests/replication/storage/test_events.py +++ b/tests/replication/storage/test_events.py @@ -58,7 +58,7 @@ def patch__eq__(cls: object) -> Callable[[], None]: def unpatch() -> None: if eq is not None: - cls.__eq__ = eq # type: ignore[assignment] + cls.__eq__ = eq # type: ignore[method-assign] return unpatch diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index a324b4d31dde..9b28cd474fbf 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from netaddr import IPSet @@ -26,7 +26,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import get_clock -from tests.test_utils import make_awaitable logger = logging.getLogger(__name__) @@ -62,7 +61,7 @@ def test_send_event_single_sender(self) -> None: new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.return_value = make_awaitable({}) + mock_client.put_json = AsyncMock(return_value={}) mock_client.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", @@ -93,7 +92,7 @@ def test_send_event_sharded(self) -> None: new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.return_value = make_awaitable({}) + mock_client1.put_json = AsyncMock(return_value={}) mock_client1.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", @@ -108,7 +107,7 @@ def test_send_event_sharded(self) -> None: ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.return_value = make_awaitable({}) + mock_client2.put_json = AsyncMock(return_value={}) mock_client2.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", @@ -162,7 +161,7 @@ def test_send_typing_sharded(self) -> None: new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.return_value = make_awaitable({}) + mock_client1.put_json = AsyncMock(return_value={}) mock_client1.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", @@ -177,7 +176,7 @@ def test_send_typing_sharded(self) -> None: ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.return_value = make_awaitable({}) + mock_client2.put_json = AsyncMock(return_value={}) mock_client2.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index feb81844aee9..2f6bd0d74faa 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -18,7 +18,7 @@ import urllib.parse from binascii import unhexlify from typing import List, Optional -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized, parameterized_class @@ -45,7 +45,7 @@ from tests import unittest from tests.server import FakeSite, make_request -from tests.test_utils import SMALL_PNG, make_awaitable +from tests.test_utils import SMALL_PNG from tests.unittest import override_config @@ -71,8 +71,8 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs.config.registration.registration_shared_secret = "shared" - self.hs.get_media_repository = Mock() # type: ignore[assignment] - self.hs.get_deactivate_account_handler = Mock() # type: ignore[assignment] + self.hs.get_media_repository = Mock() # type: ignore[method-assign] + self.hs.get_deactivate_account_handler = Mock() # type: ignore[method-assign] return self.hs @@ -419,8 +419,8 @@ def test_register_mau_limit_reached(self) -> None: store = self.hs.get_datastores().main # Set monthly active users to the limit - store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1834,8 +1834,8 @@ def test_create_user_mau_limit_reached_active_admin(self) -> None: ) # Set monthly active users to the limit - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + self.store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1871,8 +1871,8 @@ def test_create_user_mau_limit_reached_passive_admin(self) -> None: handler = self.hs.get_registration_handler() # Set monthly active users to the limit - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + self.store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 6c04e6c56cc2..4c69d224b81a 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -50,7 +50,7 @@ async def check_username( ) handler = self.hs.get_registration_handler() - handler.check_username = check_username # type: ignore[assignment] + handler.check_username = check_username # type: ignore[method-assign] def test_username_available(self) -> None: """ diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index ac19f3c6daef..e9f495e20671 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1346,7 +1346,7 @@ async def post_json( return {} # Register a mock that will return the expected result depending on the remote. - self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment] + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[method-assign] # Check that we've got the correct response from the client-side endpoint. self._test_status( diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py index d5b0640e7aec..481db9a687c3 100644 --- a/tests/rest/client/test_account_data.py +++ b/tests/rest/client/test_account_data.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock from synapse.rest import admin from synapse.rest.client import account_data, login, room from tests import unittest -from tests.test_utils import make_awaitable class AccountDataTestCase(unittest.HomeserverTestCase): @@ -32,7 +31,7 @@ def test_on_account_data_updated_callback(self) -> None: """Tests that the on_account_data_updated module callback is called correctly when a user's account data changes. """ - mocked_callback = Mock(return_value=make_awaitable(None)) + mocked_callback = AsyncMock(return_value=None) self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append( mocked_callback ) diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 54df2a252c01..141e0f57a33b 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -45,7 +45,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(config=config) - hs.get_federation_handler = Mock() # type: ignore[assignment] + hs.get_federation_handler = Mock() # type: ignore[method-assign] return hs diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index a2d5d340be35..90a8df147c7c 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -65,14 +65,14 @@ def test_add_filter_for_other_user(self) -> None: def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False # type: ignore[assignment] + self.hs.is_mine = lambda target_user: False # type: ignore[method-assign] channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.hs.is_mine = _is_mine # type: ignore[assignment] + self.hs.is_mine = _is_mine # type: ignore[method-assign] self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index ffbc13bb8df3..a2a65895647f 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -169,7 +169,8 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well "account": {"per_second": 10000, "burst_count": 10000}, - } + }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_address(self) -> None: @@ -189,12 +190,15 @@ def test_POST_ratelimiting_per_address(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) + retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertTrue(retry_after_ms < 6000) + self.assertLess(retry_after_ms, 6000) + assert retry_header + self.assertLessEqual(int(retry_header[0]), 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -217,7 +221,8 @@ def test_POST_ratelimiting_per_address(self) -> None: # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well "address": {"per_second": 10000, "burst_count": 10000}, - } + }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_account(self) -> None: @@ -234,12 +239,15 @@ def test_POST_ratelimiting_per_account(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) + retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertTrue(retry_after_ms < 6000) + self.assertLess(retry_after_ms, 6000) + assert retry_header + self.assertLessEqual(int(retry_header[0]), 6) self.reactor.advance(retry_after_ms / 1000.0) @@ -262,7 +270,8 @@ def test_POST_ratelimiting_per_account(self) -> None: # rc_login dict here, we need to set this manually as well "address": {"per_second": 10000, "burst_count": 10000}, "failed_attempts": {"per_second": 0.17, "burst_count": 5}, - } + }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: @@ -279,12 +288,15 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) + retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 403, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertTrue(retry_after_ms < 6000) + self.assertLess(retry_after_ms, 6000) + assert retry_header + self.assertLessEqual(int(retry_header[0]), 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -569,8 +581,9 @@ def test_spam_checker_deny(self) -> None: body, ) self.assertEqual(channel.code, 403, channel.result) - self.assertDictContainsSubset( - {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body + self.assertLessEqual( + {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}.items(), + channel.json_body.items(), ) diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py index 700f6587a007..41ceb3db51a4 100644 --- a/tests/rest/client/test_notifications.py +++ b/tests/rest/client/test_notifications.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -20,7 +20,6 @@ from synapse.server import HomeServer from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase @@ -45,7 +44,7 @@ def prepare( def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) - fed_transport_client.send_transaction = simple_async_mock({}) + fed_transport_client.send_transaction = AsyncMock(return_value={}) return self.setup_test_homeserver( federation_transport_client=fed_transport_client, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index e12098102b96..66b387cea37e 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -23,7 +23,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class PresenceTestCase(unittest.HomeserverTestCase): @@ -36,7 +35,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.presence_handler = Mock(spec=PresenceHandler) - self.presence_handler.set_state.return_value = make_awaitable(None) + self.presence_handler.set_state = AsyncMock(return_value=None) hs = self.setup_test_homeserver( "red", diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index b228dba8613d..c33393dc284b 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -75,7 +75,7 @@ def test_POST_appservice_registration_valid(self) -> None: self.assertEqual(channel.code, 200, msg=channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} - self.assertDictContainsSubset(det_data, channel.json_body) + self.assertLessEqual(det_data.items(), channel.json_body.items()) def test_POST_appservice_registration_no_type(self) -> None: as_token = "i_am_an_app_service" @@ -136,7 +136,7 @@ def test_POST_user_valid(self) -> None: "device_id": device_id, } self.assertEqual(channel.code, 200, msg=channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) + self.assertLessEqual(det_data.items(), channel.json_body.items()) @override_config({"enable_registration": False}) def test_POST_disabled_registration(self) -> None: @@ -157,7 +157,7 @@ def test_POST_guest_registration(self) -> None: det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEqual(channel.code, 200, msg=channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) + self.assertLessEqual(det_data.items(), channel.json_body.items()) def test_POST_disabled_guest_registration(self) -> None: self.hs.config.registration.allow_guest_access = False @@ -267,7 +267,7 @@ def test_POST_registration_requires_token(self) -> None: "device_id": device_id, } self.assertEqual(channel.code, 200, msg=channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) + self.assertLessEqual(det_data.items(), channel.json_body.items()) # Check the `completed` counter has been incremented and pending is 0 res = self.get_success( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 9bfe913e451e..61773fb28c32 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import urllib.parse from typing import Any, Callable, Dict, List, Optional, Tuple -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from twisted.test.proto_helpers import MemoryReactor @@ -28,7 +28,6 @@ from tests import unittest from tests.server import FakeChannel -from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_event from tests.unittest import override_config @@ -264,7 +263,8 @@ def test_ignore_invalid_room(self) -> None: # Disable the validation to pretend this came over federation. with patch( "synapse.handlers.message.EventCreationHandler._validate_event_relation", - new=lambda self, event: make_awaitable(None), + new_callable=AsyncMock, + return_value=None, ): # Generate a various relations from a different room. self.get_success( @@ -570,7 +570,7 @@ def test_edit_reply(self) -> None: ) self.assertEqual(200, channel.code, channel.json_body) event_result = channel.json_body - self.assertDictContainsSubset(original_body, event_result["content"]) + self.assertLessEqual(original_body.items(), event_result["content"].items()) # also check /context, which returns the *edited* event channel = self.make_request( @@ -587,14 +587,14 @@ def test_edit_reply(self) -> None: (context_result, "/context"), ): # The reference metadata should still be intact. - self.assertDictContainsSubset( + self.assertLessEqual( { "m.relates_to": { "event_id": self.parent_id, "rel_type": "m.reference", } - }, - result_event_dict["content"], + }.items(), + result_event_dict["content"].items(), desc, ) @@ -1300,7 +1300,8 @@ def test_nested_thread(self) -> None: # not an event the Client-Server API will allow.. with patch( "synapse.handlers.message.EventCreationHandler._validate_event_relation", - new=lambda self, event: make_awaitable(None), + new_callable=AsyncMock, + return_value=None, ): # Create a sub-thread off the thread, which is not allowed. self._send_relation( @@ -1371,9 +1372,11 @@ def test_thread_edit_latest_event(self) -> None: latest_event_in_thread = thread_summary["latest_event"] # The latest event in the thread should have the edit appear under the # bundled aggregations. - self.assertDictContainsSubset( - {"event_id": edit_event_id, "sender": "@alice:test"}, - latest_event_in_thread["unsigned"]["m.relations"][RelationTypes.REPLACE], + self.assertLessEqual( + {"event_id": edit_event_id, "sender": "@alice:test"}.items(), + latest_event_in_thread["unsigned"]["m.relations"][ + RelationTypes.REPLACE + ].items(), ) def test_aggregation_get_event_for_annotation(self) -> None: @@ -1636,9 +1639,9 @@ def test_redact_relation_thread(self) -> None: ################################################## self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertDictContainsSubset( - {"count": 3, "current_user_participated": True}, - relations[RelationTypes.THREAD], + self.assertLessEqual( + {"count": 3, "current_user_participated": True}.items(), + relations[RelationTypes.THREAD].items(), ) # The latest event is the last sent event. self.assertEqual( @@ -1657,9 +1660,9 @@ def test_redact_relation_thread(self) -> None: # The thread should still exist, but the latest event should be updated. self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertDictContainsSubset( - {"count": 2, "current_user_participated": True}, - relations[RelationTypes.THREAD], + self.assertLessEqual( + {"count": 2, "current_user_participated": True}.items(), + relations[RelationTypes.THREAD].items(), ) # And the latest event is the last unredacted event. self.assertEqual( @@ -1676,9 +1679,9 @@ def test_redact_relation_thread(self) -> None: # Nothing should have changed (except the thread count). self.assertEqual(self._get_related_events(), thread_replies) relations = self._get_bundled_aggregations() - self.assertDictContainsSubset( - {"count": 1, "current_user_participated": True}, - relations[RelationTypes.THREAD], + self.assertLessEqual( + {"count": 1, "current_user_participated": True}.items(), + relations[RelationTypes.THREAD].items(), ) # And the latest event is the last unredacted event. self.assertEqual( @@ -1773,12 +1776,12 @@ def test_redact_parent_thread(self) -> None: event_ids = self._get_related_events() relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) - self.assertDictContainsSubset( + self.assertLessEqual( { "count": 1, "current_user_participated": True, - }, - relations[RelationTypes.THREAD], + }.items(), + relations[RelationTypes.THREAD].items(), ) self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 88e579dc393f..47c1d38ad7dd 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -20,7 +20,7 @@ import json from http import HTTPStatus from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from unittest.mock import Mock, call, patch +from unittest.mock import AsyncMock, Mock, call, patch from urllib import parse as urlparse from parameterized import param, parameterized @@ -52,7 +52,6 @@ from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase -from tests.test_utils import make_awaitable from tests.test_utils.event_injection import create_event from tests.unittest import override_config @@ -69,15 +68,15 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: "red", ) - self.hs.get_federation_handler = Mock() # type: ignore[assignment] - self.hs.get_federation_handler.return_value.maybe_backfill = Mock( - return_value=make_awaitable(None) + self.hs.get_federation_handler = Mock() # type: ignore[method-assign] + self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock( + return_value=None ) async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[method-assign] return self.hs @@ -2375,7 +2374,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - return self.setup_test_homeserver(federation_client=Mock()) + return self.setup_test_homeserver(federation_client=AsyncMock()) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.register_user("user", "pass") @@ -2385,7 +2384,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] + self.federation_client.get_public_rooms.return_value = {} # type: ignore[attr-defined] search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} @@ -2413,7 +2412,7 @@ def test_fallback(self) -> None: # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), - make_awaitable({}), + {}, ) search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} @@ -3413,17 +3412,17 @@ def test_threepid_invite_spamcheck_deprecated(self) -> None: # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] - self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign] + self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign] + return_value=None, ) # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # allow everything for now. # `spec` argument is needed for this function mock to have `__qualname__`, which # is needed for `Measure` metrics buried in SpamChecker. - mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None) + mock = AsyncMock(return_value=True, spec=lambda *x: None) self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( mock ) @@ -3451,7 +3450,7 @@ def test_threepid_invite_spamcheck_deprecated(self) -> None: # Now change the return value of the callback to deny any invite and test that # we can't send the invite. - mock.return_value = make_awaitable(False) + mock.return_value = False channel = self.make_request( method="POST", path="/rooms/" + self.room_id + "/invite", @@ -3477,18 +3476,18 @@ def test_threepid_invite_spamcheck(self) -> None: # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] - self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign] + self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign] + return_value=None, ) # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # allow everything for now. # `spec` argument is needed for this function mock to have `__qualname__`, which # is needed for `Measure` metrics buried in SpamChecker. - mock = Mock( - return_value=make_awaitable(synapse.module_api.NOT_SPAM), + mock = AsyncMock( + return_value=synapse.module_api.NOT_SPAM, spec=lambda *x: None, ) self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( @@ -3519,7 +3518,7 @@ def test_threepid_invite_spamcheck(self) -> None: # Now change the return value of the callback to deny any invite and test that # we can't send the invite. We pick an arbitrary error code to be able to check # that the same code has been returned - mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) + mock.return_value = Codes.CONSENT_NOT_GIVEN channel = self.make_request( method="POST", path="/rooms/" + self.room_id + "/invite", @@ -3538,7 +3537,7 @@ def test_threepid_invite_spamcheck(self) -> None: make_invite_mock.assert_called_once() # Run variant with `Tuple[Codes, dict]`. - mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"})) + mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"}) channel = self.make_request( method="POST", path="/rooms/" + self.room_id + "/invite", diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 8d2cdf875150..9aecf88e4160 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -84,7 +84,7 @@ def test_invite(self) -> None: def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() - identity_handler.lookup_3pid = Mock( # type: ignore[assignment] + identity_handler.lookup_3pid = Mock( # type: ignore[method-assign] side_effect=AssertionError("This should not get called") ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index e5ba5a970639..57eb713b150a 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -13,7 +13,7 @@ # limitations under the License. import threading from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -33,7 +33,6 @@ from synapse.util.frozenutils import unfreeze from tests import unittest -from tests.test_utils import make_awaitable if TYPE_CHECKING: from synapse.module_api import ModuleApi @@ -118,7 +117,7 @@ async def approve_all_signature_checking( async def _check_event_auth(origin: Any, event: Any, context: Any) -> None: pass - hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] + hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[method-assign] return hs @@ -477,7 +476,7 @@ async def test_fn( def test_on_new_event(self) -> None: """Test that the on_new_event callback is called on new events""" - on_new_event = Mock(make_awaitable(None)) + on_new_event = AsyncMock(return_value=None) self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append( on_new_event ) @@ -580,7 +579,7 @@ def test_on_profile_update(self) -> None: avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" # Register a mock callback. - m = Mock(return_value=make_awaitable(None)) + m = AsyncMock(return_value=None) self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( m ) @@ -641,7 +640,7 @@ def test_on_profile_update_admin(self) -> None: avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" # Register a mock callback. - m = Mock(return_value=make_awaitable(None)) + m = AsyncMock(return_value=None) self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( m ) @@ -682,7 +681,7 @@ def test_on_user_deactivation_status_changed(self) -> None: correctly when processing a user's deactivation. """ # Register a mocked callback. - deactivation_mock = Mock(return_value=make_awaitable(None)) + deactivation_mock = AsyncMock(return_value=None) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._on_user_deactivation_status_changed_callbacks.append( deactivation_mock, @@ -690,7 +689,7 @@ def test_on_user_deactivation_status_changed(self) -> None: # Also register a mocked callback for profile updates, to check that the # deactivation code calls it in a way that let modules know the user is being # deactivated. - profile_mock = Mock(return_value=make_awaitable(None)) + profile_mock = AsyncMock(return_value=None) self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( profile_mock, ) @@ -740,7 +739,7 @@ def test_on_user_deactivation_status_changed_admin(self) -> None: well as a reactivation. """ # Register a mock callback. - m = Mock(return_value=make_awaitable(None)) + m = AsyncMock(return_value=None) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._on_user_deactivation_status_changed_callbacks.append(m) @@ -794,7 +793,7 @@ def test_check_can_deactivate_user(self) -> None: correctly when processing a user's deactivation. """ # Register a mocked callback. - deactivation_mock = Mock(return_value=make_awaitable(False)) + deactivation_mock = AsyncMock(return_value=False) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._check_can_deactivate_user_callbacks.append( deactivation_mock, @@ -840,7 +839,7 @@ def test_check_can_deactivate_user_admin(self) -> None: correctly when processing a user's deactivation triggered by a server admin. """ # Register a mocked callback. - deactivation_mock = Mock(return_value=make_awaitable(False)) + deactivation_mock = AsyncMock(return_value=False) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._check_can_deactivate_user_callbacks.append( deactivation_mock, @@ -879,7 +878,7 @@ def test_check_can_shutdown_room(self) -> None: correctly when processing an admin's shutdown room request. """ # Register a mocked callback. - shutdown_mock = Mock(return_value=make_awaitable(False)) + shutdown_mock = AsyncMock(return_value=False) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._check_can_shutdown_room_callbacks.append( shutdown_mock, @@ -915,7 +914,7 @@ def test_on_threepid_bind(self) -> None: associating a 3PID to an account. """ # Register a mocked callback. - threepid_bind_mock = Mock(return_value=make_awaitable(None)) + threepid_bind_mock = AsyncMock(return_value=None) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock) @@ -957,11 +956,9 @@ def test_on_add_and_remove_user_third_party_identifier(self) -> None: just before associating and removing a 3PID to/from an account. """ # Pretend to be a Synapse module and register both callbacks as mocks. - on_add_user_third_party_identifier_callback_mock = Mock( - return_value=make_awaitable(None) - ) - on_remove_user_third_party_identifier_callback_mock = Mock( - return_value=make_awaitable(None) + on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None) + on_remove_user_third_party_identifier_callback_mock = AsyncMock( + return_value=None ) self.hs.get_module_api().register_third_party_rules_callbacks( on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock, @@ -1021,8 +1018,8 @@ def test_on_remove_user_third_party_identifier_is_called_on_deactivate( when a user is deactivated and their third-party ID associations are deleted. """ # Pretend to be a Synapse module and register both callbacks as mocks. - on_remove_user_third_party_identifier_callback_mock = Mock( - return_value=make_awaitable(None) + on_remove_user_third_party_identifier_callback_mock = AsyncMock( + return_value=None ) self.hs.get_module_api().register_third_party_rules_callbacks( on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock, diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index d8dc56261ac1..951a3cbc43e9 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -14,7 +14,7 @@ from http import HTTPStatus from typing import Any, Generator, Tuple, cast -from unittest.mock import Mock, call +from unittest.mock import AsyncMock, Mock, call from twisted.internet import defer, reactor as _reactor @@ -24,7 +24,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.utils import MockClock reactor = cast(ISynapseReactor, _reactor) @@ -53,7 +52,7 @@ def setUp(self) -> None: def test_executes_given_function( self, ) -> Generator["defer.Deferred[Any]", object, None]: - cb = Mock(return_value=make_awaitable(self.mock_http_response)) + cb = AsyncMock(return_value=self.mock_http_response) res = yield self.cache.fetch_or_execute_request( self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg" ) @@ -64,7 +63,7 @@ def test_executes_given_function( def test_deduplicates_based_on_key( self, ) -> Generator["defer.Deferred[Any]", object, None]: - cb = Mock(return_value=make_awaitable(self.mock_http_response)) + cb = AsyncMock(return_value=self.mock_http_response) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute_request( self.mock_request, @@ -168,7 +167,7 @@ def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": @defer.inlineCallbacks def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: - cb = Mock(return_value=make_awaitable(self.mock_http_response)) + cb = AsyncMock(return_value=self.mock_http_response) yield self.cache.fetch_or_execute_request( self.mock_request, self.mock_requester, cb, "an arg" ) diff --git a/tests/server.py b/tests/server.py index ff03d2886476..08633fe640f4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import hashlib +import ipaddress import json import logging import os @@ -45,7 +46,7 @@ from typing_extensions import ParamSpec from zope.interface import implementer -from twisted.internet import address, threads, udp +from twisted.internet import address, tcp, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError @@ -567,6 +568,8 @@ def connectTCP( conn = super().connectTCP( host, port, factory, timeout=timeout, bindAddress=None ) + if self.lookups and host in self.lookups: + validate_connector(conn, self.lookups[host]) callback = self._tcp_callbacks.get((host, port)) if callback: @@ -599,6 +602,55 @@ def advance(self, amount: float) -> None: super().advance(0) +def validate_connector(connector: tcp.Connector, expected_ip: str) -> None: + """Try to validate the obtained connector as it would happen when + synapse is running and the conection will be established. + + This method will raise a useful exception when necessary, else it will + just do nothing. + + This is in order to help catch quirks related to reactor.connectTCP, + since when called directly, the connector's destination will be of type + IPv4Address, with the hostname as the literal host that was given (which + could be an IPv6-only host or an IPv6 literal). + + But when called from reactor.connectTCP *through* e.g. an Endpoint, the + connector's destination will contain the specific IP address with the + correct network stack class. + + Note that testing code paths that use connectTCP directly should not be + affected by this check, unless they specifically add a test with a + matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of + type ThreadedMemoryReactorClock. + For an example of implementing such tests, see test/handlers/send_email.py. + """ + destination = connector.getDestination() + + # We use address.IPv{4,6}Address to check what the reactor thinks it is + # is sending but check for validity with ipaddress.IPv{4,6}Address + # because they fail with IPs on the wrong network stack. + cls_mapping = { + address.IPv4Address: ipaddress.IPv4Address, + address.IPv6Address: ipaddress.IPv6Address, + } + + cls = cls_mapping.get(destination.__class__) + + if cls is not None: + try: + cls(expected_ip) + except Exception as exc: + raise ValueError( + "Invalid IP type and resolution for %s. Expected %s to be %s" + % (destination, expected_ip, cls.__name__) + ) from exc + else: + raise ValueError( + "Unknown address type %s for %s" + % (destination.__class__.__name__, destination) + ) + + class ThreadPool: """ Threadless thread pool. @@ -670,7 +722,7 @@ def runInteraction( **kwargs, ) - pool.runWithConnection = runWithConnection # type: ignore[assignment] + pool.runWithConnection = runWithConnection # type: ignore[method-assign] pool.runInteraction = runInteraction # type: ignore[assignment] # Replace the thread pool with a threadless 'thread' pool pool.threadpool = ThreadPool(clock._reactor) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index d2bfa53eda49..17f428bfc5e5 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -29,7 +29,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import default_config @@ -69,24 +68,22 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: assert isinstance(rlsn, ResourceLimitsServerNotices) self._rlsn = rlsn - self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(1000) - ) - self._rlsn._server_notices_manager.send_notice = Mock( # type: ignore[assignment] - return_value=make_awaitable(Mock()) + self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=1000) + self._rlsn._server_notices_manager.send_notice = AsyncMock( # type: ignore[method-assign] + return_value=Mock() ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.user_id = "@user_id:test" - self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( - return_value=make_awaitable("!something:localhost") + self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = ( + AsyncMock(return_value="!something:localhost") ) - self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock( - return_value=make_awaitable("!something:localhost") + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = AsyncMock( + return_value="!something:localhost" ) - self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] + self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[method-assign] + self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[method-assign] @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self) -> None: @@ -103,14 +100,14 @@ def test_maybe_send_server_notice_to_user_flag_off(self) -> None: def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = Mock( # type: ignore[assignment] - return_value=make_awaitable({"123": mock_event}) + self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] + return_value={"123": mock_event} ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -125,16 +122,16 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> No """ Test when user has blocked notice, but notice ought to be there (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError(403, "foo"), ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = Mock( # type: ignore[assignment] - return_value=make_awaitable({"123": mock_event}) + self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] + return_value={"123": mock_event} ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -145,8 +142,8 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None: """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError(403, "foo"), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -158,8 +155,8 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None: """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -171,12 +168,10 @@ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None: Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) - ) - self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(None) + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None ) + self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=None) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @@ -189,8 +184,8 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked( Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), @@ -204,8 +199,8 @@ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None: """ Test that when a server is disabled, that MAU limit alerting is ignored. """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED ), @@ -223,22 +218,22 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked( When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), ) - self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment] - return_value=make_awaitable((True, [])) + self._rlsn._is_room_currently_blocked = AsyncMock( # type: ignore[method-assign] + return_value=(True, []) ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = Mock( # type: ignore[assignment] - return_value=make_awaitable({"123": mock_event}) + self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] + return_value={"123": mock_event} ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -284,11 +279,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self) -> None: - self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) + self.store.get_monthly_active_count = AsyncMock(return_value=1000) - self.store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(1000) - ) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=1000) # Call the function multiple times to ensure we only send the notice once self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -327,7 +320,7 @@ def test_no_invite_without_notice(self) -> None: hasn't been reached (since it's the only user and the limit is 5), so users shouldn't receive a server notice. """ - m = Mock(return_value=make_awaitable(None)) + m = AsyncMock(return_value=None) self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m user_id = self.register_user("user", "password") diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index f541f1d6be1e..650b4941bab6 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -132,6 +132,7 @@ def test_timeout_lock(self) -> None: # We simulate the process getting stuck by cancelling the looping call # that keeps the lock active. + assert lock._looping_call lock._looping_call.stop() # Wait for the lock to timeout. @@ -403,6 +404,7 @@ def test_timeout_lock(self) -> None: # We simulate the process getting stuck by cancelling the looping call # that keeps the lock active. + assert lock._looping_call lock._looping_call.stop() # Wait for the lock to timeout. diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 71302facd14d..cbce26a725c8 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -15,7 +15,7 @@ import os import tempfile from typing import List, cast -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import yaml @@ -35,7 +35,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): @@ -339,7 +338,7 @@ def test_get_oldest_unsent_txn(self) -> None: # we aren't testing store._base stuff here, so mock this out # (ignore needed because Mypy won't allow us to assign to a method otherwise) - self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] + self.store.get_events_as_list = AsyncMock(return_value=events) # type: ignore[method-assign] self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) self.get_success(self._insert_txn(service.id, 10, events)) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index a4a823a25242..abf7d0564d81 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from unittest.mock import Mock +import logging +from unittest.mock import AsyncMock, Mock import yaml @@ -32,7 +32,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable, simple_async_mock from tests.unittest import override_config @@ -331,6 +330,28 @@ async def update_short(progress: JsonDict, count: int) -> int: self.update_handler.side_effect = update_short self.get_success(self.updates.do_next_background_update(False)) + def test_failed_update_logs_exception_details(self) -> None: + needle = "RUH ROH RAGGY" + + def failing_update(progress: JsonDict, count: int) -> int: + raise Exception(needle) + + self.update_handler.side_effect = failing_update + self.update_handler.reset_mock() + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": "{}"}, + ) + ) + + with self.assertLogs(level=logging.ERROR) as logs: + # Expect a back-to-back RuntimeError to be raised + self.get_failure(self.updates.run_background_updates(False), RuntimeError) + + self.assertTrue(any(needle in log for log in logs.output), logs.output) + class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -348,8 +369,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Mock out the AsyncContextManager class MockCM: - __aenter__ = simple_async_mock(return_value=None) - __aexit__ = simple_async_mock(return_value=None) + __aenter__ = AsyncMock(return_value=None) + __aexit__ = AsyncMock(return_value=None) self._update_ctx_manager = MockCM @@ -363,9 +384,9 @@ class MockCM: # Register the callbacks with more mocks self.hs.get_module_api().register_background_update_controller_callbacks( on_update=self._on_update, - min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)), - default_batch_size=Mock( - return_value=make_awaitable(self._default_batch_size), + min_batch_size=AsyncMock(return_value=self._default_batch_size), + default_batch_size=AsyncMock( + return_value=self._default_batch_size, ), ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 209d68b40ba9..6b9692c48625 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Any, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock from parameterized import parameterized @@ -30,7 +30,6 @@ from tests import unittest from tests.server import make_request -from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -66,15 +65,15 @@ def test_insert_new_client_ip(self) -> None: ) r = result[(user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user_id, "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, - }, - r, + }.items(), + r.items(), ) def test_insert_new_client_ip_none_device_id(self) -> None: @@ -443,9 +442,7 @@ def test_adding_monthly_active_user_when_full(self) -> None: lots_of_users = 100 user_id = "@user:server" - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(lots_of_users) - ) + self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) self.get_success( self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id" @@ -529,15 +526,15 @@ def test_devices_last_seen_bg_update(self) -> None: ) r = result[(user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user_id, "device_id": device_id, "ip": None, "user_agent": None, "last_seen": None, - }, - r, + }.items(), + r.items(), ) # Register the background update to run again. @@ -564,15 +561,15 @@ def test_devices_last_seen_bg_update(self) -> None: ) r = result[(user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user_id, "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, - }, - r, + }.items(), + r.items(), ) def test_old_user_ips_pruned(self) -> None: @@ -643,15 +640,15 @@ def test_old_user_ips_pruned(self) -> None: ) r = result2[(user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user_id, "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, - }, - r, + }.items(), + r.items(), ) def test_invalid_user_agents_are_ignored(self) -> None: @@ -780,13 +777,13 @@ def _runtest( self.store.get_last_client_ip_by_device(self.user_id, device_id) ) r = result[(self.user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": self.user_id, "device_id": device_id, "ip": expected_ip, "user_agent": "Mozzila pizza", "last_seen": 123456100, - }, - r, + }.items(), + r.items(), ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f03807c8f9d4..58ab41cf2670 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -58,13 +58,13 @@ def test_store_new_device(self) -> None: res = self.get_success(self.store.get_device("user_id", "device_id")) assert res is not None - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": "user_id", "device_id": "device_id", "display_name": "display_name", - }, - res, + }.items(), + res.items(), ) def test_get_devices_by_user(self) -> None: @@ -80,21 +80,21 @@ def test_get_devices_by_user(self) -> None: res = self.get_success(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": "user_id", "device_id": "device1", "display_name": "display_name 1", - }, - res["device1"], + }.items(), + res["device1"].items(), ) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": "user_id", "device_id": "device2", "display_name": "display_name 2", - }, - res["device2"], + }.items(), + res["device2"].items(), ) def test_count_devices_by_users(self) -> None: diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 5fde3b9c7879..2033377b5247 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -38,7 +38,7 @@ def test_key_without_device_name(self) -> None: self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset(json, dev) + self.assertLessEqual(json.items(), dev.items()) def test_reupload_key(self) -> None: now = 1470174257070 @@ -71,8 +71,12 @@ def test_get_key_with_device_name(self) -> None: self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset( - {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev + self.assertLessEqual( + { + "key": "value", + "unsigned": {"device_display_name": "display_name"}, + }.items(), + dev.items(), ) def test_multiple_devices(self) -> None: diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 282773837907..49366440ce10 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List -from unittest.mock import Mock +from unittest.mock import AsyncMock from twisted.test.proto_helpers import MemoryReactor @@ -21,7 +21,6 @@ from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 @@ -253,7 +252,7 @@ def test_populate_monthly_users_is_guest(self) -> None: ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -261,24 +260,22 @@ def test_populate_monthly_users_is_guest(self) -> None: self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self) -> None: - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] - self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] + self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[method-assign] - self.store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(None) - ) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) d = self.store.populate_monthly_active_users("user_id") self.get_success(d) self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self) -> None: - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] - self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] - self.store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(self.hs.get_clock().time_msec()) + self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[method-assign] + self.store.user_last_seen_monthly_active = AsyncMock( + return_value=self.hs.get_clock().time_msec() ) d = self.store.populate_monthly_active_users("user_id") @@ -359,7 +356,7 @@ def test_track_monthly_users_without_cap(self) -> None: @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self) -> None: - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 71ec74eadc91..1e27f2c275a1 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -44,13 +44,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def test_get_room(self) -> None: res = self.get_success(self.store.get_room(self.room.to_string())) assert res is not None - self.assertDictContainsSubset( + self.assertLessEqual( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "is_public": True, - }, - res, + }.items(), + res.items(), ) def test_get_room_unknown_room(self) -> None: @@ -59,13 +59,13 @@ def test_get_room_unknown_room(self) -> None: def test_get_room_with_stats(self) -> None: res = self.get_success(self.store.get_room_with_stats(self.room.to_string())) assert res is not None - self.assertDictContainsSubset( + self.assertLessEqual( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "public": True, - }, - res, + }.items(), + res.items(), ) def test_get_room_with_stats_unknown_room(self) -> None: diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index 0e3fc2a77f05..29be8cdbd0e8 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -22,7 +22,6 @@ PartialStateEventsTracker, ) -from tests.test_utils import make_awaitable from tests.unittest import TestCase @@ -124,16 +123,17 @@ def test_cancellation(self) -> None: class PartialCurrentStateTrackerTestCase(TestCase): def setUp(self) -> None: self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) + self.mock_store.is_partial_state_room = mock.AsyncMock() self.tracker = PartialCurrentStateTracker(self.mock_store) def test_does_not_block_for_full_state_rooms(self) -> None: - self.mock_store.is_partial_state_room.return_value = make_awaitable(False) + self.mock_store.is_partial_state_room.return_value = False self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) def test_blocks_for_partial_room_state(self) -> None: - self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + self.mock_store.is_partial_state_room.return_value = True d = ensureDeferred(self.tracker.await_full_state("room_id")) @@ -156,7 +156,7 @@ async def is_partial_state_room(room_id: str) -> bool: self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) def test_cancellation(self) -> None: - self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + self.mock_store.is_partial_state_room.return_value = True d1 = ensureDeferred(self.tracker.await_full_state("room_id")) self.assertNoResult(d1) diff --git a/tests/test_federation.py b/tests/test_federation.py index 6d15ac759785..f8ade6da3852 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Collection, List, Optional, Union -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -31,7 +31,6 @@ from synapse.util.retryutils import NotRetryingDestination from tests import unittest -from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): @@ -81,7 +80,7 @@ async def _check_event_auth( ) -> None: pass - federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment] + federation_event_handler._check_event_auth = _check_event_auth # type: ignore[method-assign] self.client = self.hs.get_federation_client() async def _check_sigs_and_hash_for_pulled_events_and_fetch( @@ -191,12 +190,12 @@ def query_user_devices( # Register the mock on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment] + federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign] # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.hs.get_datastores().main - store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) + store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"]) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. @@ -241,27 +240,24 @@ def test_cross_signing_keys_retry(self) -> None: # Register mock device list retrieval on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = Mock( # type: ignore[assignment] - return_value=make_awaitable( - { + federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign] + return_value={ + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" - + remote_self_signing_key: remote_self_signing_key - }, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key }, - } - ) + }, + } ) # Resync the device list. diff --git a/tests/test_state.py b/tests/test_state.py index eded38c7669d..9c8679cc1dc9 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -714,7 +714,7 @@ def test_resolve_state_conflict( store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events # type: ignore[assignment] + self.dummy_store.get_events = store.get_events # type: ignore[method-assign] context: EventContext context = yield self._get_context( @@ -773,7 +773,7 @@ def test_standard_depth_conflict( store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events # type: ignore[assignment] + self.dummy_store.get_events = store.get_events # type: ignore[method-assign] context: EventContext context = yield self._get_context( diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 52424aa08713..64a49488c654 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -85,7 +85,9 @@ def test_ui_auth(self) -> None: } } self.assertIsInstance(channel.json_body["params"], dict) - self.assertDictContainsSubset(channel.json_body["params"], expected_params) + self.assertLessEqual( + channel.json_body["params"].items(), expected_params.items() + ) # We have to complete the dummy auth stage before completing the terms stage request_data = { diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index c8cc841d9540..fa731426cda5 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -18,10 +18,8 @@ import json import sys import warnings -from asyncio import Future from binascii import unhexlify -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar -from unittest.mock import Mock +from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar import attr import zope.interface @@ -57,27 +55,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") -def make_awaitable(result: TV) -> Awaitable[TV]: - """ - Makes an awaitable, suitable for mocking an `async` function. - This uses Futures as they can be awaited multiple times so can be returned - to multiple callers. - """ - future: Future[TV] = Future() - future.set_result(result) - return future - - def setup_awaitable_errors() -> Callable[[], None]: """ Convert warnings from a non-awaited coroutines into errors. """ warnings.simplefilter("error", RuntimeWarning) - # unraisablehook was added in Python 3.8. - if not hasattr(sys, "unraisablehook"): - return lambda: None - # State shared between unraisablehook and check_for_unraisable_exceptions. unraisable_exceptions = [] orig_unraisablehook = sys.unraisablehook @@ -100,18 +83,6 @@ def cleanup() -> None: return cleanup -def simple_async_mock( - return_value: Optional[TV] = None, raises: Optional[Exception] = None -) -> Mock: - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args: Any, **kwargs: Any) -> Optional[TV]: - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) - - # Type ignore: it does not fully implement IResponse, but is good enough for tests @zope.interface.implementer(IResponse) @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/tests/unittest.py b/tests/unittest.py index b0721e060c40..5d3640d8ac24 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -313,7 +313,7 @@ class HomeserverTestCase(TestCase): servlets: List of servlet registration function. user_id (str): The user ID to assume if auth is hijacked. hijack_auth: Whether to hijack auth to return the user specified - in user_id. + in user_id. """ hijack_auth: ClassVar[bool] = True @@ -395,9 +395,9 @@ async def get_requester(*args: Any, **kwargs: Any) -> Requester: ) # Type ignore: mypy doesn't like us assigning to methods. - self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment] - self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment] - self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment] + self.hs.get_auth().get_user_by_req = get_requester # type: ignore[method-assign] + self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[method-assign] + self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[method-assign] if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment] diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 91cac9822af4..05983ed434b1 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -60,11 +60,9 @@ def check_called_first(res: int) -> int: observer1.addBoth(check_called_first) # store the results - results: List[Optional[ObservableDeferred[int]]] = [None, None] + results: List[Optional[int]] = [None, None] - def check_val( - res: ObservableDeferred[int], idx: int - ) -> ObservableDeferred[int]: + def check_val(res: int, idx: int) -> int: results[idx] = res return res @@ -93,14 +91,14 @@ def check_called_first(res: int) -> int: observer1.addBoth(check_called_first) # store the results - results: List[Optional[ObservableDeferred[str]]] = [None, None] + results: List[Optional[Failure]] = [None, None] - def check_val(res: ObservableDeferred[str], idx: int) -> None: + def check_failure(res: Failure, idx: int) -> None: results[idx] = res return None - observer1.addErrback(check_val, 0) - observer2.addErrback(check_val, 1) + observer1.addErrback(check_failure, 0) + observer2.addErrback(check_failure, 1) try: raise Exception("gah!") diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py index 3a97559bf04b..8665aeb50c0f 100644 --- a/tests/util/test_task_scheduler.py +++ b/tests/util/test_task_scheduler.py @@ -22,10 +22,11 @@ from synapse.util import Clock from synapse.util.task_scheduler import TaskScheduler -from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.unittest import HomeserverTestCase, override_config -class TestTaskScheduler(unittest.HomeserverTestCase): +class TestTaskScheduler(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.task_scheduler = hs.get_task_scheduler() self.task_scheduler.register_action(self._test_task, "_test_task") @@ -34,7 +35,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.task_scheduler.register_action(self._resumable_task, "_resumable_task") async def _test_task( - self, task: ScheduledTask, first_launch: bool + self, task: ScheduledTask ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: # This test task will copy the parameters to the result result = None @@ -77,7 +78,7 @@ def test_schedule_task(self) -> None: self.assertIsNone(task) async def _sleeping_task( - self, task: ScheduledTask, first_launch: bool + self, task: ScheduledTask ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: # Sleep for a second await deferLater(self.reactor, 1, lambda: None) @@ -85,24 +86,18 @@ async def _sleeping_task( def test_schedule_lot_of_tasks(self) -> None: """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior.""" - timestamp = self.clock.time_msec() + 30 * 1000 task_ids = [] for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1): task_ids.append( self.get_success( self.task_scheduler.schedule_task( "_sleeping_task", - timestamp=timestamp, params={"val": i}, ) ) ) - # The timestamp being 30s after now the task should been executed - # after the first scheduling loop is run - self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) - - # This is to give the time to the sleeping tasks to finish + # This is to give the time to the active tasks to finish self.reactor.advance(1) # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one @@ -120,10 +115,11 @@ def test_schedule_lot_of_tasks(self) -> None: ) scheduled_tasks = [ - t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED + t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE ] self.assertEquals(len(scheduled_tasks), 1) + # We need to wait for the next run of the scheduler loop self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) self.reactor.advance(1) @@ -138,7 +134,7 @@ def test_schedule_lot_of_tasks(self) -> None: ) async def _raising_task( - self, task: ScheduledTask, first_launch: bool + self, task: ScheduledTask ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: raise Exception("raising") @@ -146,15 +142,13 @@ def test_schedule_raising_task(self) -> None: """Schedule a task raising an exception and check it runs to failure and report exception content.""" task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task")) - self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) - task = self.get_success(self.task_scheduler.get_task(task_id)) assert task is not None self.assertEqual(task.status, TaskStatus.FAILED) self.assertEqual(task.error, "raising") async def _resumable_task( - self, task: ScheduledTask, first_launch: bool + self, task: ScheduledTask ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: if task.result and "in_progress" in task.result: return TaskStatus.COMPLETE, {"success": True}, None @@ -169,8 +163,6 @@ def test_schedule_resumable_task(self) -> None: """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart.""" task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task")) - self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) - task = self.get_success(self.task_scheduler.get_task(task_id)) assert task is not None self.assertEqual(task.status, TaskStatus.ACTIVE) @@ -184,3 +176,33 @@ def test_schedule_resumable_task(self) -> None: self.assertEqual(task.status, TaskStatus.COMPLETE) assert task.result is not None self.assertTrue(task.result.get("success")) + + +class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.task_scheduler = hs.get_task_scheduler() + self.task_scheduler.register_action(self._test_task, "_test_task") + + async def _test_task( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + return (TaskStatus.COMPLETE, None, None) + + @override_config({"run_background_tasks_on": "worker1"}) + def test_schedule_task(self) -> None: + """Check that a task scheduled to run now is launch right away on the background worker.""" + bg_worker_hs = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={"worker_name": "worker1"}, + ) + bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task") + + task_id = self.get_success( + self.task_scheduler.schedule_task( + "_test_task", + ) + ) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.COMPLETE)