diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000000..83ddd568c207 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,8 @@ +# Black reformatting (#5482). +32e7c9e7f20b57dd081023ac42d6931a8da9b3a3 + +# Target Python 3.5 with black (#8664). +aff1eb7c671b0a3813407321d2702ec46c71fa56 + +# Update black to 20.8b1 (#9381). +0a00b7ff14890987f09112a2ae696c61001e6cf1 diff --git a/CHANGES.md b/CHANGES.md index d9ecbac44050..d3e8fbae34ab 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,9 +1,143 @@ -Synapse 1.xx.0 -============== +Synapse 1.30.0rc1 (2021-03-16) +============================== + +Note that this release deprecates the ability for appservices to +call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice +developers should use a `type` value of `m.login.application_service` as +per [the spec](https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions). +In future releases, calling this endpoint with an access token - but without a `m.login.application_service` +type - will fail. + +Features +-------- + +- Add prometheus metrics for number of users successfully registering and logging in. ([\#9510](https://github.com/matrix-org/synapse/issues/9510), [\#9511](https://github.com/matrix-org/synapse/issues/9511), [\#9573](https://github.com/matrix-org/synapse/issues/9573)) +- Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers. ([\#9540](https://github.com/matrix-org/synapse/issues/9540)) +- Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets. ([\#9549](https://github.com/matrix-org/synapse/issues/9549)) +- Optimise handling of incomplete room history for incoming federation. ([\#9601](https://github.com/matrix-org/synapse/issues/9601)) +- Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9617](https://github.com/matrix-org/synapse/issues/9617)) +- Tell spam checker modules about the SSO IdP a user registered through if one was used. ([\#9626](https://github.com/matrix-org/synapse/issues/9626)) + + +Bugfixes +-------- + +- Fix long-standing bug when generating thumbnails for some images with transparency: `TypeError: cannot unpack non-iterable int object`. ([\#9473](https://github.com/matrix-org/synapse/issues/9473)) +- Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. ([\#9542](https://github.com/matrix-org/synapse/issues/9542), [\#9583](https://github.com/matrix-org/synapse/issues/9583)) +- Fix bug where federation requests were not correctly retried on 5xx responses. ([\#9567](https://github.com/matrix-org/synapse/issues/9567)) +- Fix re-activating an account via the admin API when local passwords are disabled. ([\#9587](https://github.com/matrix-org/synapse/issues/9587)) +- Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. ([\#9597](https://github.com/matrix-org/synapse/issues/9597)) +- Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. ([\#9620](https://github.com/matrix-org/synapse/issues/9620)) +- Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. ([\#9623](https://github.com/matrix-org/synapse/issues/9623)) + + +Updates to the Docker image +--------------------------- + +- Use jemalloc if available in docker. ([\#8553](https://github.com/matrix-org/synapse/issues/8553)) + + +Improved Documentation +---------------------- + +- Add relayd entry to reverse proxy example configurations. ([\#9508](https://github.com/matrix-org/synapse/issues/9508)) +- Improve the SAML2 upgrade notes for 1.27.0. ([\#9550](https://github.com/matrix-org/synapse/issues/9550)) +- Link to the "List user's media" admin API from the media admin API docs. ([\#9571](https://github.com/matrix-org/synapse/issues/9571)) +- Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. ([\#9580](https://github.com/matrix-org/synapse/issues/9580)) +- Clarify the sample configuration for `stats` settings. ([\#9604](https://github.com/matrix-org/synapse/issues/9604)) + + +Deprecations and Removals +------------------------- + +- The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`. ([\#9540](https://github.com/matrix-org/synapse/issues/9540)) +- Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release. ([\#9559](https://github.com/matrix-org/synapse/issues/9559)) + + +Internal Changes +---------------- + +- Add tests to ResponseCache. ([\#9458](https://github.com/matrix-org/synapse/issues/9458)) +- Add type hints to purge room and server notice admin API. ([\#9520](https://github.com/matrix-org/synapse/issues/9520)) +- Add extra logging to ObservableDeferred when callbacks throw exceptions. ([\#9523](https://github.com/matrix-org/synapse/issues/9523)) +- Fix incorrect type hints. ([\#9528](https://github.com/matrix-org/synapse/issues/9528), [\#9543](https://github.com/matrix-org/synapse/issues/9543), [\#9591](https://github.com/matrix-org/synapse/issues/9591), [\#9608](https://github.com/matrix-org/synapse/issues/9608), [\#9618](https://github.com/matrix-org/synapse/issues/9618)) +- Add an additional test for purging a room. ([\#9541](https://github.com/matrix-org/synapse/issues/9541)) +- Add a `.git-blame-ignore-revs` file with the hashes of auto-formatting. ([\#9560](https://github.com/matrix-org/synapse/issues/9560)) +- Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. ([\#9561](https://github.com/matrix-org/synapse/issues/9561)) +- Fix spurious errors reported by the `config-lint.sh` script. ([\#9562](https://github.com/matrix-org/synapse/issues/9562)) +- Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. ([\#9563](https://github.com/matrix-org/synapse/issues/9563)) +- Do not have mypy ignore type hints from unpaddedbase64. ([\#9568](https://github.com/matrix-org/synapse/issues/9568)) +- Improve efficiency of calculating the auth chain in large rooms. ([\#9576](https://github.com/matrix-org/synapse/issues/9576)) +- Convert `synapse.types.Requester` to an `attrs` class. ([\#9586](https://github.com/matrix-org/synapse/issues/9586)) +- Add logging for redis connection setup. ([\#9590](https://github.com/matrix-org/synapse/issues/9590)) +- Improve logging when processing incoming transactions. ([\#9596](https://github.com/matrix-org/synapse/issues/9596)) +- Remove unused `stats.retention` setting, and emit a warning if stats are disabled. ([\#9604](https://github.com/matrix-org/synapse/issues/9604)) +- Prevent attempting to bundle aggregations for state events in /context APIs. ([\#9619](https://github.com/matrix-org/synapse/issues/9619)) + + +Synapse 1.29.0 (2021-03-08) +=========================== Note that synapse now expects an `X-Forwarded-Proto` header when used with a reverse proxy. Please see [UPGRADE.rst](UPGRADE.rst#upgrading-to-v1290) for more details on this change. +No significant changes. + + +Synapse 1.29.0rc1 (2021-03-04) +============================== + +Features +-------- + +- Add rate limiters to cross-user key sharing requests. ([\#8957](https://github.com/matrix-org/synapse/issues/8957)) +- Add `order_by` to the admin API `GET /_synapse/admin/v1/users//media`. Contributed by @dklimpel. ([\#8978](https://github.com/matrix-org/synapse/issues/8978)) +- Add some configuration settings to make users' profile data more private. ([\#9203](https://github.com/matrix-org/synapse/issues/9203)) +- The `no_proxy` and `NO_PROXY` environment variables are now respected in proxied HTTP clients with the lowercase form taking precedence if both are present. Additionally, the lowercase `https_proxy` environment variable is now respected in proxied HTTP clients on top of existing support for the uppercase `HTTPS_PROXY` form and takes precedence if both are present. Contributed by Timothy Leung. ([\#9372](https://github.com/matrix-org/synapse/issues/9372)) +- Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users. ([\#9383](https://github.com/matrix-org/synapse/issues/9383), [\#9385](https://github.com/matrix-org/synapse/issues/9385)) +- Add support for regenerating thumbnails if they have been deleted but the original image is still stored. ([\#9438](https://github.com/matrix-org/synapse/issues/9438)) +- Add support for `X-Forwarded-Proto` header when using a reverse proxy. ([\#9472](https://github.com/matrix-org/synapse/issues/9472), [\#9501](https://github.com/matrix-org/synapse/issues/9501), [\#9512](https://github.com/matrix-org/synapse/issues/9512), [\#9539](https://github.com/matrix-org/synapse/issues/9539)) + + +Bugfixes +-------- + +- Fix a bug where users' pushers were not all deleted when they deactivated their account. ([\#9285](https://github.com/matrix-org/synapse/issues/9285), [\#9516](https://github.com/matrix-org/synapse/issues/9516)) +- Fix a bug where a lot of unnecessary presence updates were sent when joining a room. ([\#9402](https://github.com/matrix-org/synapse/issues/9402)) +- Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results. ([\#9416](https://github.com/matrix-org/synapse/issues/9416)) +- Fix a bug in single sign-on which could cause a "No session cookie found" error. ([\#9436](https://github.com/matrix-org/synapse/issues/9436)) +- Fix bug introduced in v1.27.0 where allowing a user to choose their own username when logging in via single sign-on did not work unless an `idp_icon` was defined. ([\#9440](https://github.com/matrix-org/synapse/issues/9440)) +- Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`. ([\#9449](https://github.com/matrix-org/synapse/issues/9449)) +- Fix deleting pushers when using sharded pushers. ([\#9465](https://github.com/matrix-org/synapse/issues/9465), [\#9466](https://github.com/matrix-org/synapse/issues/9466), [\#9479](https://github.com/matrix-org/synapse/issues/9479), [\#9536](https://github.com/matrix-org/synapse/issues/9536)) +- Fix missing startup checks for the consistency of certain PostgreSQL sequences. ([\#9470](https://github.com/matrix-org/synapse/issues/9470)) +- Fix a long-standing bug where the media repository could leak file descriptors while previewing media. ([\#9497](https://github.com/matrix-org/synapse/issues/9497)) +- Properly purge the event chain cover index when purging history. ([\#9498](https://github.com/matrix-org/synapse/issues/9498)) +- Fix missing chain cover index due to a schema delta not being applied correctly. Only affected servers that ran development versions. ([\#9503](https://github.com/matrix-org/synapse/issues/9503)) +- Fix a bug introduced in v1.25.0 where `/_synapse/admin/join/` would fail when given a room alias. ([\#9506](https://github.com/matrix-org/synapse/issues/9506)) +- Prevent presence background jobs from running when presence is disabled. ([\#9530](https://github.com/matrix-org/synapse/issues/9530)) +- Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events. ([\#9537](https://github.com/matrix-org/synapse/issues/9537)) + + +Improved Documentation +---------------------- + +- Update the example systemd config to propagate reloads to individual units. ([\#9463](https://github.com/matrix-org/synapse/issues/9463)) + + +Internal Changes +---------------- + +- Add documentation and type hints to `parse_duration`. ([\#9432](https://github.com/matrix-org/synapse/issues/9432)) +- Remove vestiges of `uploads_path` configuration setting. ([\#9462](https://github.com/matrix-org/synapse/issues/9462)) +- Add a comment about systemd-python. ([\#9464](https://github.com/matrix-org/synapse/issues/9464)) +- Test that we require validated email for email pushers. ([\#9496](https://github.com/matrix-org/synapse/issues/9496)) +- Allow python to generate bytecode for synapse. ([\#9502](https://github.com/matrix-org/synapse/issues/9502)) +- Fix incorrect type hints. ([\#9515](https://github.com/matrix-org/synapse/issues/9515), [\#9518](https://github.com/matrix-org/synapse/issues/9518)) +- Add type hints to device and event report admin API. ([\#9519](https://github.com/matrix-org/synapse/issues/9519)) +- Add type hints to user admin API. ([\#9521](https://github.com/matrix-org/synapse/issues/9521)) +- Bump the versions of mypy and mypy-zope used for static type checking. ([\#9529](https://github.com/matrix-org/synapse/issues/9529)) + + Synapse 1.28.0 (2021-02-25) =========================== diff --git a/MANIFEST.in b/MANIFEST.in index 120ce5b776bd..25d1cb758e52 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -20,9 +20,10 @@ recursive-include scripts * recursive-include scripts-dev * recursive-include synapse *.pyi recursive-include tests *.py -include tests/http/ca.crt -include tests/http/ca.key -include tests/http/server.key +recursive-include tests *.pem +recursive-include tests *.p8 +recursive-include tests *.crt +recursive-include tests *.key recursive-include synapse/res * recursive-include synapse/static *.css diff --git a/README.rst b/README.rst index d872b11f57d7..6a1e7135909f 100644 --- a/README.rst +++ b/README.rst @@ -183,8 +183,9 @@ Using a reverse proxy with Synapse It is recommended to put a reverse proxy such as `nginx `_, `Apache `_, -`Caddy `_ or -`HAProxy `_ in front of Synapse. One advantage of +`Caddy `_, +`HAProxy `_ or +`relayd `_ in front of Synapse. One advantage of doing so is that it means that you can expose the default https port (443) to Matrix clients without needing to run Synapse with root privileges. diff --git a/UPGRADE.rst b/UPGRADE.rst index e852b806c282..8bc2ff91ab2d 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -98,9 +98,9 @@ will log a warning on each received request. To avoid the warning, administrators using a reverse proxy should ensure that the reverse proxy sets `X-Forwarded-Proto` header to `https` or `http` to -indicate the protocol used by the client. See the [reverse proxy -documentation](docs/reverse_proxy.md), where the example configurations have -been updated to show how to set this header. +indicate the protocol used by the client. See the `reverse proxy documentation +`_, where the example configurations have been updated to +show how to set this header. (Users of `Caddy `_ are unaffected, since we believe it sets `X-Forwarded-Proto` by default.) @@ -124,6 +124,13 @@ This version changes the URI used for callbacks from OAuth2 and SAML2 identity p need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted "ACS location" (also known as "allowed callback URLs") at the identity provider. + The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to + ``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity + provider uses this property to validate or otherwise identify Synapse, its configuration + will need to be updated to use the new URL. Alternatively you could create a new, separate + "EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in + the existing "EntityDescriptor" as they were. + Changes to HTML templates ------------------------- diff --git a/changelog.d/8957.feature b/changelog.d/8957.feature deleted file mode 100644 index fa8961f8408b..000000000000 --- a/changelog.d/8957.feature +++ /dev/null @@ -1 +0,0 @@ -Add rate limiters to cross-user key sharing requests. diff --git a/changelog.d/8978.feature b/changelog.d/8978.feature deleted file mode 100644 index 042e257bf06a..000000000000 --- a/changelog.d/8978.feature +++ /dev/null @@ -1 +0,0 @@ -Add `order_by` to the admin API `GET /_synapse/admin/v1/users//media`. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/9203.feature b/changelog.d/9203.feature deleted file mode 100644 index 36b66a47a854..000000000000 --- a/changelog.d/9203.feature +++ /dev/null @@ -1 +0,0 @@ -Add some configuration settings to make users' profile data more private. diff --git a/changelog.d/9285.bugfix b/changelog.d/9285.bugfix deleted file mode 100644 index 81188c5473eb..000000000000 --- a/changelog.d/9285.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug where users' pushers were not all deleted when they deactivated their account. diff --git a/changelog.d/9372.feature b/changelog.d/9372.feature deleted file mode 100644 index 3cb01004c97b..000000000000 --- a/changelog.d/9372.feature +++ /dev/null @@ -1 +0,0 @@ -The `no_proxy` and `NO_PROXY` environment variables are now respected in proxied HTTP clients with the lowercase form taking precedence if both are present. Additionally, the lowercase `https_proxy` environment variable is now respected in proxied HTTP clients on top of existing support for the uppercase `HTTPS_PROXY` form and takes precedence if both are present. Contributed by Timothy Leung. diff --git a/changelog.d/9383.feature b/changelog.d/9383.feature deleted file mode 100644 index 8957c9cc5e4c..000000000000 --- a/changelog.d/9383.feature +++ /dev/null @@ -1 +0,0 @@ -Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users. \ No newline at end of file diff --git a/changelog.d/9385.feature b/changelog.d/9385.feature deleted file mode 100644 index cbe3922de80a..000000000000 --- a/changelog.d/9385.feature +++ /dev/null @@ -1 +0,0 @@ - Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users. \ No newline at end of file diff --git a/changelog.d/9402.bugfix b/changelog.d/9402.bugfix deleted file mode 100644 index 7729225ba2d5..000000000000 --- a/changelog.d/9402.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug where a lot of unnecessary presence updates were sent when joining a room. diff --git a/changelog.d/9416.bugfix b/changelog.d/9416.bugfix deleted file mode 100644 index 4d79cb222801..000000000000 --- a/changelog.d/9416.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results. \ No newline at end of file diff --git a/changelog.d/9432.misc b/changelog.d/9432.misc deleted file mode 100644 index 1e07da203332..000000000000 --- a/changelog.d/9432.misc +++ /dev/null @@ -1 +0,0 @@ -Add documentation and type hints to `parse_duration`. diff --git a/changelog.d/9436.bugfix b/changelog.d/9436.bugfix deleted file mode 100644 index a530516eed4e..000000000000 --- a/changelog.d/9436.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug in single sign-on which could cause a "No session cookie found" error. diff --git a/changelog.d/9438.feature b/changelog.d/9438.feature deleted file mode 100644 index a1f6be556376..000000000000 --- a/changelog.d/9438.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for regenerating thumbnails if they have been deleted but the original image is still stored. diff --git a/changelog.d/9440.bugfix b/changelog.d/9440.bugfix deleted file mode 100644 index 47b9842b37cf..000000000000 --- a/changelog.d/9440.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix bug introduced in v1.27.0 where allowing a user to choose their own username when logging in via single sign-on did not work unless an `idp_icon` was defined. diff --git a/changelog.d/9449.bugfix b/changelog.d/9449.bugfix deleted file mode 100644 index 54214a7e4a07..000000000000 --- a/changelog.d/9449.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`. diff --git a/changelog.d/9462.misc b/changelog.d/9462.misc deleted file mode 100644 index 1b245bf85d3c..000000000000 --- a/changelog.d/9462.misc +++ /dev/null @@ -1 +0,0 @@ -Remove vestiges of `uploads_path` configuration setting. diff --git a/changelog.d/9463.doc b/changelog.d/9463.doc deleted file mode 100644 index c9cedd147d57..000000000000 --- a/changelog.d/9463.doc +++ /dev/null @@ -1 +0,0 @@ -Update the example systemd config to propagate reloads to individual units. diff --git a/changelog.d/9464.misc b/changelog.d/9464.misc deleted file mode 100644 index 39fcf85d4006..000000000000 --- a/changelog.d/9464.misc +++ /dev/null @@ -1 +0,0 @@ -Add a comment about systemd-python. diff --git a/changelog.d/9465.bugfix b/changelog.d/9465.bugfix deleted file mode 100644 index 2ab4f315c11f..000000000000 --- a/changelog.d/9465.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix deleting pushers when using sharded pushers. diff --git a/changelog.d/9466.bugfix b/changelog.d/9466.bugfix deleted file mode 100644 index 2ab4f315c11f..000000000000 --- a/changelog.d/9466.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix deleting pushers when using sharded pushers. diff --git a/changelog.d/9470.bugfix b/changelog.d/9470.bugfix deleted file mode 100644 index c1b7dbb17d98..000000000000 --- a/changelog.d/9470.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix missing startup checks for the consistency of certain PostgreSQL sequences. diff --git a/changelog.d/9472.feature b/changelog.d/9472.feature deleted file mode 100644 index 06cfd5d1999c..000000000000 --- a/changelog.d/9472.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for `X-Forwarded-Proto` header when using a reverse proxy. diff --git a/changelog.d/9479.bugfix b/changelog.d/9479.bugfix deleted file mode 100644 index 2ab4f315c11f..000000000000 --- a/changelog.d/9479.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix deleting pushers when using sharded pushers. diff --git a/changelog.d/9496.misc b/changelog.d/9496.misc deleted file mode 100644 index d5866c56f738..000000000000 --- a/changelog.d/9496.misc +++ /dev/null @@ -1 +0,0 @@ -Test that we require validated email for email pushers. diff --git a/changelog.d/9497.bugfix b/changelog.d/9497.bugfix deleted file mode 100644 index 598bcaab673b..000000000000 --- a/changelog.d/9497.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where the media repository could leak file descriptors while previewing media. diff --git a/changelog.d/9498.bugfix b/changelog.d/9498.bugfix deleted file mode 100644 index dce0ad0920e2..000000000000 --- a/changelog.d/9498.bugfix +++ /dev/null @@ -1 +0,0 @@ -Properly purge the event chain cover index when purging history. diff --git a/changelog.d/9501.feature b/changelog.d/9501.feature deleted file mode 100644 index 06cfd5d1999c..000000000000 --- a/changelog.d/9501.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for `X-Forwarded-Proto` header when using a reverse proxy. diff --git a/changelog.d/9502.misc b/changelog.d/9502.misc deleted file mode 100644 index c5c29672b67e..000000000000 --- a/changelog.d/9502.misc +++ /dev/null @@ -1 +0,0 @@ -Allow python to generate bytecode for synapse. \ No newline at end of file diff --git a/changelog.d/9503.bugfix b/changelog.d/9503.bugfix deleted file mode 100644 index 0868691389ed..000000000000 --- a/changelog.d/9503.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix missing chain cover index due to a schema delta not being applied correctly. Only affected servers that ran development versions. diff --git a/changelog.d/9506.bugfix b/changelog.d/9506.bugfix deleted file mode 100644 index cc0d410e0f9f..000000000000 --- a/changelog.d/9506.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in v1.25.0 where `/_synapse/admin/join/` would fail when given a room alias. diff --git a/changelog.d/9512.feature b/changelog.d/9512.feature deleted file mode 100644 index 06cfd5d1999c..000000000000 --- a/changelog.d/9512.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for `X-Forwarded-Proto` header when using a reverse proxy. diff --git a/changelog.d/9515.misc b/changelog.d/9515.misc deleted file mode 100644 index 14c7b78dd97e..000000000000 --- a/changelog.d/9515.misc +++ /dev/null @@ -1 +0,0 @@ -Fix incorrect type hints. diff --git a/changelog.d/9516.bugfix b/changelog.d/9516.bugfix deleted file mode 100644 index 81188c5473eb..000000000000 --- a/changelog.d/9516.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug where users' pushers were not all deleted when they deactivated their account. diff --git a/changelog.d/9518.misc b/changelog.d/9518.misc deleted file mode 100644 index 14c7b78dd97e..000000000000 --- a/changelog.d/9518.misc +++ /dev/null @@ -1 +0,0 @@ -Fix incorrect type hints. diff --git a/changelog.d/9519.misc b/changelog.d/9519.misc deleted file mode 100644 index caccc88a1996..000000000000 --- a/changelog.d/9519.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to device and event report admin API. \ No newline at end of file diff --git a/changelog.d/9521.misc b/changelog.d/9521.misc deleted file mode 100644 index 1424d9c188a7..000000000000 --- a/changelog.d/9521.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to user admin API. \ No newline at end of file diff --git a/changelog.d/9529.misc b/changelog.d/9529.misc deleted file mode 100644 index b9021a26b406..000000000000 --- a/changelog.d/9529.misc +++ /dev/null @@ -1 +0,0 @@ -Bump the versions of mypy and mypy-zope used for static type checking. diff --git a/changelog.d/9530.bugfix b/changelog.d/9530.bugfix deleted file mode 100644 index bb4db675d9d6..000000000000 --- a/changelog.d/9530.bugfix +++ /dev/null @@ -1 +0,0 @@ -Prevent presence background jobs from running when presence is disabled. \ No newline at end of file diff --git a/changelog.d/9536.bugfix b/changelog.d/9536.bugfix deleted file mode 100644 index 2ab4f315c11f..000000000000 --- a/changelog.d/9536.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix deleting pushers when using sharded pushers. diff --git a/changelog.d/9537.bugfix b/changelog.d/9537.bugfix deleted file mode 100644 index 033ab1c939d4..000000000000 --- a/changelog.d/9537.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events. diff --git a/changelog.d/9539.feature b/changelog.d/9539.feature deleted file mode 100644 index 06cfd5d1999c..000000000000 --- a/changelog.d/9539.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for `X-Forwarded-Proto` header when using a reverse proxy. diff --git a/debian/changelog b/debian/changelog index 04cd0d0d68dc..3feefd8a1c31 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,9 +1,12 @@ -matrix-synapse-py3 (1.29.0) UNRELEASED; urgency=medium +matrix-synapse-py3 (1.29.0) stable; urgency=medium [ Jonathan de Jong ] * Remove the python -B flag (don't generate bytecode) in scripts and documentation. - -- Synapse Packaging team Fri, 26 Feb 2021 14:41:31 +0100 + [ Synapse Packaging team ] + * New synapse release 1.29.0. + + -- Synapse Packaging team Mon, 08 Mar 2021 13:51:50 +0000 matrix-synapse-py3 (1.28.0) stable; urgency=medium diff --git a/docker/Dockerfile b/docker/Dockerfile index d619ee08ed82..def4501541fe 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -69,6 +69,7 @@ RUN apt-get update && apt-get install -y \ libpq5 \ libwebp6 \ xmlsec1 \ + libjemalloc2 \ && rm -rf /var/lib/apt/lists/* COPY --from=builder /install /usr/local diff --git a/docker/README.md b/docker/README.md index 7b138df4d3ee..3a7dc585e7b5 100644 --- a/docker/README.md +++ b/docker/README.md @@ -204,3 +204,8 @@ healthcheck: timeout: 10s retries: 3 ``` + +## Using jemalloc + +Jemalloc is embedded in the image and will be used instead of the default allocator. +You can read about jemalloc by reading the Synapse [README](../README.md) \ No newline at end of file diff --git a/docker/start.py b/docker/start.py index 0d2c590b8838..16d6a8208a5f 100755 --- a/docker/start.py +++ b/docker/start.py @@ -3,6 +3,7 @@ import codecs import glob import os +import platform import subprocess import sys @@ -213,6 +214,13 @@ def main(args, environ): if "-m" not in args: args = ["-m", synapse_worker] + args + jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),) + + if os.path.isfile(jemallocpath): + environ["LD_PRELOAD"] = jemallocpath + else: + log("Could not find %s, will not use" % (jemallocpath,)) + # if there are no config files passed to synapse, try adding the default file if not any(p.startswith("--config-path") or p.startswith("-c") for p in args): config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data") @@ -248,9 +256,9 @@ def main(args, environ): args = ["python"] + args if ownership is not None: args = ["gosu", ownership] + args - os.execv("/usr/sbin/gosu", args) + os.execve("/usr/sbin/gosu", args, environ) else: - os.execv("/usr/local/bin/python", args) + os.execve("/usr/local/bin/python", args, environ) if __name__ == "__main__": diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 90faeaaef0d9..9dbec68c19e6 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -1,5 +1,7 @@ # Contents -- [List all media in a room](#list-all-media-in-a-room) +- [Querying media](#querying-media) + * [List all media in a room](#list-all-media-in-a-room) + * [List all media uploaded by a user](#list-all-media-uploaded-by-a-user) - [Quarantine media](#quarantine-media) * [Quarantining media by ID](#quarantining-media-by-id) * [Quarantining media in a room](#quarantining-media-in-a-room) @@ -10,7 +12,11 @@ * [Delete local media by date or size](#delete-local-media-by-date-or-size) - [Purge Remote Media API](#purge-remote-media-api) -# List all media in a room +# Querying media + +These APIs allow extracting media information from the homeserver. + +## List all media in a room This API gets a list of known media in a room. However, it only shows media from unencrypted events or rooms. @@ -36,6 +42,12 @@ The API returns a JSON body like the following: } ``` +## List all media uploaded by a user + +Listing all media that has been uploaded by a local user can be achieved through +the use of the [List media of a user](user_admin_api.rst#list-media-of-a-user) +Admin API. + # Quarantine media Quarantining media means that it is marked as inaccessible by users. It applies diff --git a/docs/openid.md b/docs/openid.md index 263bc9f6f88e..cfaafc50150f 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -226,7 +226,7 @@ Synapse config: oidc_providers: - idp_id: github idp_name: Github - idp_brand: "org.matrix.github" # optional: styling hint for clients + idp_brand: "github" # optional: styling hint for clients discover: false issuer: "https://github.com/" client_id: "your-client-id" # TO BE FILLED @@ -252,7 +252,7 @@ oidc_providers: oidc_providers: - idp_id: google idp_name: Google - idp_brand: "org.matrix.google" # optional: styling hint for clients + idp_brand: "google" # optional: styling hint for clients issuer: "https://accounts.google.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED @@ -299,7 +299,7 @@ Synapse config: oidc_providers: - idp_id: gitlab idp_name: Gitlab - idp_brand: "org.matrix.gitlab" # optional: styling hint for clients + idp_brand: "gitlab" # optional: styling hint for clients issuer: "https://gitlab.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED @@ -334,7 +334,7 @@ Synapse config: ```yaml - idp_id: facebook idp_name: Facebook - idp_brand: "org.matrix.facebook" # optional: styling hint for clients + idp_brand: "facebook" # optional: styling hint for clients discover: false issuer: "https://facebook.com" client_id: "your-client-id" # TO BE FILLED @@ -386,7 +386,7 @@ oidc_providers: config: subject_claim: "id" localpart_template: "{{ user.login }}" - display_name_template: "{{ user.full_name }}" + display_name_template: "{{ user.full_name }}" ``` ### XWiki @@ -401,8 +401,7 @@ oidc_providers: idp_name: "XWiki" issuer: "https://myxwikihost/xwiki/oidc/" client_id: "your-client-id" # TO BE FILLED - # Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed - client_secret: "dontcare" + client_auth_method: none scopes: ["openid", "profile"] user_profile_method: "userinfo_endpoint" user_mapping_provider: @@ -410,3 +409,40 @@ oidc_providers: localpart_template: "{{ user.preferred_username }}" display_name_template: "{{ user.name }}" ``` + +## Apple + +Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account. + +You will need to create a new "Services ID" for SiWA, and create and download a +private key with "SiWA" enabled. + +As well as the private key file, you will need: + * Client ID: the "identifier" you gave the "Services ID" + * Team ID: a 10-character ID associated with your developer account. + * Key ID: the 10-character identifier for the key. + +https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more +documentation on setting up SiWA. + +The synapse config will look like this: + +```yaml + - idp_id: apple + idp_name: Apple + issuer: "https://appleid.apple.com" + client_id: "your-client-id" # Set to the "identifier" for your "ServicesID" + client_auth_method: "client_secret_post" + client_secret_jwt_key: + key_file: "/path/to/AuthKey_KEYIDCODE.p8" # point to your key file + jwt_header: + alg: ES256 + kid: "KEYIDCODE" # Set to the 10-char Key ID + jwt_payload: + iss: TEAMIDCODE # Set to the 10-char Team ID + scopes: ["name", "email", "openid"] + authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post + user_mapping_provider: + config: + email_template: "{{ user.email }}" +``` diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 81e5a68a361b..860afd5a040a 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -3,8 +3,9 @@ It is recommended to put a reverse proxy such as [nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html), [Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html), -[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or -[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage +[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy), +[HAProxy](https://www.haproxy.org/) or +[relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage of doing so is that it means that you can expose the default https port (443) to Matrix clients without needing to run Synapse with root privileges. @@ -162,6 +163,52 @@ backend matrix server matrix 127.0.0.1:8008 ``` +### Relayd + +``` +table { 127.0.0.1 } +table { 127.0.0.1 } + +http protocol "https" { + tls { no tlsv1.0, ciphers "HIGH" } + tls keypair "example.com" + match header set "X-Forwarded-For" value "$REMOTE_ADDR" + match header set "X-Forwarded-Proto" value "https" + + # set CORS header for .well-known/matrix/server, .well-known/matrix/client + # httpd does not support setting headers, so do it here + match request path "/.well-known/matrix/*" tag "matrix-cors" + match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*" + + pass quick path "/_matrix/*" forward to + pass quick path "/_synapse/client/*" forward to + + # pass on non-matrix traffic to webserver + pass forward to +} + +relay "https_traffic" { + listen on egress port 443 tls + protocol "https" + forward to port 8008 check tcp + forward to port 8080 check tcp +} + +http protocol "matrix" { + tls { no tlsv1.0, ciphers "HIGH" } + tls keypair "example.com" + block + pass quick path "/_matrix/*" forward to + pass quick path "/_synapse/client/*" forward to +} + +relay "matrix_federation" { + listen on egress port 8448 tls + protocol "matrix" + forward to port 8008 check tcp +} +``` + ## Homeserver Configuration You will also want to set `bind_addresses: ['127.0.0.1']` and diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 42cb17123e6a..aeb73b49667f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -89,8 +89,7 @@ pid_file: DATADIR/homeserver.pid # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to # 'false'. Note that profile data is also available via the federation -# API, so this setting is of limited value if federation is enabled on -# the server. +# API, unless allow_profile_lookup_over_federation is set to false. # #require_auth_for_profile_requests: true @@ -1787,7 +1786,26 @@ saml2_config: # # client_id: Required. oauth2 client id to use. # -# client_secret: Required. oauth2 client secret to use. +# client_secret: oauth2 client secret to use. May be omitted if +# client_secret_jwt_key is given, or if client_auth_method is 'none'. +# +# client_secret_jwt_key: Alternative to client_secret: details of a key used +# to create a JSON Web Token to be used as an OAuth2 client secret. If +# given, must be a dictionary with the following properties: +# +# key: a pem-encoded signing key. Must be a suitable key for the +# algorithm specified. Required unless 'key_file' is given. +# +# key_file: the path to file containing a pem-encoded signing key file. +# Required unless 'key' is given. +# +# jwt_header: a dictionary giving properties to include in the JWT +# header. Must include the key 'alg', giving the algorithm used to +# sign the JWT, such as "ES256", using the JWA identifiers in +# RFC7518. +# +# jwt_payload: an optional dictionary giving properties to include in +# the JWT payload. Normally this should include an 'iss' key. # # client_auth_method: auth method to use when exchanging the token. Valid # values are 'client_secret_basic' (default), 'client_secret_post' and @@ -1908,7 +1926,7 @@ oidc_providers: # #- idp_id: github # idp_name: Github - # idp_brand: org.matrix.github + # idp_brand: github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -2634,19 +2652,20 @@ user_directory: -# Local statistics collection. Used in populating the room directory. -# -# 'bucket_size' controls how large each statistics timeslice is. It can -# be defined in a human readable short form -- e.g. "1d", "1y". +# Settings for local room and user statistics collection. See +# docs/room_and_user_statistics.md. # -# 'retention' controls how long historical statistics will be kept for. -# It can be defined in a human readable short form -- e.g. "1d", "1y". -# -# -#stats: -# enabled: true -# bucket_size: 1d -# retention: 1y +stats: + # Uncomment the following to disable room and user statistics. Note that doing + # so may cause certain features (such as the room directory) not to work + # correctly. + # + #enabled: false + + # The size of each timeslice in the room_stats_historical and + # user_stats_historical tables, as a time period. Defaults to "1d". + # + #bucket_size: 1h # Server Notices room configuration diff --git a/docs/spam_checker.md b/docs/spam_checker.md index e615ac99103f..52947f605e13 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -14,6 +14,7 @@ The Python class is instantiated with two objects: * An instance of `synapse.module_api.ModuleApi`. It then implements methods which return a boolean to alter behavior in Synapse. +All the methods must be defined. There's a generic method for checking every event (`check_event_for_spam`), as well as some specific methods: @@ -24,6 +25,7 @@ well as some specific methods: * `user_may_publish_room` * `check_username_for_spam` * `check_registration_for_spam` +* `check_media_file_for_spam` The details of each of these methods (as well as their inputs and outputs) are documented in the `synapse.events.spamcheck.SpamChecker` class. @@ -31,6 +33,10 @@ are documented in the `synapse.events.spamcheck.SpamChecker` class. The `ModuleApi` class provides a way for the custom spam checker class to call back into the homeserver internals. +Additionally, a `parse_config` method is mandatory and receives the plugin config +dictionary. After parsing, It must return an object which will be +passed to `__init__` later. + ### Example ```python @@ -41,6 +47,10 @@ class ExampleSpamChecker: self.config = config self.api = api + @staticmethod + def parse_config(config): + return config + async def check_event_for_spam(self, foo): return False # allow all events @@ -59,7 +69,13 @@ class ExampleSpamChecker: async def check_username_for_spam(self, user_profile): return False # allow all usernames - async def check_registration_for_spam(self, email_threepid, username, request_info): + async def check_registration_for_spam( + self, + email_threepid, + username, + request_info, + auth_provider_id, + ): return RegistrationBehaviour.ALLOW # allow all registrations async def check_media_file_for_spam(self, file_wrapper, file_info): diff --git a/mypy.ini b/mypy.ini index 64ed45dac24c..e0685e097c1a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -69,6 +69,7 @@ files = synapse/util/async_helpers.py, synapse/util/caches, synapse/util/metrics.py, + synapse/util/macaroons.py, synapse/util/stringutils.py, tests/replication, tests/test_utils, @@ -116,9 +117,6 @@ ignore_missing_imports = True [mypy-saml2.*] ignore_missing_imports = True -[mypy-unpaddedbase64] -ignore_missing_imports = True - [mypy-canonicaljson] ignore_missing_imports = True diff --git a/scripts-dev/config-lint.sh b/scripts-dev/config-lint.sh index 189ca66535e6..9132160463a6 100755 --- a/scripts-dev/config-lint.sh +++ b/scripts-dev/config-lint.sh @@ -2,9 +2,14 @@ # Find linting errors in Synapse's default config file. # Exits with 0 if there are no problems, or another code otherwise. +# cd to the root of the repository +cd `dirname $0`/.. + +# Restore backup of sample config upon script exit +trap "mv docs/sample_config.yaml.bak docs/sample_config.yaml" EXIT + # Fix non-lowercase true/false values sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml -rm docs/sample_config.yaml.bak # Check if anything changed -git diff --exit-code docs/sample_config.yaml +diff docs/sample_config.yaml docs/sample_config.yaml.bak diff --git a/setup.cfg b/setup.cfg index f46e43fad0bb..5e301c2cd7d3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,7 @@ test_suite = tests [check-manifest] ignore = + .git-blame-ignore-revs contrib contrib/* docs/* diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 68c11d3b8b62..58dc06bd6632 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -17,7 +17,9 @@ """ from typing import Any, List, Optional, Type, Union -class RedisProtocol: +from twisted.internet import protocol + +class RedisProtocol(protocol.Protocol): def publish(self, channel: str, message: bytes): ... async def ping(self) -> None: ... async def set( @@ -54,7 +56,7 @@ def lazyConnection( class ConnectionHandler: ... -class RedisFactory: +class RedisFactory(protocol.ReconnectingClientFactory): continueTrying: bool handler: RedisProtocol pool: List[RedisProtocol] diff --git a/synapse/__init__.py b/synapse/__init__.py index 869e860fb045..88be7db1966d 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ except ImportError: pass -__version__ = "1.28.0" +__version__ = "1.30.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 89e62b0e367b..e10e33fd23c4 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -39,6 +39,7 @@ from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import StateMap, UserID from synapse.util.caches.lrucache import LruCache +from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -163,7 +164,7 @@ def get_public_keys(self, invite_event): async def get_user_by_req( self, - request: Request, + request: SynapseRequest, allow_guest: bool = False, rights: str = "access", allow_expired: bool = False, @@ -408,7 +409,7 @@ def _parse_and_validate_macaroon(self, token, rights="access"): raise _InvalidMacaroonException() try: - user_id = self.get_user_id_from_macaroon(macaroon) + user_id = get_value_from_macaroon(macaroon, "user_id") guest = False for caveat in macaroon.caveats: @@ -416,7 +417,12 @@ def _parse_and_validate_macaroon(self, token, rights="access"): guest = True self.validate_macaroon(macaroon, rights, user_id=user_id) - except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): + except ( + pymacaroons.exceptions.MacaroonException, + KeyError, + TypeError, + ValueError, + ): raise InvalidClientTokenError("Invalid macaroon passed.") if rights == "access": @@ -424,27 +430,6 @@ def _parse_and_validate_macaroon(self, token, rights="access"): return user_id, guest - def get_user_id_from_macaroon(self, macaroon): - """Retrieve the user_id given by the caveats on the macaroon. - - Does *not* validate the macaroon. - - Args: - macaroon (pymacaroons.Macaroon): The macaroon to validate - - Returns: - (str) user id - - Raises: - InvalidClientCredentialsError if there is no user_id caveat in the - macaroon - """ - user_prefix = "user_id = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix) :] - raise InvalidClientTokenError("No user caveat in macaroon") - def validate_macaroon(self, macaroon, type_string, user_id): """ validate that a Macaroon is understood by and was signed by this server. @@ -465,21 +450,13 @@ def validate_macaroon(self, macaroon, type_string, user_id): v.satisfy_exact("type = " + type_string) v.satisfy_exact("user_id = %s" % user_id) v.satisfy_exact("guest = true") - v.satisfy_general(self._verify_expiry) + satisfy_expiry(v, self.clock.time_msec) # access_tokens include a nonce for uniqueness: any value is acceptable v.satisfy_general(lambda c: c.startswith("nonce = ")) v.verify(macaroon, self._macaroon_secret_key) - def _verify_expiry(self, caveat): - prefix = "time < " - if not caveat.startswith(prefix): - return False - expiry = int(caveat[len(prefix) :]) - now = self.hs.get_clock().time_msec() - return now < expiry - def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService: token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 93c2aabcca6c..9d3bbe3b8b05 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -90,7 +90,7 @@ def __init__(self, hs): self.clock = hs.get_clock() self.protocol_meta_cache = ResponseCache( - hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS + hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) # type: ResponseCache[Tuple[str, str]] async def query_user(self, service, user_id): diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 402696671193..ba9cd63cf2d5 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -212,9 +212,8 @@ def ensure_directory(cls, dir_path): @classmethod def read_file(cls, file_path, config_name): - cls.check_file(file_path, config_name) - with open(file_path) as file_stream: - return file_stream.read() + """Deprecated: call read_file directly""" + return read_file(file_path, (config_name,)) def read_template(self, filename: str) -> jinja2.Template: """Load a template file from disk. @@ -894,4 +893,35 @@ def get_instance(self, key: str) -> str: return self._get_instance(key) -__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] +def read_file(file_path: Any, config_path: Iterable[str]) -> str: + """Check the given file exists, and read it into a string + + If it does not, emit an error indicating the problem + + Args: + file_path: the file to be read + config_path: where in the configuration file_path came from, so that a useful + error can be emitted if it does not exist. + Returns: + content of the file. + Raises: + ConfigError if there is a problem reading the file. + """ + if not isinstance(file_path, str): + raise ConfigError("%r is not a string", config_path) + + try: + os.stat(file_path) + with open(file_path) as file_stream: + return file_stream.read() + except OSError as e: + raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e + + +__all__ = [ + "Config", + "RootConfig", + "ShardedWorkerHandlingConfig", + "RoutableShardedWorkerHandlingConfig", + "read_file", +] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index db16c86f5084..e896fd34e23c 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig: class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): def get_instance(self, key: str) -> str: ... + +def read_file(file_path: Any, config_path: Iterable[str]) -> str: ... diff --git a/synapse/config/logger.py b/synapse/config/logger.py index e56cf846f516..999aecce5c78 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -21,8 +21,10 @@ from string import Template import yaml +from zope.interface import implementer from twisted.logger import ( + ILogObserver, LogBeginner, STDLibLogObserver, eventAsText, @@ -227,7 +229,8 @@ def factory(*args, **kwargs): threadlocal = threading.local() - def _log(event): + @implementer(ILogObserver) + def _log(event: dict) -> None: if "log_text" in event: if event["log_text"].startswith("DNSDatagramProtocol starting on "): return diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index a27594befc87..2bfb537c15fb 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -15,7 +15,7 @@ # limitations under the License. from collections import Counter -from typing import Iterable, Optional, Tuple, Type +from typing import Iterable, Mapping, Optional, Tuple, Type import attr @@ -25,7 +25,7 @@ from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_mxc_uri -from ._base import Config, ConfigError +from ._base import Config, ConfigError, read_file DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider" @@ -97,7 +97,26 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # # client_id: Required. oauth2 client id to use. # - # client_secret: Required. oauth2 client secret to use. + # client_secret: oauth2 client secret to use. May be omitted if + # client_secret_jwt_key is given, or if client_auth_method is 'none'. + # + # client_secret_jwt_key: Alternative to client_secret: details of a key used + # to create a JSON Web Token to be used as an OAuth2 client secret. If + # given, must be a dictionary with the following properties: + # + # key: a pem-encoded signing key. Must be a suitable key for the + # algorithm specified. Required unless 'key_file' is given. + # + # key_file: the path to file containing a pem-encoded signing key file. + # Required unless 'key' is given. + # + # jwt_header: a dictionary giving properties to include in the JWT + # header. Must include the key 'alg', giving the algorithm used to + # sign the JWT, such as "ES256", using the JWA identifiers in + # RFC7518. + # + # jwt_payload: an optional dictionary giving properties to include in + # the JWT payload. Normally this should include an 'iss' key. # # client_auth_method: auth method to use when exchanging the token. Valid # values are 'client_secret_basic' (default), 'client_secret_post' and @@ -218,7 +237,7 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # #- idp_id: github # idp_name: Github - # idp_brand: org.matrix.github + # idp_brand: github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -240,7 +259,7 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # jsonschema definition of the configuration settings for an oidc identity provider OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", - "required": ["issuer", "client_id", "client_secret"], + "required": ["issuer", "client_id"], "properties": { "idp_id": { "type": "string", @@ -253,7 +272,12 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): "idp_icon": {"type": "string"}, "idp_brand": { "type": "string", - # MSC2758-style namespaced identifier + "minLength": 1, + "maxLength": 255, + "pattern": "^[a-z][a-z0-9_.-]*$", + }, + "idp_unstable_brand": { + "type": "string", "minLength": 1, "maxLength": 255, "pattern": "^[a-z][a-z0-9_.-]*$", @@ -262,6 +286,30 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): "issuer": {"type": "string"}, "client_id": {"type": "string"}, "client_secret": {"type": "string"}, + "client_secret_jwt_key": { + "type": "object", + "required": ["jwt_header"], + "oneOf": [ + {"required": ["key"]}, + {"required": ["key_file"]}, + ], + "properties": { + "key": {"type": "string"}, + "key_file": {"type": "string"}, + "jwt_header": { + "type": "object", + "required": ["alg"], + "properties": { + "alg": {"type": "string"}, + }, + "additionalProperties": {"type": "string"}, + }, + "jwt_payload": { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + }, + }, "client_auth_method": { "type": "string", # the following list is the same as the keys of @@ -404,15 +452,31 @@ def _parse_oidc_config_dict( "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) ) from e + client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key") + client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey] + if client_secret_jwt_key_config is not None: + keyfile = client_secret_jwt_key_config.get("key_file") + if keyfile: + key = read_file(keyfile, config_path + ("client_secret_jwt_key",)) + else: + key = client_secret_jwt_key_config["key"] + client_secret_jwt_key = OidcProviderClientSecretJwtKey( + key=key, + jwt_header=client_secret_jwt_key_config["jwt_header"], + jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}), + ) + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), idp_icon=idp_icon, idp_brand=oidc_config.get("idp_brand"), + unstable_idp_brand=oidc_config.get("unstable_idp_brand"), discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], - client_secret=oidc_config["client_secret"], + client_secret=oidc_config.get("client_secret"), + client_secret_jwt_key=client_secret_jwt_key, client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), scopes=oidc_config.get("scopes", ["openid"]), authorization_endpoint=oidc_config.get("authorization_endpoint"), @@ -427,6 +491,18 @@ def _parse_oidc_config_dict( ) +@attr.s(slots=True, frozen=True) +class OidcProviderClientSecretJwtKey: + # a pem-encoded signing key + key = attr.ib(type=str) + + # properties to include in the JWT header + jwt_header = attr.ib(type=Mapping[str, str]) + + # properties to include in the JWT payload. + jwt_payload = attr.ib(type=Mapping[str, str]) + + @attr.s(slots=True, frozen=True) class OidcProviderConfig: # a unique identifier for this identity provider. Used in the 'user_external_ids' @@ -442,6 +518,9 @@ class OidcProviderConfig: # Optional brand identifier for this IdP. idp_brand = attr.ib(type=Optional[str]) + # Optional brand identifier for the unstable API (see MSC2858). + unstable_idp_brand = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) @@ -452,8 +531,13 @@ class OidcProviderConfig: # oauth2 client id to use client_id = attr.ib(type=str) - # oauth2 client secret to use - client_secret = attr.ib(type=str) + # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate + # a secret. + client_secret = attr.ib(type=Optional[str]) + + # key to use to construct a JWT to use as a client secret. May be `None` if + # `client_secret` is set. + client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey]) # auth method to use when exchanging the token. # Valid values are 'client_secret_basic', 'client_secret_post' and diff --git a/synapse/config/server.py b/synapse/config/server.py index 2afca36e7d16..5f8910b6e149 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -841,8 +841,7 @@ def generate_config_section( # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to # 'false'. Note that profile data is also available via the federation - # API, so this setting is of limited value if federation is enabled on - # the server. + # API, unless allow_profile_lookup_over_federation is set to false. # #require_auth_for_profile_requests: true diff --git a/synapse/config/stats.py b/synapse/config/stats.py index b559bfa4113c..2258329a52d8 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -13,10 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys +import logging from ._base import Config +ROOM_STATS_DISABLED_WARN = """\ +WARNING: room/user statistics have been disabled via the stats.enabled +configuration setting. This means that certain features (such as the room +directory) will not operate correctly. Future versions of Synapse may ignore +this setting. + +To fix this warning, remove the stats.enabled setting from your configuration +file. +--------------------------------------------------------------------------------""" + +logger = logging.getLogger(__name__) + class StatsConfig(Config): """Stats Configuration @@ -28,30 +40,29 @@ class StatsConfig(Config): def read_config(self, config, **kwargs): self.stats_enabled = True self.stats_bucket_size = 86400 * 1000 - self.stats_retention = sys.maxsize stats_config = config.get("stats", None) if stats_config: self.stats_enabled = stats_config.get("enabled", self.stats_enabled) self.stats_bucket_size = self.parse_duration( stats_config.get("bucket_size", "1d") ) - self.stats_retention = self.parse_duration( - stats_config.get("retention", "%ds" % (sys.maxsize,)) - ) + if not self.stats_enabled: + logger.warning(ROOM_STATS_DISABLED_WARN) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ - # Local statistics collection. Used in populating the room directory. + # Settings for local room and user statistics collection. See + # docs/room_and_user_statistics.md. # - # 'bucket_size' controls how large each statistics timeslice is. It can - # be defined in a human readable short form -- e.g. "1d", "1y". - # - # 'retention' controls how long historical statistics will be kept for. - # It can be defined in a human readable short form -- e.g. "1d", "1y". - # - # - #stats: - # enabled: true - # bucket_size: 1d - # retention: 1y + stats: + # Uncomment the following to disable room and user statistics. Note that doing + # so may cause certain features (such as the room directory) not to work + # correctly. + # + #enabled: false + + # The size of each timeslice in the room_stats_historical and + # user_stats_historical tables, as a time period. Defaults to "1d". + # + #bucket_size: 1h """ diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 8cfc0bb3cbb1..a9185987a237 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,6 +15,7 @@ # limitations under the License. import inspect +import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from synapse.rest.media.v1._base import FileInfo @@ -27,6 +28,8 @@ import synapse.events import synapse.server +logger = logging.getLogger(__name__) + class SpamChecker: def __init__(self, hs: "synapse.server.HomeServer"): @@ -190,6 +193,7 @@ async def check_registration_for_spam( email_threepid: Optional[dict], username: Optional[str], request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str] = None, ) -> RegistrationBehaviour: """Checks if we should allow the given registration request. @@ -198,6 +202,9 @@ async def check_registration_for_spam( username: The request user name, if any request_info: List of tuples of user agent and IP that were used during the registration process. + auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml", + "cas". If any. Note this does not include users registered + via a password provider. Returns: Enum for how the request should be handled @@ -208,9 +215,25 @@ async def check_registration_for_spam( # spam checker checker = getattr(spam_checker, "check_registration_for_spam", None) if checker: - behaviour = await maybe_awaitable( - checker(email_threepid, username, request_info) - ) + # Provide auth_provider_id if the function supports it + checker_args = inspect.signature(checker) + if len(checker_args.parameters) == 4: + d = checker( + email_threepid, + username, + request_info, + auth_provider_id, + ) + elif len(checker_args.parameters) == 3: + d = checker(email_threepid, username, request_info) + else: + logger.error( + "Invalid signature for %s.check_registration_for_spam. Denying registration", + spam_checker.__module__, + ) + return RegistrationBehaviour.DENY + + behaviour = await maybe_awaitable(d) assert isinstance(behaviour, RegistrationBehaviour) if behaviour != RegistrationBehaviour.ALLOW: return behaviour diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2f832b47f65c..9839d3d0160b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -22,6 +22,7 @@ Awaitable, Callable, Dict, + Iterable, List, Optional, Tuple, @@ -90,16 +91,15 @@ "Time taken to process an event", ) - -last_pdu_age_metric = Gauge( - "synapse_federation_last_received_pdu_age", - "The age (in seconds) of the last PDU successfully received from the given domain", +last_pdu_ts_metric = Gauge( + "synapse_federation_last_received_pdu_time", + "The timestamp of the last PDU which was successfully received from the given domain", labelnames=("server_name",), ) class FederationServer(FederationBase): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() @@ -112,14 +112,15 @@ def __init__(self, hs): # with FederationHandlerRegistry. hs.get_directory_handler() - self._federation_ratelimiter = hs.get_federation_ratelimiter() - self._server_linearizer = Linearizer("fed_server") - self._transaction_linearizer = Linearizer("fed_txn_handler") + + # origins that we are currently processing a transaction from. + # a dict from origin to txn id. + self._active_transactions = {} # type: Dict[str, str] # We cache results for transaction with the same ID self._transaction_resp_cache = ResponseCache( - hs, "fed_txn_handler", timeout_ms=30000 + hs.get_clock(), "fed_txn_handler", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] self.transaction_actions = TransactionActions(self.store) @@ -129,10 +130,10 @@ def __init__(self, hs): # We cache responses to state queries, as they take a while and often # come in waves. self._state_resp_cache = ResponseCache( - hs, "state_resp", timeout_ms=30000 + hs.get_clock(), "state_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] self._state_ids_resp_cache = ResponseCache( - hs, "state_ids_resp", timeout_ms=30000 + hs.get_clock(), "state_ids_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] self._federation_metrics_domains = ( @@ -169,6 +170,33 @@ async def on_incoming_transaction( logger.debug("[%s] Got transaction", transaction_id) + # Reject malformed transactions early: reject if too many PDUs/EDUs + if len(transaction.pdus) > 50 or ( # type: ignore + hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore + ): + logger.info("Transaction PDU or EDU count too large. Returning 400") + return 400, {} + + # we only process one transaction from each origin at a time. We need to do + # this check here, rather than in _on_incoming_transaction_inner so that we + # don't cache the rejection in _transaction_resp_cache (so that if the txn + # arrives again later, we can process it). + current_transaction = self._active_transactions.get(origin) + if current_transaction and current_transaction != transaction_id: + logger.warning( + "Received another txn %s from %s while still processing %s", + transaction_id, + origin, + current_transaction, + ) + return 429, { + "errcode": Codes.UNKNOWN, + "error": "Too many concurrent transactions", + } + + # CRITICAL SECTION: we must now not await until we populate _active_transactions + # in _on_incoming_transaction_inner. + # We wrap in a ResponseCache so that we de-duplicate retried # transactions. return await self._transaction_resp_cache.wrap( @@ -182,26 +210,18 @@ async def on_incoming_transaction( async def _on_incoming_transaction_inner( self, origin: str, transaction: Transaction, request_time: int ) -> Tuple[int, Dict[str, Any]]: - # Use a linearizer to ensure that transactions from a remote are - # processed in order. - with await self._transaction_linearizer.queue(origin): - # We rate limit here *after* we've queued up the incoming requests, - # so that we don't fill up the ratelimiter with blocked requests. - # - # This is important as the ratelimiter allows N concurrent requests - # at a time, and only starts ratelimiting if there are more requests - # than that being processed at a time. If we queued up requests in - # the linearizer/response cache *after* the ratelimiting then those - # queued up requests would count as part of the allowed limit of N - # concurrent requests. - with self._federation_ratelimiter.ratelimit(origin) as d: - await d - - result = await self._handle_incoming_transaction( - origin, transaction, request_time - ) + # CRITICAL SECTION: the first thing we must do (before awaiting) is + # add an entry to _active_transactions. + assert origin not in self._active_transactions + self._active_transactions[origin] = transaction.transaction_id # type: ignore - return result + try: + result = await self._handle_incoming_transaction( + origin, transaction, request_time + ) + return result + finally: + del self._active_transactions[origin] async def _handle_incoming_transaction( self, origin: str, transaction: Transaction, request_time: int @@ -227,19 +247,6 @@ async def _handle_incoming_transaction( logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore - # Reject if PDU count > 50 or EDU count > 100 - if len(transaction.pdus) > 50 or ( # type: ignore - hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore - ): - - logger.info("Transaction PDU or EDU count too large. Returning 400") - - response = {} - await self.transaction_actions.set_response( - origin, transaction, 400, response - ) - return 400, response - # We process PDUs and EDUs in parallel. This is important as we don't # want to block things like to device messages from reaching clients # behind the potentially expensive handling of PDUs. @@ -335,42 +342,48 @@ async def _handle_pdus_in_txn( # impose a limit to avoid going too crazy with ram/cpu. async def process_pdus_for_room(room_id: str): - logger.debug("Processing PDUs for %s", room_id) - try: - await self.check_server_matches_acl(origin_host, room_id) - except AuthError as e: - logger.warning("Ignoring PDUs for room %s from banned server", room_id) - for pdu in pdus_by_room[room_id]: - event_id = pdu.event_id - pdu_results[event_id] = e.error_dict() - return + with nested_logging_context(room_id): + logger.debug("Processing PDUs for %s", room_id) - for pdu in pdus_by_room[room_id]: - event_id = pdu.event_id - with pdu_process_time.time(): - with nested_logging_context(event_id): - try: - await self._handle_received_pdu(origin, pdu) - pdu_results[event_id] = {} - except FederationError as e: - logger.warning("Error handling PDU %s: %s", event_id, e) - pdu_results[event_id] = {"error": str(e)} - except Exception as e: - f = failure.Failure() - pdu_results[event_id] = {"error": str(e)} - logger.error( - "Failed to handle PDU %s", - event_id, - exc_info=(f.type, f.value, f.getTracebackObject()), - ) + try: + await self.check_server_matches_acl(origin_host, room_id) + except AuthError as e: + logger.warning( + "Ignoring PDUs for room %s from banned server", room_id + ) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + pdu_results[event_id] = e.error_dict() + return + + for pdu in pdus_by_room[room_id]: + pdu_results[pdu.event_id] = await process_pdu(pdu) + + async def process_pdu(pdu: EventBase) -> JsonDict: + event_id = pdu.event_id + with pdu_process_time.time(): + with nested_logging_context(event_id): + try: + await self._handle_received_pdu(origin, pdu) + return {} + except FederationError as e: + logger.warning("Error handling PDU %s: %s", event_id, e) + return {"error": str(e)} + except Exception as e: + f = failure.Failure() + logger.error( + "Failed to handle PDU %s", + event_id, + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + return {"error": str(e)} await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if newest_pdu_ts and origin in self._federation_metrics_domains: - newest_pdu_age = self._clock.time_msec() - newest_pdu_ts - last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000) + last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000) return pdu_results @@ -448,18 +461,22 @@ async def on_state_ids_request( async def _on_state_ids_request_compute(self, room_id, event_id): state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) + auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( self, room_id: str, event_id: str ) -> Dict[str, list]: if event_id: - pdus = await self.handler.get_state_for_pdu(room_id, event_id) + pdus = await self.handler.get_state_for_pdu( + room_id, event_id + ) # type: Iterable[EventBase] else: pdus = (await self.state.get_current_state(room_id)).values() - auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + auth_chain = await self.store.get_auth_chain( + room_id, [pdu.event_id for pdu in pdus] + ) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], @@ -863,7 +880,9 @@ def __init__(self, hs: "HomeServer"): self.edu_handlers = ( {} ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] - self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] + self.query_handlers = ( + {} + ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]] # Map from type to instance names that we should route EDU handling to. # We randomly choose one instance from the list to route to for each new @@ -897,7 +916,7 @@ def register_edu_handler( self.edu_handlers[edu_type] = handler def register_query_handler( - self, query_type: str, handler: Callable[[dict], defer.Deferred] + self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] ): """Sets the handler callable that will be used to handle an incoming federation query of the given type. @@ -970,7 +989,7 @@ async def on_edu(self, edu_type: str, origin: str, content: dict): # Oh well, let's just log and move on. logger.warning("No handler registered for EDU type %s", edu_type) - async def on_query(self, query_type: str, args: dict): + async def on_query(self, query_type: str, args: dict) -> JsonDict: handler = self.query_handlers.get(query_type) if handler: return await handler(args) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index deb519f3efb6..cc0d765e5f2f 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast +import attr from prometheus_client import Counter from synapse.api.errors import ( @@ -93,6 +94,10 @@ def __init__( self._destination = destination self.transmission_loop_running = False + # Flag to signal to any running transmission loop that there is new data + # queued up to be sent. + self._new_data_to_send = False + # True whilst we are sending events that the remote homeserver missed # because it was unreachable. We start in this state so we can perform # catch-up at startup. @@ -108,7 +113,7 @@ def __init__( # destination (we are the only updater so this is safe) self._last_successful_stream_ordering = None # type: Optional[int] - # a list of pending PDUs + # a queue of pending PDUs self._pending_pdus = [] # type: List[EventBase] # XXX this is never actually used: see @@ -208,6 +213,10 @@ def attempt_new_transaction(self) -> None: transaction in the background. """ + # Mark that we (may) have new things to send, so that any running + # transmission loop will recheck whether there is stuff to send. + self._new_data_to_send = True + if self.transmission_loop_running: # XXX: this can get stuck on by a never-ending # request at which point pending_pdus just keeps growing. @@ -250,125 +259,41 @@ async def _transaction_transmission_loop(self) -> None: pending_pdus = [] while True: - # We have to keep 2 free slots for presence and rr_edus - limit = MAX_EDUS_PER_TRANSACTION - 2 - - device_update_edus, dev_list_id = await self._get_device_update_edus( - limit - ) - - limit -= len(device_update_edus) - - ( - to_device_edus, - device_stream_id, - ) = await self._get_to_device_message_edus(limit) - - pending_edus = device_update_edus + to_device_edus - - # BEGIN CRITICAL SECTION - # - # In order to avoid a race condition, we need to make sure that - # the following code (from popping the queues up to the point - # where we decide if we actually have any pending messages) is - # atomic - otherwise new PDUs or EDUs might arrive in the - # meantime, but not get sent because we hold the - # transmission_loop_running flag. - - pending_pdus = self._pending_pdus + self._new_data_to_send = False - # We can only include at most 50 PDUs per transactions - pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:] + async with _TransactionQueueManager(self) as ( + pending_pdus, + pending_edus, + ): + if not pending_pdus and not pending_edus: + logger.debug("TX [%s] Nothing to send", self._destination) + + # If we've gotten told about new things to send during + # checking for things to send, we try looking again. + # Otherwise new PDUs or EDUs might arrive in the meantime, + # but not get sent because we hold the + # `transmission_loop_running` flag. + if self._new_data_to_send: + continue + else: + return - pending_edus.extend(self._get_rr_edus(force_flush=False)) - pending_presence = self._pending_presence - self._pending_presence = {} - if pending_presence: - pending_edus.append( - Edu( - origin=self._server_name, - destination=self._destination, - edu_type="m.presence", - content={ - "push": [ - format_user_presence_state( - presence, self._clock.time_msec() - ) - for presence in pending_presence.values() - ] - }, + if pending_pdus: + logger.debug( + "TX [%s] len(pending_pdus_by_dest[dest]) = %d", + self._destination, + len(pending_pdus), ) - ) - pending_edus.extend( - self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) - ) - while ( - len(pending_edus) < MAX_EDUS_PER_TRANSACTION - and self._pending_edus_keyed - ): - _, val = self._pending_edus_keyed.popitem() - pending_edus.append(val) - - if pending_pdus: - logger.debug( - "TX [%s] len(pending_pdus_by_dest[dest]) = %d", - self._destination, - len(pending_pdus), + await self._transaction_manager.send_new_transaction( + self._destination, pending_pdus, pending_edus ) - if not pending_pdus and not pending_edus: - logger.debug("TX [%s] Nothing to send", self._destination) - self._last_device_stream_id = device_stream_id - return - - # if we've decided to send a transaction anyway, and we have room, we - # may as well send any pending RRs - if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: - pending_edus.extend(self._get_rr_edus(force_flush=True)) - - # END CRITICAL SECTION - - success = await self._transaction_manager.send_new_transaction( - self._destination, pending_pdus, pending_edus - ) - if success: sent_transactions_counter.inc() sent_edus_counter.inc(len(pending_edus)) for edu in pending_edus: sent_edus_by_type.labels(edu.edu_type).inc() - # Remove the acknowledged device messages from the database - # Only bother if we actually sent some device messages - if to_device_edus: - await self._store.delete_device_msgs_for_remote( - self._destination, device_stream_id - ) - # also mark the device updates as sent - if device_update_edus: - logger.info( - "Marking as sent %r %r", self._destination, dev_list_id - ) - await self._store.mark_as_sent_devices_by_remote( - self._destination, dev_list_id - ) - - self._last_device_stream_id = device_stream_id - self._last_device_list_stream_id = dev_list_id - - if pending_pdus: - # we sent some PDUs and it was successful, so update our - # last_successful_stream_ordering in the destinations table. - final_pdu = pending_pdus[-1] - last_successful_stream_ordering = ( - final_pdu.internal_metadata.stream_ordering - ) - assert last_successful_stream_ordering - await self._store.set_destination_last_successful_stream_ordering( - self._destination, last_successful_stream_ordering - ) - else: - break except NotRetryingDestination as e: logger.debug( "TX [%s] not ready for retry yet (next retry at %s) - " @@ -401,7 +326,7 @@ async def _transaction_transmission_loop(self) -> None: self._pending_presence = {} self._pending_rrs = {} - self._start_catching_up() + self._start_catching_up() except FederationDeniedError as e: logger.info(e) except HttpResponseException as e: @@ -412,7 +337,6 @@ async def _transaction_transmission_loop(self) -> None: e, ) - self._start_catching_up() except RequestSendFailed as e: logger.warning( "TX [%s] Failed to send transaction: %s", self._destination, e @@ -422,16 +346,12 @@ async def _transaction_transmission_loop(self) -> None: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) - - self._start_catching_up() except Exception: logger.exception("TX [%s] Failed to send transaction", self._destination) for p in pending_pdus: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) - - self._start_catching_up() finally: # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False @@ -499,13 +419,10 @@ async def _catch_up_transmission_loop(self) -> None: rooms = [p.room_id for p in catchup_pdus] logger.info("Catching up rooms to %s: %r", self._destination, rooms) - success = await self._transaction_manager.send_new_transaction( + await self._transaction_manager.send_new_transaction( self._destination, catchup_pdus, [] ) - if not success: - return - sent_transactions_counter.inc() final_pdu = catchup_pdus[-1] self._last_successful_stream_ordering = cast( @@ -584,3 +501,135 @@ def _start_catching_up(self) -> None: """ self._catching_up = True self._pending_pdus = [] + + +@attr.s(slots=True) +class _TransactionQueueManager: + """A helper async context manager for pulling stuff off the queues and + tracking what was last successfully sent, etc. + """ + + queue = attr.ib(type=PerDestinationQueue) + + _device_stream_id = attr.ib(type=Optional[int], default=None) + _device_list_id = attr.ib(type=Optional[int], default=None) + _last_stream_ordering = attr.ib(type=Optional[int], default=None) + _pdus = attr.ib(type=List[EventBase], factory=list) + + async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: + # First we calculate the EDUs we want to send, if any. + + # We start by fetching device related EDUs, i.e device updates and to + # device messages. We have to keep 2 free slots for presence and rr_edus. + limit = MAX_EDUS_PER_TRANSACTION - 2 + + device_update_edus, dev_list_id = await self.queue._get_device_update_edus( + limit + ) + + if device_update_edus: + self._device_list_id = dev_list_id + else: + self.queue._last_device_list_stream_id = dev_list_id + + limit -= len(device_update_edus) + + ( + to_device_edus, + device_stream_id, + ) = await self.queue._get_to_device_message_edus(limit) + + if to_device_edus: + self._device_stream_id = device_stream_id + else: + self.queue._last_device_stream_id = device_stream_id + + pending_edus = device_update_edus + to_device_edus + + # Now add the read receipt EDU. + pending_edus.extend(self.queue._get_rr_edus(force_flush=False)) + + # And presence EDU. + if self.queue._pending_presence: + pending_edus.append( + Edu( + origin=self.queue._server_name, + destination=self.queue._destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.queue._clock.time_msec() + ) + for presence in self.queue._pending_presence.values() + ] + }, + ) + ) + self.queue._pending_presence = {} + + # Finally add any other types of EDUs if there is room. + pending_edus.extend( + self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) + ) + while ( + len(pending_edus) < MAX_EDUS_PER_TRANSACTION + and self.queue._pending_edus_keyed + ): + _, val = self.queue._pending_edus_keyed.popitem() + pending_edus.append(val) + + # Now we look for any PDUs to send, by getting up to 50 PDUs from the + # queue + self._pdus = self.queue._pending_pdus[:50] + + if not self._pdus and not pending_edus: + return [], [] + + # if we've decided to send a transaction anyway, and we have room, we + # may as well send any pending RRs + if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: + pending_edus.extend(self.queue._get_rr_edus(force_flush=True)) + + if self._pdus: + self._last_stream_ordering = self._pdus[ + -1 + ].internal_metadata.stream_ordering + assert self._last_stream_ordering + + return self._pdus, pending_edus + + async def __aexit__(self, exc_type, exc, tb): + if exc_type is not None: + # Failed to send transaction, so we bail out. + return + + # Successfully sent transactions, so we remove pending PDUs from the queue + if self._pdus: + self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :] + + # Succeeded to send the transaction so we record where we have sent up + # to in the various streams + + if self._device_stream_id: + await self.queue._store.delete_device_msgs_for_remote( + self.queue._destination, self._device_stream_id + ) + self.queue._last_device_stream_id = self._device_stream_id + + # also mark the device updates as sent + if self._device_list_id: + logger.info( + "Marking as sent %r %r", self.queue._destination, self._device_list_id + ) + await self.queue._store.mark_as_sent_devices_by_remote( + self.queue._destination, self._device_list_id + ) + self.queue._last_device_list_stream_id = self._device_list_id + + if self._last_stream_ordering: + # we sent some PDUs and it was successful, so update our + # last_successful_stream_ordering in the destinations table. + await self.queue._store.set_destination_last_successful_stream_ordering( + self.queue._destination, self._last_stream_ordering + ) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 763aff296c87..07b740c2f2fc 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -36,9 +36,9 @@ logger = logging.getLogger(__name__) -last_pdu_age_metric = Gauge( - "synapse_federation_last_sent_pdu_age", - "The age (in seconds) of the last PDU successfully sent to the given domain", +last_pdu_ts_metric = Gauge( + "synapse_federation_last_sent_pdu_time", + "The timestamp of the last PDU which was successfully sent to the given domain", labelnames=("server_name",), ) @@ -69,15 +69,12 @@ async def send_new_transaction( destination: str, pdus: List[EventBase], edus: List[Edu], - ) -> bool: + ) -> None: """ Args: destination: The destination to send to (e.g. 'example.org') pdus: In-order list of PDUs to send edus: List of EDUs to send - - Returns: - True iff the transaction was successful """ # Make a transaction-sending opentracing span. This span follows on from @@ -96,8 +93,6 @@ async def send_new_transaction( edu.strip_context() with start_active_span_follows_from("send_transaction", span_contexts): - success = True - logger.debug("TX [%s] _attempt_new_transaction", destination) txn_id = str(self._next_txn_id) @@ -152,45 +147,29 @@ def json_data_cb(): response = await self._transport_layer.send_transaction( transaction, json_data_cb ) - code = 200 except HttpResponseException as e: code = e.code response = e.response - if e.code in (401, 404, 429) or 500 <= e.code: - logger.info( - "TX [%s] {%s} got %d response", destination, txn_id, code - ) - raise e - - logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) - - if code == 200: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warning( - "TX [%s] {%s} Remote returned error for %s: %s", - destination, - txn_id, - e_id, - r, - ) - else: - for p in pdus: + set_tag(tags.ERROR, True) + + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) + raise + + logger.info("TX [%s] {%s} got 200 response", destination, txn_id) + + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: logger.warning( - "TX [%s] {%s} Failed to send event %s", + "TX [%s] {%s} Remote returned error for %s: %s", destination, txn_id, - p.event_id, + e_id, + r, ) - success = False - if success and pdus and destination in self._federation_metrics_domains: + if pdus and destination in self._federation_metrics_domains: last_pdu = pdus[-1] - last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts - last_pdu_age_metric.labels(server_name=destination).set( - last_pdu_age / 1000 + last_pdu_ts_metric.labels(server_name=destination).set( + last_pdu.origin_server_ts / 1000 ) - - set_tag(tags.ERROR, not success) - return success diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 5ecb2da1acb7..132be238dd5f 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -73,7 +73,9 @@ async def start_listening(self) -> None: "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port ) try: - self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host) + self.reactor.listenTCP( + self.hs.config.acme_port, srv, backlog=50, interface=host + ) except twisted.internet.error.CannotListenError as e: check_bind_error(e, host, bind_addresses) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3978e4151851..fb5f8118f038 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -65,6 +65,7 @@ from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import maybe_awaitable +from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email @@ -170,6 +171,16 @@ class SsoLoginExtraAttributes: extra_attributes = attr.ib(type=JsonDict) +@attr.s(slots=True, frozen=True) +class LoginTokenAttributes: + """Data we store in a short-term login token""" + + user_id = attr.ib(type=str) + + # the SSO Identity Provider that the user authenticated with, to get this token + auth_provider_id = attr.ib(type=str) + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -326,7 +337,8 @@ async def validate_user_via_ui_auth( user is too high to proceed """ - + if not requester.access_token_id: + raise ValueError("Cannot validate a user without an access token") if self._ui_auth_session_timeout: last_validated = await self.store.get_access_token_last_validated( requester.access_token_id @@ -1164,18 +1176,16 @@ async def _check_local_password(self, user_id: str, password: str) -> Optional[s return None return user_id - async def validate_short_term_login_token_and_get_user_id(self, login_token: str): - auth_api = self.hs.get_auth() - user_id = None + async def validate_short_term_login_token( + self, login_token: str + ) -> LoginTokenAttributes: try: - macaroon = pymacaroons.Macaroon.deserialize(login_token) - user_id = auth_api.get_user_id_from_macaroon(macaroon) - auth_api.validate_macaroon(macaroon, "login", user_id) + res = self.macaroon_gen.verify_short_term_login_token(login_token) except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - await self.auth.check_auth_blocking(user_id) - return user_id + await self.auth.check_auth_blocking(res.user_id) + return res async def delete_access_token(self, access_token: str): """Invalidate a single access token @@ -1204,7 +1214,7 @@ async def delete_access_token(self, access_token: str): async def delete_access_tokens_for_user( self, user_id: str, - except_token_id: Optional[str] = None, + except_token_id: Optional[int] = None, device_id: Optional[str] = None, ): """Invalidate access tokens belonging to a user @@ -1397,6 +1407,7 @@ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> s async def complete_sso_login( self, registered_user_id: str, + auth_provider_id: str, request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, @@ -1406,6 +1417,9 @@ async def complete_sso_login( Args: registered_user_id: The registered user ID to complete SSO login for. + auth_provider_id: The id of the SSO Identity provider that was used for + login. This will be stored in the login token for future tracking in + prometheus metrics. request: The request to complete. client_redirect_url: The URL to which to redirect the user at the end of the process. @@ -1427,6 +1441,7 @@ async def complete_sso_login( self._complete_sso_login( registered_user_id, + auth_provider_id, request, client_redirect_url, extra_attributes, @@ -1437,6 +1452,7 @@ async def complete_sso_login( def _complete_sso_login( self, registered_user_id: str, + auth_provider_id: str, request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, @@ -1463,7 +1479,7 @@ def _complete_sso_login( # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( - registered_user_id + registered_user_id, auth_provider_id=auth_provider_id ) # Append the login token to the original redirect URL (i.e. with its query @@ -1569,15 +1585,48 @@ def generate_access_token( return macaroon.serialize() def generate_short_term_login_token( - self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + self, + user_id: str, + auth_provider_id: str, + duration_in_ms: int = (2 * 60 * 1000), ) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") now = self.hs.get_clock().time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) + macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) return macaroon.serialize() + def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: + """Verify a short-term-login macaroon + + Checks that the given token is a valid, unexpired short-term-login token + minted by this server. + + Args: + token: the login token to verify + + Returns: + the user_id that this token is valid for + + Raises: + MacaroonVerificationFailedException if the verification failed + """ + macaroon = pymacaroons.Macaroon.deserialize(token) + user_id = get_value_from_macaroon(macaroon, "user_id") + auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") + + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = login") + v.satisfy_general(lambda c: c.startswith("user_id = ")) + v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) + satisfy_expiry(v, self.hs.get_clock().time_msec) + v.verify(macaroon, self.hs.config.key.macaroon_secret_key) + + return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) + def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = delete_pusher") diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 04972f9cf0b6..cb67589f7d0c 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -83,6 +83,7 @@ def __init__(self, hs: "HomeServer"): # the SsoIdentityProvider protocol type. self.idp_icon = None self.idp_brand = None + self.unstable_idp_brand = None self._sso_handler = hs.get_sso_handler() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2ead626a4d5a..598a66f74cf4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -201,7 +201,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: or pdu.internal_metadata.is_outlier() ) if already_seen: - logger.debug("[%s %s]: Already seen pdu", room_id, event_id) + logger.debug("Already seen pdu") return # do some initial sanity-checking of the event. In particular, make @@ -210,18 +210,14 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warning( - "[%s %s] Received event failed sanity checks", room_id, event_id - ) + logger.warning("Received event failed sanity checks") raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) # If we are currently in the process of joining this room, then we # queue up events for later processing. if room_id in self.room_queues: logger.info( - "[%s %s] Queuing PDU from %s for now: join in progress", - room_id, - event_id, + "Queuing PDU from %s for now: join in progress", origin, ) self.room_queues[room_id].append((pdu, origin)) @@ -236,9 +232,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( - "[%s %s] Ignoring PDU from %s as we're not in the room", - room_id, - event_id, + "Ignoring PDU from %s as we're not in the room", origin, ) return None @@ -250,7 +244,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: # We only backfill backwards to the min depth. min_depth = await self.get_min_depth_for_context(pdu.room_id) - logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) + logger.debug("min_depth: %d", min_depth) prevs = set(pdu.prev_event_ids()) seen = await self.store.have_events_in_timeline(prevs) @@ -267,17 +261,13 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: # If we're missing stuff, ensure we only fetch stuff one # at a time. logger.info( - "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", - room_id, - event_id, + "Acquiring room lock to fetch %d missing prev_events: %s", len(missing_prevs), shortstr(missing_prevs), ) with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( - "[%s %s] Acquired room lock to fetch %d missing prev_events", - room_id, - event_id, + "Acquired room lock to fetch %d missing prev_events", len(missing_prevs), ) @@ -297,9 +287,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: if not prevs - seen: logger.info( - "[%s %s] Found all missing prev_events", - room_id, - event_id, + "Found all missing prev_events", ) if prevs - seen: @@ -329,9 +317,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: if sent_to_us_directly: logger.warning( - "[%s %s] Rejecting: failed to fetch %d prev events: %s", - room_id, - event_id, + "Rejecting: failed to fetch %d prev events: %s", len(prevs - seen), shortstr(prevs - seen), ) @@ -367,17 +353,16 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: # Ask the remote server for the states we don't # know about for p in prevs - seen: - logger.info( - "Requesting state at missing prev_event %s", - event_id, - ) + logger.info("Requesting state after missing prev_event %s", p) with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - (remote_state, _,) = await self._get_state_for_room( - origin, room_id, p, include_event_in_state=True + remote_state = ( + await self._get_state_after_missing_prev_event( + origin, room_id, p + ) ) remote_state_map = { @@ -414,10 +399,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( - "[%s %s] Error attempting to resolve state at missing " - "prev_events", - room_id, - event_id, + "Error attempting to resolve state at missing " "prev_events", exc_info=True, ) raise FederationError( @@ -454,9 +436,7 @@ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): latest |= seen logger.info( - "[%s %s]: Requesting missing events between %s and %s", - room_id, - event_id, + "Requesting missing events between %s and %s", shortstr(latest), event_id, ) @@ -523,15 +503,11 @@ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warning( - "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e - ) + logger.warning("Failed to get prev_events: %s", e) return logger.info( - "[%s %s]: Got %d prev_events: %s", - room_id, - event_id, + "Got %d prev_events: %s", len(missing_events), shortstr(missing_events), ) @@ -542,9 +518,7 @@ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): for ev in missing_events: logger.info( - "[%s %s] Handling received prev_event %s", - room_id, - event_id, + "Handling received prev_event %s", ev.event_id, ) with nested_logging_context(ev.event_id): @@ -553,9 +527,7 @@ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): except FederationError as e: if e.code == 403: logger.warning( - "[%s %s] Received prev_event %s failed history check.", - room_id, - event_id, + "Received prev_event %s failed history check.", ev.event_id, ) else: @@ -566,7 +538,6 @@ async def _get_state_for_room( destination: str, room_id: str, event_id: str, - include_event_in_state: bool = False, ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. @@ -574,11 +545,9 @@ async def _get_state_for_room( destination: The remote homeserver to query for the state. room_id: The id of the room we're interested in. event_id: The id of the event we want the state at. - include_event_in_state: if true, the event itself will be included in the - returned state event list. Returns: - A list of events in the state, possibly including the event itself, and + A list of events in the state, not including the event itself, and a list of events in the auth chain for the given event. """ ( @@ -590,9 +559,6 @@ async def _get_state_for_room( desired_events = set(state_event_ids + auth_event_ids) - if include_event_in_state: - desired_events.add(event_id) - event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) @@ -609,13 +575,6 @@ async def _get_state_for_room( event_map[e_id] for e_id in state_event_ids if e_id in event_map ] - if include_event_in_state: - remote_event = event_map.get(event_id) - if not remote_event: - raise Exception("Unable to get missing prev_event %s" % (event_id,)) - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] auth_chain.sort(key=lambda e: e.depth) @@ -689,6 +648,131 @@ async def _get_events_from_store_or_dest( return fetched_events + async def _get_state_after_missing_prev_event( + self, + destination: str, + room_id: str, + event_id: str, + ) -> List[EventBase]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + + Returns: + A list of events in the state, including the event itself + """ + # TODO: This function is basically the same as _get_state_for_room. Can + # we make backfill() use it, rather than having two code paths? I think the + # only difference is that backfill() persists the prev events separately. + + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + logger.debug( + "state_ids returned %i state events, %i auth events", + len(state_event_ids), + len(auth_event_ids), + ) + + # start by just trying to fetch the events from the store + desired_events = set(state_event_ids) + desired_events.add(event_id) + logger.debug("Fetching %i events from cache/store", len(desired_events)) + fetched_events = await self.store.get_events( + desired_events, allow_rejected=True + ) + + missing_desired_events = desired_events - fetched_events.keys() + logger.debug( + "We are missing %i events (got %i)", + len(missing_desired_events), + len(fetched_events), + ) + + # We probably won't need most of the auth events, so let's just check which + # we have for now, rather than thrashing the event cache with them all + # unnecessarily. + + # TODO: we probably won't actually need all of the auth events, since we + # already have a bunch of the state events. It would be nice if the + # federation api gave us a way of finding out which we actually need. + + missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events.difference_update( + await self.store.have_seen_events(missing_auth_events) + ) + logger.debug("We are also missing %i auth events", len(missing_auth_events)) + + missing_events = missing_desired_events | missing_auth_events + logger.debug("Fetching %i events from remote", len(missing_events)) + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + (await self.store.get_events(missing_desired_events, allow_rejected=True)) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + del fetched_events[bad_event_id] + + # if we couldn't get the prev event in question, that's a problem. + remote_event = fetched_events.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + + # missing state at that event is a warning, not a blocker + # XXX: this doesn't sound right? it means that we'll end up with incomplete + # state. + failed_to_fetch = desired_events - fetched_events.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state events for %s %s", + event_id, + failed_to_fetch, + ) + + remote_state = [ + fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events + ] + + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) + + return remote_state + async def _process_received_pdu( self, origin: str, @@ -707,10 +791,7 @@ async def _process_received_pdu( (ie, we are missing one or more prev_events), the resolved state at the event """ - room_id = event.room_id - event_id = event.event_id - - logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) + logger.debug("Processing event: %s", event) try: await self._handle_new_event(origin, event, state=state) @@ -871,7 +952,6 @@ async def backfill(self, dest, room_id, limit, extremities): destination=dest, room_id=room_id, event_id=e_id, - include_event_in_state=False, ) auth_events.update({a.event_id: a for a in auth}) auth_events.update({s.event_id: s for s in state}) @@ -1317,7 +1397,7 @@ async def send_invite(self, target_host, event): async def on_event_auth(self, event_id: str) -> List[EventBase]: event = await self.store.get_event(event_id) auth = await self.store.get_auth_chain( - list(event.auth_event_ids()), include_given=True + event.room_id, list(event.auth_event_ids()), include_given=True ) return list(auth) @@ -1580,7 +1660,7 @@ async def on_send_join_request(self, origin, pdu): prev_state_ids = await context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) - auth_chain = await self.store.get_auth_chain(state_ids) + auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) state = await self.store.get_events(list(prev_state_ids.values())) @@ -2219,7 +2299,7 @@ async def on_query_auth( # Now get the current auth_chain for the event. local_auth_chain = await self.store.get_auth_chain( - list(event.auth_event_ids()), include_given=True + room_id, list(event.auth_event_ids()), include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 71a507667296..13f8152283f2 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -48,7 +48,7 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.validator = EventValidator() self.snapshot_cache = ResponseCache( - hs, "initial_sync_cache" + hs.get_clock(), "initial_sync_cache" ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 07db1e31e425..6d8551a6d61b 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2020 Quentin Gliech +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,13 +15,13 @@ # limitations under the License. import inspect import logging -from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar +from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union from urllib.parse import urlencode import attr import pymacaroons from authlib.common.security import generate_token -from authlib.jose import JsonWebToken +from authlib.jose import JsonWebToken, jwt from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo @@ -28,20 +29,26 @@ from jinja2 import Environment, Template from pymacaroons.exceptions import ( MacaroonDeserializationException, + MacaroonInitException, MacaroonInvalidSignatureException, ) from typing_extensions import TypedDict from twisted.web.client import readBody +from twisted.web.http_headers import Headers from synapse.config import ConfigError -from synapse.config.oidc_config import OidcProviderConfig +from synapse.config.oidc_config import ( + OidcProviderClientSecretJwtKey, + OidcProviderConfig, +) from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart -from synapse.util import json_decoder +from synapse.util import Clock, json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall +from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry if TYPE_CHECKING: from synapse.server import HomeServer @@ -211,7 +218,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None: session_data = self._token_generator.verify_oidc_session_token( session, state ) - except (MacaroonDeserializationException, ValueError) as e: + except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e: logger.exception("Invalid session for OIDC callback") self._sso_handler.render_error(request, "invalid_session", str(e)) return @@ -275,9 +282,21 @@ def __init__( self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method + + client_secret = None # type: Union[None, str, JwtClientSecret] + if provider.client_secret: + client_secret = provider.client_secret + elif provider.client_secret_jwt_key: + client_secret = JwtClientSecret( + provider.client_secret_jwt_key, + provider.client_id, + provider.issuer, + hs.get_clock(), + ) + self._client_auth = ClientAuth( provider.client_id, - provider.client_secret, + client_secret, provider.client_auth_method, ) # type: ClientAuth self._client_auth_method = provider.client_auth_method @@ -312,6 +331,9 @@ def __init__( # optional brand identifier for this auth provider self.idp_brand = provider.idp_brand + # Optional brand identifier for the unstable API (see MSC2858). + self.unstable_idp_brand = provider.unstable_idp_brand + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) @@ -521,7 +543,7 @@ async def _exchange_code(self, code: str) -> Token: """ metadata = await self.load_metadata() token_endpoint = metadata.get("token_endpoint") - headers = { + raw_headers = { "Content-Type": "application/x-www-form-urlencoded", "User-Agent": self._http_client.user_agent, "Accept": "application/json", @@ -535,10 +557,10 @@ async def _exchange_code(self, code: str) -> Token: body = urlencode(args, True) # Fill the body/headers with credentials - uri, headers, body = self._client_auth.prepare( - method="POST", uri=token_endpoint, headers=headers, body=body + uri, raw_headers, body = self._client_auth.prepare( + method="POST", uri=token_endpoint, headers=raw_headers, body=body ) - headers = {k: [v] for (k, v) in headers.items()} + headers = Headers({k: [v] for (k, v) in raw_headers.items()}) # Do the actual request # We're not using the SimpleHttpClient util methods as we don't want to @@ -745,7 +767,7 @@ async def handle_redirect_request( idp_id=self.idp_id, nonce=nonce, client_redirect_url=client_redirect_url.decode(), - ui_auth_session_id=ui_auth_session_id, + ui_auth_session_id=ui_auth_session_id or "", ), ) @@ -976,6 +998,81 @@ def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: return str(remote_user_id) +# number of seconds a newly-generated client secret should be valid for +CLIENT_SECRET_VALIDITY_SECONDS = 3600 + +# minimum remaining validity on a client secret before we should generate a new one +CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600 + + +class JwtClientSecret: + """A class which generates a new client secret on demand, based on a JWK + + This implementation is designed to comply with the requirements for Apple Sign in: + https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048 + + It looks like those requirements are based on https://tools.ietf.org/html/rfc7523, + but it's worth noting that we still put the generated secret in the "client_secret" + field (or rather, whereever client_auth_method puts it) rather than in a + client_assertion field in the body as that RFC seems to require. + """ + + def __init__( + self, + key: OidcProviderClientSecretJwtKey, + oauth_client_id: str, + oauth_issuer: str, + clock: Clock, + ): + self._key = key + self._oauth_client_id = oauth_client_id + self._oauth_issuer = oauth_issuer + self._clock = clock + self._cached_secret = b"" + self._cached_secret_replacement_time = 0 + + def __str__(self): + # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls + # encode_client_secret_basic, which calls "{}".format(secret), which ends up + # here. + return self._get_secret().decode("ascii") + + def __bytes__(self): + # if client_auth_method is client_secret_post, then ClientAuth.prepare calls + # encode_client_secret_post, which ends up here. + return self._get_secret() + + def _get_secret(self) -> bytes: + now = self._clock.time() + + # if we have enough validity on our existing secret, use it + if now < self._cached_secret_replacement_time: + return self._cached_secret + + issued_at = int(now) + expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS + + # we copy the configured header because jwt.encode modifies it. + header = dict(self._key.jwt_header) + + # see https://tools.ietf.org/html/rfc7523#section-3 + payload = { + "sub": self._oauth_client_id, + "aud": self._oauth_issuer, + "iat": issued_at, + "exp": expires_at, + **self._key.jwt_payload, + } + logger.info( + "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload + ) + self._cached_secret = jwt.encode(header, payload, self._key.key) + self._cached_secret_replacement_time = ( + expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS + ) + return self._cached_secret + + class OidcSessionTokenGenerator: """Methods for generating and checking OIDC Session cookies.""" @@ -1020,10 +1117,9 @@ def generate_oidc_session_token( macaroon.add_first_party_caveat( "client_redirect_url = %s" % (session_data.client_redirect_url,) ) - if session_data.ui_auth_session_id: - macaroon.add_first_party_caveat( - "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) - ) + macaroon.add_first_party_caveat( + "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) + ) now = self._clock.time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) @@ -1046,7 +1142,7 @@ def verify_oidc_session_token( The data extracted from the session cookie Raises: - ValueError if an expected caveat is missing from the macaroon. + KeyError if an expected caveat is missing from the macaroon. """ macaroon = pymacaroons.Macaroon.deserialize(session) @@ -1057,26 +1153,16 @@ def verify_oidc_session_token( v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("idp_id = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) - # Sometimes there's a UI auth session ID, it seems to be OK to attempt - # to always satisfy this. v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) - v.satisfy_general(self._verify_expiry) + satisfy_expiry(v, self._clock.time_msec) v.verify(macaroon, self._macaroon_secret_key) # Extract the session data from the token. - nonce = self._get_value_from_macaroon(macaroon, "nonce") - idp_id = self._get_value_from_macaroon(macaroon, "idp_id") - client_redirect_url = self._get_value_from_macaroon( - macaroon, "client_redirect_url" - ) - try: - ui_auth_session_id = self._get_value_from_macaroon( - macaroon, "ui_auth_session_id" - ) # type: Optional[str] - except ValueError: - ui_auth_session_id = None - + nonce = get_value_from_macaroon(macaroon, "nonce") + idp_id = get_value_from_macaroon(macaroon, "idp_id") + client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url") + ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id") return OidcSessionData( nonce=nonce, idp_id=idp_id, @@ -1084,33 +1170,6 @@ def verify_oidc_session_token( ui_auth_session_id=ui_auth_session_id, ) - def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: - """Extracts a caveat value from a macaroon token. - - Args: - macaroon: the token - key: the key of the caveat to extract - - Returns: - The extracted value - - Raises: - ValueError: if the caveat was not in the macaroon - """ - prefix = key + " = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(prefix): - return caveat.caveat_id[len(prefix) :] - raise ValueError("No %s caveat in macaroon" % (key,)) - - def _verify_expiry(self, caveat: str) -> bool: - prefix = "time < " - if not caveat.startswith(prefix): - return False - expiry = int(caveat[len(prefix) :]) - now = self._clock.time_msec() - return now < expiry - @attr.s(frozen=True, slots=True) class OidcSessionData: @@ -1125,8 +1184,8 @@ class OidcSessionData: # The URL the client gave when it initiated the flow. ("" if this is a UI Auth) client_redirect_url = attr.ib(type=str) - # The session ID of the ongoing UI Auth (None if this is a login) - ui_auth_session_id = attr.ib(type=Optional[str], default=None) + # The session ID of the ongoing UI Auth ("" if this is a login) + ui_auth_session_id = attr.ib(type=str) UserAttributeDict = TypedDict( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 059064a4eb44..66dc886c8100 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -285,7 +285,7 @@ async def _purge_history( except Exception: f = Failure() logger.error( - "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) + "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED finally: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3cda89657ef0..1abc8875cb87 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -16,7 +16,9 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple + +from prometheus_client import Counter from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType @@ -41,6 +43,19 @@ logger = logging.getLogger(__name__) +registration_counter = Counter( + "synapse_user_registrations_total", + "Number of new users registered (since restart)", + ["guest", "shadow_banned", "auth_provider"], +) + +login_counter = Counter( + "synapse_user_logins_total", + "Number of user logins (since restart)", + ["guest", "auth_provider"], +) + + class RegistrationHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -67,6 +82,7 @@ def __init__(self, hs: "HomeServer"): ) else: self.device_handler = hs.get_device_handler() + self._register_device_client = self.register_device_inner self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.session_lifetime @@ -156,6 +172,7 @@ async def register_user( bind_emails: Iterable[str] = [], by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, + auth_provider_id: Optional[str] = None, ) -> str: """Registers a new client on the server. @@ -181,8 +198,9 @@ async def register_user( admin api, otherwise False. user_agent_ips: Tuples of IP addresses and user-agents used during the registration process. + auth_provider_id: The SSO IdP the user used, if any. Returns: - The registere user_id. + The registered user_id. Raises: SynapseError if there was a problem registering. """ @@ -192,6 +210,7 @@ async def register_user( threepid, localpart, user_agent_ips or [], + auth_provider_id=auth_provider_id, ) if result == RegistrationBehaviour.DENY: @@ -280,6 +299,12 @@ async def register_user( # if user id is taken, just generate another fail_count += 1 + registration_counter.labels( + guest=make_guest, + shadow_banned=shadow_banned, + auth_provider=(auth_provider_id or ""), + ).inc() + if not self.hs.config.user_consent_at_registration: if not self.hs.config.auto_join_rooms_for_guests and make_guest: logger.info( @@ -638,6 +663,7 @@ async def register_device( initial_display_name: Optional[str], is_guest: bool = False, is_appservice_ghost: bool = False, + auth_provider_id: Optional[str] = None, ) -> Tuple[str, str]: """Register a device for a user and generate an access token. @@ -648,21 +674,40 @@ async def register_device( device_id: The device ID to check, or None to generate a new one. initial_display_name: An optional display name for the device. is_guest: Whether this is a guest account - + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). Returns: Tuple of device ID and access token """ + res = await self._register_device_client( + user_id=user_id, + device_id=device_id, + initial_display_name=initial_display_name, + is_guest=is_guest, + is_appservice_ghost=is_appservice_ghost, + ) - if self.hs.config.worker_app: - r = await self._register_device_client( - user_id=user_id, - device_id=device_id, - initial_display_name=initial_display_name, - is_guest=is_guest, - is_appservice_ghost=is_appservice_ghost, - ) - return r["device_id"], r["access_token"] + login_counter.labels( + guest=is_guest, + auth_provider=(auth_provider_id or ""), + ).inc() + + return res["device_id"], res["access_token"] + async def register_device_inner( + self, + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + is_guest: bool = False, + is_appservice_ghost: bool = False, + ) -> Dict[str, str]: + """Helper for register_device + + Does the bits that need doing on the main process. Not for use outside this + class and RegisterDeviceReplicationServlet. + """ + assert not self.hs.config.worker_app valid_until_ms = None if self.session_lifetime is not None: if is_guest: @@ -687,7 +732,7 @@ async def register_device( is_appservice_ghost=is_appservice_ghost, ) - return (registered_device_id, access_token) + return {"device_id": registered_device_id, "access_token": access_token} async def post_registration_actions( self, user_id: str, auth_result: dict, access_token: Optional[str] diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a488df10d678..4b3d0d72e387 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -121,7 +121,7 @@ def __init__(self, hs: "HomeServer"): # succession, only process the first attempt and return its result to # subsequent requests self._upgrade_response_cache = ResponseCache( - hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS + hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS ) # type: ResponseCache[Tuple[str, str]] self._server_notices_mxid = hs.config.server_notices_mxid diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 14f14db4492b..8bfc46c6545e 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -44,10 +44,10 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache( - hs, "room_list" + hs.get_clock(), "room_list" ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]] self.remote_response_cache = ResponseCache( - hs, "remote_room_list", timeout_ms=30 * 1000 + hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] async def get_local_public_room_list( diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a9645b77d808..ec2ba11c7584 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -81,6 +81,7 @@ def __init__(self, hs: "HomeServer"): # the SsoIdentityProvider protocol type. self.idp_icon = None self.idp_brand = None + self.unstable_idp_brand = None # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 80e28bdcbea5..415b1c2d17c7 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -98,6 +98,11 @@ def idp_brand(self) -> Optional[str]: """Optional branding identifier""" return None + @property + def unstable_idp_brand(self) -> Optional[str]: + """Optional brand identifier for the unstable API (see MSC2858).""" + return None + @abc.abstractmethod async def handle_redirect_request( self, @@ -456,6 +461,7 @@ async def complete_sso_login_request( await self._auth_handler.complete_sso_login( user_id, + auth_provider_id, request, client_redirect_url, extra_login_attributes, @@ -605,6 +611,7 @@ async def _register_mapped_user( default_display_name=attributes.display_name, bind_emails=attributes.emails, user_agent_ips=[(user_agent, ip_address)], + auth_provider_id=auth_provider_id, ) await self._store.record_user_external_id( @@ -886,6 +893,7 @@ async def register_sso_user(self, request: Request, session_id: str) -> None: await self._auth_handler.complete_sso_login( user_id, + session.auth_provider_id, request, session.client_redirect_url, session.extra_login_attributes, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4e8ed7b33f62..f50257cd5736 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -244,7 +244,7 @@ def __init__(self, hs: "HomeServer"): self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.response_cache = ResponseCache( - hs, "sync" + hs.get_clock(), "sync" ) # type: ResponseCache[Tuple[Any, ...]] self.state = hs.get_state_handler() self.auth = hs.get_auth() diff --git a/synapse/http/client.py b/synapse/http/client.py index 72901e3f95e8..1e01e0a9f2b5 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -39,12 +39,15 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, error as twisted_error, protocol, ssl +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.interfaces import ( IAddress, IHostResolution, IReactorPluggableNameResolver, IResolutionReceiver, + ITCPTransport, ) +from twisted.internet.protocol import connectionDone from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone @@ -56,13 +59,20 @@ ) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers -from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse +from twisted.web.iweb import ( + UNKNOWN_LENGTH, + IAgent, + IBodyProducer, + IPolicyForHTTPS, + IResponse, +) from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags +from synapse.types import ISynapseReactor from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred @@ -150,16 +160,17 @@ def __init__( def resolveHostName( self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 ) -> IResolutionReceiver: - - r = recv() addresses = [] # type: List[IAddress] def _callback() -> None: - r.resolutionBegan(None) - has_bad_ip = False - for i in addresses: - ip_address = IPAddress(i.host) + for address in addresses: + # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups + # should go through this path. + if not isinstance(address, (IPv4Address, IPv6Address)): + continue + + ip_address = IPAddress(address.host) if check_against_blacklist( ip_address, self._ip_whitelist, self._ip_blacklist @@ -174,15 +185,15 @@ def _callback() -> None: # request, but all we can really do from here is claim that there were no # valid results. if not has_bad_ip: - for i in addresses: - r.addressResolved(i) - r.resolutionComplete() + for address in addresses: + recv.addressResolved(address) + recv.resolutionComplete() @provider(IResolutionReceiver) class EndpointReceiver: @staticmethod def resolutionBegan(resolutionInProgress: IHostResolution) -> None: - pass + recv.resolutionBegan(resolutionInProgress) @staticmethod def addressResolved(address: IAddress) -> None: @@ -196,10 +207,10 @@ def resolutionComplete() -> None: EndpointReceiver, hostname, portNumber=portNumber ) - return r + return recv -@implementer(IReactorPluggableNameResolver) +@implementer(ISynapseReactor) class BlacklistingReactorWrapper: """ A Reactor wrapper which will prevent DNS resolution to blacklisted IP @@ -324,7 +335,7 @@ def __init__( # filters out blacklisted IP addresses, to prevent DNS rebinding. self.reactor = BlacklistingReactorWrapper( hs.get_reactor(), self._ip_whitelist, self._ip_blacklist - ) + ) # type: ISynapseReactor else: self.reactor = hs.get_reactor() @@ -345,7 +356,7 @@ def __init__( contextFactory=self.hs.get_http_client_context_factory(), pool=pool, use_proxy=use_proxy, - ) + ) # type: IAgent if self._ip_blacklist: # If we have an IP blacklist, we then install the blacklisting Agent @@ -751,6 +762,8 @@ class BodyExceededMaxSize(Exception): class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which immediately errors upon receiving data.""" + transport = None # type: Optional[ITCPTransport] + def __init__(self, deferred: defer.Deferred): self.deferred = deferred @@ -762,18 +775,21 @@ def _maybe_fail(self): self.deferred.errback(BodyExceededMaxSize()) # Close the connection (forcefully) since all the data will get # discarded anyway. + assert self.transport is not None self.transport.abortConnection() def dataReceived(self, data: bytes) -> None: self._maybe_fail() - def connectionLost(self, reason: Failure) -> None: + def connectionLost(self, reason: Failure = connectionDone) -> None: self._maybe_fail() class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" + transport = None # type: Optional[ITCPTransport] + def __init__( self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] ): @@ -796,9 +812,10 @@ def dataReceived(self, data: bytes) -> None: self.deferred.errback(BodyExceededMaxSize()) # Close the connection (forcefully) since all the data will get # discarded anyway. + assert self.transport is not None self.transport.abortConnection() - def connectionLost(self, reason: Failure) -> None: + def connectionLost(self, reason: Failure = connectionDone) -> None: # If the maximum size was already exceeded, there's nothing to do. if self.deferred.called: return @@ -867,6 +884,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by return query_str.encode("utf8") +@implementer(IPolicyForHTTPS) class InsecureInterceptableContextFactory(ssl.ContextFactory): """ Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain. diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index b07aa59c08e5..5935a125fd60 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -35,6 +35,7 @@ from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import ISynapseReactor from synapse.util import Clock logger = logging.getLogger(__name__) @@ -68,7 +69,7 @@ class MatrixFederationAgent: def __init__( self, - reactor: IReactorCore, + reactor: ISynapseReactor, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, ip_blacklist: IPSet, diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 4def7d763304..ecd63e6596b4 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -322,7 +322,8 @@ def _cache_period_from_headers( def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: cache_controls = {} - for hdr in headers.getRawHeaders(b"cache-control", []): + cache_control_headers = headers.getRawHeaders(b"cache-control") or [] + for hdr in cache_control_headers: for directive in hdr.split(b","): splits = [x.strip() for x in directive.split(b"=", 1)] k = splits[0].lower() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 0f107714eafc..5f01ebd3d472 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -59,7 +59,7 @@ start_active_span, tags, ) -from synapse.types import JsonDict +from synapse.types import ISynapseReactor, JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -237,14 +237,14 @@ def __init__(self, hs, tls_client_options_factory): # addresses, to prevent DNS rebinding. self.reactor = BlacklistingReactorWrapper( hs.get_reactor(), None, hs.config.federation_ip_range_blacklist - ) + ) # type: ISynapseReactor user_agent = hs.version_string if hs.config.user_agent_suffix: user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) user_agent = user_agent.encode("ascii") - self.agent = MatrixFederationAgent( + federation_agent = MatrixFederationAgent( self.reactor, tls_client_options_factory, user_agent, @@ -254,7 +254,7 @@ def __init__(self, hs, tls_client_options_factory): # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, + federation_agent, ip_blacklist=hs.config.federation_ip_range_blacklist, ) @@ -534,9 +534,10 @@ async def _send_request( response.code, response_phrase, body ) - # Retry if the error is a 429 (Too Many Requests), - # otherwise just raise a standard HttpResponseException - if response.code == 429: + # Retry if the error is a 5xx or a 429 (Too Many + # Requests), otherwise just raise a standard + # `HttpResponseException` + if 500 <= response.code < 600 or response.code == 429: raise RequestSendFailed(exc, can_retry=True) from exc else: raise exc diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 174ca7be5ac8..643492ceaf83 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -32,8 +32,9 @@ TCP4ClientEndpoint, TCP6ClientEndpoint, ) -from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport +from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint from twisted.internet.protocol import Factory, Protocol +from twisted.internet.tcp import Connection from twisted.python.failure import Failure logger = logging.getLogger(__name__) @@ -52,7 +53,9 @@ class LogProducer: format: A callable to format the log record to a string. """ - transport = attr.ib(type=ITransport) + # This is essentially ITCPTransport, but that is missing certain fields + # (connected and registerProducer) which are part of the implementation. + transport = attr.ib(type=Connection) _format = attr.ib(type=Callable[[logging.LogRecord], str]) _buffer = attr.ib(type=deque) _paused = attr.ib(default=False, type=bool, init=False) @@ -149,8 +152,6 @@ def _connect(self) -> None: if self._connection_waiter: return - self._connection_waiter = self._service.whenConnected(failAfterFailures=1) - def fail(failure: Failure) -> None: # If the Deferred was cancelled (e.g. during shutdown) do not try to # reconnect (this will cause an infinite loop of errors). @@ -163,9 +164,13 @@ def fail(failure: Failure) -> None: self._connect() def writer(result: Protocol) -> None: + # Force recognising transport as a Connection and not the more + # generic ITransport. + transport = result.transport # type: Connection # type: ignore + # We have a connection. If we already have a producer, and its # transport is the same, just trigger a resumeProducing. - if self._producer and result.transport is self._producer.transport: + if self._producer and transport is self._producer.transport: self._producer.resumeProducing() self._connection_waiter = None return @@ -177,14 +182,16 @@ def writer(result: Protocol) -> None: # Make a new producer and start it. self._producer = LogProducer( buffer=self._buffer, - transport=result.transport, + transport=transport, format=self.format, ) - result.transport.registerProducer(self._producer, True) + transport.registerProducer(self._producer, True) self._producer.resumeProducing() self._connection_waiter = None - self._connection_waiter.addCallbacks(writer, fail) + deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred + deferred.addCallbacks(writer, fail) + self._connection_waiter = deferred def _handle_pressure(self) -> None: """ diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 78e27bfb00ed..1a7ea4fa9629 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -669,7 +669,7 @@ def g(*args, **kwargs): return g -def run_in_background(f, *args, **kwargs): +def run_in_background(f, *args, **kwargs) -> defer.Deferred: """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the deferred returned by the function completes. @@ -697,8 +697,10 @@ def run_in_background(f, *args, **kwargs): if isinstance(res, types.CoroutineType): res = defer.ensureDeferred(res) + # At this point we should have a Deferred, if not then f was a synchronous + # function, wrap it in a Deferred for consistency. if not isinstance(res, defer.Deferred): - return res + return defer.succeed(res) if res.called and not res.paused: # The function should have maintained the logcontext, so we can diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index db2d400b7e80..781e02fbbb31 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -203,11 +203,26 @@ def record_user_external_id( ) def generate_short_term_login_token( - self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + self, + user_id: str, + duration_in_ms: int = (2 * 60 * 1000), + auth_provider_id: str = "", ) -> str: - """Generate a login token suitable for m.login.token authentication""" + """Generate a login token suitable for m.login.token authentication + + Args: + user_id: gives the ID of the user that the token is for + + duration_in_ms: the time that the token will be valid for + + auth_provider_id: the ID of the SSO IdP that the user used to authenticate + to get this token, if any. This is encoded in the token so that + /login can report stats on number of successful logins by IdP. + """ return self._hs.get_macaroon_generator().generate_short_term_login_token( - user_id, duration_in_ms + user_id, + auth_provider_id, + duration_in_ms, ) @defer.inlineCallbacks @@ -276,6 +291,7 @@ def complete_sso_login( """ self._auth_handler._complete_sso_login( registered_user_id, + "", request, client_redirect_url, ) @@ -286,6 +302,7 @@ async def complete_sso_login_async( request: SynapseRequest, client_redirect_url: str, new_user: bool = False, + auth_provider_id: str = "", ): """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that @@ -299,9 +316,15 @@ async def complete_sso_login_async( redirect them directly if whitelisted). new_user: set to true to use wording for the consent appropriate to a user who has just registered. + auth_provider_id: the ID of the SSO IdP which was used to log in. This + is used to track counts of sucessful logins by IdP. """ await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url, new_user=new_user + registered_user_id, + auth_provider_id, + request, + client_redirect_url, + new_user=new_user, ) @defer.inlineCallbacks diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 5fec2aaf5dc8..3dc06a79e865 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -16,8 +16,8 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional -from twisted.internet.base import DelayedCall from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from twisted.internet.interfaces import IDelayedCall from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, ThrottleParams @@ -66,7 +66,7 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer self.store = self.hs.get_datastore() self.email = pusher_config.pushkey - self.timed_call = None # type: Optional[DelayedCall] + self.timed_call = None # type: Optional[IDelayedCall] self.throttle_params = {} # type: Dict[str, ThrottleParams] self._inited = False diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 8a3f113e7640..b7aa0c280fe5 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -18,7 +18,7 @@ import re import urllib from inspect import signature -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple from prometheus_client import Counter, Gauge @@ -28,6 +28,9 @@ from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) _pending_outgoing_requests = Gauge( @@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): CACHE = True RETRY_ON_TIMEOUT = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): if self.CACHE: self.response_cache = ResponseCache( - hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 + hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) # type: ResponseCache[str] # We reserve `instance_name` as a parameter to sending requests, so we diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 36071feb36df..4ec1bfa6eaf9 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -61,7 +61,7 @@ async def _handle_request(self, request, user_id): is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] - device_id, access_token = await self.registration_handler.register_device( + res = await self.registration_handler.register_device_inner( user_id, device_id, initial_display_name, @@ -69,7 +69,7 @@ async def _handle_request(self, request, user_id): is_appservice_ghost=is_appservice_ghost, ) - return 200, {"device_id": device_id, "access_token": access_token} + return 200, res def register_servlets(hs, http_server): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index a7245da15232..a8894beadfd1 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -48,7 +48,7 @@ UserIpCommand, UserSyncCommand, ) -from synapse.replication.tcp.protocol import AbstractConnection +from synapse.replication.tcp.protocol import IReplicationConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, AccountDataStream, @@ -82,7 +82,7 @@ # the type of the entries in _command_queues_by_stream _StreamCommandQueue = Deque[ - Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] + Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection] ] @@ -174,7 +174,7 @@ def __init__(self, hs: "HomeServer"): # The currently connected connections. (The list of places we need to send # outgoing replication commands to.) - self._connections = [] # type: List[AbstractConnection] + self._connections = [] # type: List[IReplicationConnection] LaterGauge( "synapse_replication_tcp_resource_total_connections", @@ -197,7 +197,7 @@ def __init__(self, hs: "HomeServer"): # For each connection, the incoming stream names that have received a POSITION # from that connection. - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]] LaterGauge( "synapse_replication_tcp_command_queue", @@ -220,7 +220,7 @@ def __init__(self, hs: "HomeServer"): self._server_notices_sender = hs.get_server_notices_sender() def _add_command_to_stream_queue( - self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] ) -> None: """Queue the given received command for processing @@ -267,7 +267,7 @@ async def _unsafe_process_queue(self, stream_name: str): async def _process_command( self, cmd: Union[PositionCommand, RdataCommand], - conn: AbstractConnection, + conn: IReplicationConnection, stream_name: str, ) -> None: if isinstance(cmd, PositionCommand): @@ -302,7 +302,7 @@ def start_replication(self, hs): hs, outbound_redis_connection ) hs.get_reactor().connectTCP( - hs.config.redis.redis_host, + hs.config.redis.redis_host.encode(), hs.config.redis.redis_port, self._factory, ) @@ -311,7 +311,7 @@ def start_replication(self, hs): self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self._factory) + hs.get_reactor().connectTCP(host.encode(), port, self._factory) def get_streams(self) -> Dict[str, Stream]: """Get a map from stream name to all streams.""" @@ -321,10 +321,10 @@ def get_streams_to_replicate(self) -> List[Stream]: """Get a list of streams that this instances replicates.""" return self._streams_to_replicate - def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand): self.send_positions_to_connection(conn) - def send_positions_to_connection(self, conn: AbstractConnection): + def send_positions_to_connection(self, conn: IReplicationConnection): """Send current position of all streams this process is source of to the connection. """ @@ -347,7 +347,7 @@ def send_positions_to_connection(self, conn: AbstractConnection): ) def on_USER_SYNC( - self, conn: AbstractConnection, cmd: UserSyncCommand + self, conn: IReplicationConnection, cmd: UserSyncCommand ) -> Optional[Awaitable[None]]: user_sync_counter.inc() @@ -359,21 +359,23 @@ def on_USER_SYNC( return None def on_CLEAR_USER_SYNC( - self, conn: AbstractConnection, cmd: ClearUserSyncsCommand + self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand ) -> Optional[Awaitable[None]]: if self._is_master: return self._presence_handler.update_external_syncs_clear(cmd.instance_id) else: return None - def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): + def on_FEDERATION_ACK( + self, conn: IReplicationConnection, cmd: FederationAckCommand + ): federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.instance_name, cmd.token) def on_USER_IP( - self, conn: AbstractConnection, cmd: UserIpCommand + self, conn: IReplicationConnection, cmd: UserIpCommand ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() @@ -395,7 +397,7 @@ async def _handle_user_ip(self, cmd: UserIpCommand): assert self._server_notices_sender is not None await self._server_notices_sender.on_user_ip(cmd.user_id) - def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -412,7 +414,7 @@ def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): self._add_command_to_stream_queue(conn, cmd) async def _process_rdata( - self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand ) -> None: """Process an RDATA command @@ -486,7 +488,7 @@ async def on_rdata( stream_name, instance_name, token, rows ) - def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return @@ -496,7 +498,7 @@ def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): self._add_command_to_stream_queue(conn, cmd) async def _process_position( - self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand ) -> None: """Process a POSITION command @@ -553,7 +555,9 @@ async def _process_position( self._streams_by_connection.setdefault(conn, set()).add(stream_name) - def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): + def on_REMOTE_SERVER_UP( + self, conn: IReplicationConnection, cmd: RemoteServerUpCommand + ): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -576,7 +580,7 @@ def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpComma # between two instances, but that is not currently supported). self.send_command(cmd, ignore_conn=conn) - def new_connection(self, connection: AbstractConnection): + def new_connection(self, connection: IReplicationConnection): """Called when we have a new connection.""" self._connections.append(connection) @@ -603,7 +607,7 @@ def new_connection(self, connection: AbstractConnection): UserSyncCommand(self._instance_id, user_id, True, now) ) - def lost_connection(self, connection: AbstractConnection): + def lost_connection(self, connection: IReplicationConnection): """Called when a connection is closed/lost.""" # we no longer need _streams_by_connection for this connection. streams = self._streams_by_connection.pop(connection, None) @@ -624,7 +628,7 @@ def connected(self) -> bool: return bool(self._connections) def send_command( - self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None + self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None ): """Send a command to all connected connections. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index e0b4ad314dfe..825900f64c1d 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,7 +46,6 @@ > ERROR server stopping * connection closed by server * """ -import abc import fcntl import logging import struct @@ -54,8 +53,10 @@ from typing import TYPE_CHECKING, List, Optional from prometheus_client import Counter +from zope.interface import Interface, implementer from twisted.internet import task +from twisted.internet.tcp import Connection from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure @@ -121,6 +122,14 @@ class ConnectionStates: CLOSED = "closed" +class IReplicationConnection(Interface): + """An interface for replication connections.""" + + def send_command(cmd: Command): + """Send the command down the connection""" + + +@implementer(IReplicationConnection) class BaseReplicationStreamProtocol(LineOnlyReceiver): """Base replication protocol shared between client and server. @@ -137,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): (if they send a `PING` command) """ + # The transport is going to be an ITCPTransport, but that doesn't have the + # (un)registerProducer methods, those are only on the implementation. + transport = None # type: Connection + delimiter = b"\n" # Valid commands we expect to receive @@ -181,6 +194,7 @@ def connectionMade(self): connected_connections.append(self) # Register connection for metrics + assert self.transport is not None self.transport.registerProducer(self, True) # For the *Producing callbacks self._send_pending_commands() @@ -205,6 +219,7 @@ def send_ping(self): logger.info( "[%s] Failed to close connection gracefully, aborting", self.id() ) + assert self.transport is not None self.transport.abortConnection() else: if now - self.last_sent_command >= PING_TIME: @@ -294,6 +309,7 @@ def handle_command(self, cmd: Command) -> None: def close(self): logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() + assert self.transport is not None self.transport.loseConnection() self.on_connection_closed() @@ -391,6 +407,7 @@ def stopProducing(self): def connectionLost(self, reason): logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): + assert reason.type is not None connection_close_counter.labels(reason.type.__name__).inc() else: connection_close_counter.labels(reason.__class__.__name__).inc() @@ -495,20 +512,6 @@ def replicate(self): self.send_command(ReplicateCommand()) -class AbstractConnection(abc.ABC): - """An interface for replication connections.""" - - @abc.abstractmethod - def send_command(self, cmd: Command): - """Send the command down the connection""" - pass - - -# This tells python that `BaseReplicationStreamProtocol` implements the -# interface. -AbstractConnection.register(BaseReplicationStreamProtocol) - - # The following simply registers metrics for the replication connections pending_commands = LaterGauge( diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 0e6155cf530a..2f4d407f9482 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -19,6 +19,11 @@ import attr import txredisapi +from zope.interface import implementer + +from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.interfaces import IAddress, IConnector +from twisted.python.failure import Failure from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import ( @@ -32,7 +37,7 @@ parse_command_from_line, ) from synapse.replication.tcp.protocol import ( - AbstractConnection, + IReplicationConnection, tcp_inbound_commands_counter, tcp_outbound_commands_counter, ) @@ -62,7 +67,8 @@ def __set__(self, obj: Optional[T], value: V): pass -class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): +@implementer(IReplicationConnection) +class RedisSubscriber(txredisapi.SubscriberProtocol): """Connection to redis subscribed to replication stream. This class fulfils two functions: @@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): connection, parsing *incoming* messages into replication commands, and passing them to `ReplicationCommandHandler` - (b) it implements the AbstractConnection API, where it sends *outgoing* commands + (b) it implements the IReplicationConnection API, where it sends *outgoing* commands onto outbound_redis_connection. Due to the vagaries of `txredisapi` we don't want to have a custom @@ -253,6 +259,37 @@ async def _send_ping(self): except Exception: logger.warning("Failed to send ping to a redis connection") + # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but + # it's rubbish. We add our own here. + + def startedConnecting(self, connector: IConnector): + logger.info( + "Connecting to redis server %s", format_address(connector.getDestination()) + ) + super().startedConnecting(connector) + + def clientConnectionFailed(self, connector: IConnector, reason: Failure): + logger.info( + "Connection to redis server %s failed: %s", + format_address(connector.getDestination()), + reason.value, + ) + super().clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector: IConnector, reason: Failure): + logger.info( + "Connection to redis server %s lost: %s", + format_address(connector.getDestination()), + reason.value, + ) + super().clientConnectionLost(connector, reason) + + +def format_address(address: IAddress) -> str: + if isinstance(address, (IPv4Address, IPv6Address)): + return "%s:%i" % (address.host, address.port) + return str(address) + class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """This is a reconnecting factory that connects to redis and immediately @@ -328,6 +365,6 @@ def lazyConnection( factory.continueTrying = reconnect reactor = hs.get_reactor() - reactor.connectTCP(host, port, factory, 30) + reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None) return factory.handler diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index e09234c644bd..7681e55b5832 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -15,10 +15,9 @@ import re -import twisted.web.server - -import synapse.api.auth +from synapse.api.auth import Auth from synapse.api.errors import AuthError +from synapse.http.site import SynapseRequest from synapse.types import UserID @@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"): return patterns -async def assert_requester_is_admin( - auth: synapse.api.auth.Auth, request: twisted.web.server.Request -) -> None: +async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None: """Verify that the requester is an admin user Args: - auth: api.auth.Auth singleton + auth: Auth singleton request: incoming request Raises: @@ -53,11 +50,11 @@ async def assert_requester_is_admin( await assert_user_is_admin(auth, requester.user) -async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None: +async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: """Verify that the given user is an admin user Args: - auth: api.auth.Auth singleton + auth: Auth singleton user_id: user to check Raises: diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 511c859f646b..7fcc48a9d70f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,10 +17,9 @@ import logging from typing import TYPE_CHECKING, Tuple -from twisted.web.server import Request - from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_boolean, parse_integer +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, @@ -50,7 +49,9 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -75,7 +76,9 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -103,7 +106,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() async def on_POST( - self, request: Request, server_name: str, media_id: str + self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -127,7 +130,9 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, media_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -148,7 +153,9 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: @@ -166,7 +173,7 @@ def __init__(self, hs: "HomeServer"): self.media_repository = hs.get_media_repository() self.auth = hs.get_auth() - async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -189,7 +196,7 @@ def __init__(self, hs: "HomeServer"): self.media_repository = hs.get_media_repository() async def on_DELETE( - self, request: Request, server_name: str, media_id: str + self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -218,7 +225,9 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, server_name: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py index 8b7bb6d44ebe..49966ee3e044 100644 --- a/synapse/rest/admin/purge_room_servlet.py +++ b/synapse/rest/admin/purge_room_servlet.py @@ -12,13 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Tuple + from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin import assert_requester_is_admin from synapse.rest.admin._base import admin_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer class PurgeRoomServlet(RestServlet): @@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet): PATTERNS = admin_patterns("/purge_room$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.pagination_handler = hs.get_pagination_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f2c42a0f30a4..263d8ec07619 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -685,7 +685,10 @@ async def on_GET( results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], time_now + results["state"], + time_now, + # No need to bundle aggregations for state events + bundle_aggregations=False, ) return 200, results diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 375d0554455b..f495666f4a71 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -12,17 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Optional, Tuple + from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin import assert_requester_is_admin from synapse.rest.admin._base import admin_patterns from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer class SendServerNoticeServlet(RestServlet): @@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet): } """ - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) self.snm = hs.get_server_notices_manager() - def register(self, json_resource): + def register(self, json_resource: HttpServer): PATTERN = "/send_server_notice" json_resource.register_paths( "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ @@ -66,7 +69,9 @@ def register(self, json_resource): self.__class__.__name__, ) - async def on_POST(self, request, txn_id=None): + async def on_POST( + self, request: SynapseRequest, txn_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ("user_id", "content")) @@ -90,7 +95,7 @@ async def on_POST(self, request, txn_id=None): return 200, {"event_id": event.event_id} - def on_PUT(self, request, txn_id): + def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]: return self.txns.fetch_or_execute_request( request, self.on_POST, request, txn_id ) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 267a993430eb..2c89b62e254a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -269,7 +269,10 @@ async def on_PUT( target_user.to_string(), False, requester, by_admin=True ) elif not deactivate and user["deactivated"]: - if "password" not in body: + if ( + "password" not in body + and self.hs.config.password_localdb_enabled + ): raise SynapseError( 400, "Must provide a password to re-activate an account." ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 925edfc40239..e4c352f572a2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.api.urls import CLIENT_API_PREFIX from synapse.appservice import ApplicationService from synapse.handlers.sso import SsoIdentityProvider from synapse.http import get_request_uri @@ -94,11 +96,21 @@ def on_GET(self, request: SynapseRequest): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + sso_flow = { + "type": LoginRestServlet.SSO_TYPE, + "identity_providers": [ + _get_auth_flow_dict_for_idp( + idp, + ) + for idp in self._sso_handler.get_identity_providers().values() + ], + } # type: JsonDict if self._msc2858_enabled: + # backwards-compatibility support for clients which don't + # support the stable API yet sso_flow["org.matrix.msc2858.identity_providers"] = [ - _get_auth_flow_dict_for_idp(idp) + _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True) for idp in self._sso_handler.get_identity_providers().values() ] @@ -219,6 +231,7 @@ async def _complete_login( callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, ratelimit: bool = True, + auth_provider_id: Optional[str] = None, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -234,6 +247,8 @@ async def _complete_login( create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). Returns: result: Dictionary of account information after successful login. @@ -256,7 +271,7 @@ async def _complete_login( device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name + user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id ) result = { @@ -283,12 +298,13 @@ async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: """ token = login_submission["token"] auth_handler = self.auth_handler - user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( - token - ) + res = await auth_handler.validate_short_term_login_token(token) return await self._complete_login( - user_id, login_submission, self.auth_handler._sso_login_callback + res.user_id, + login_submission, + self.auth_handler._sso_login_callback, + auth_provider_id=res.auth_provider_id, ) async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: @@ -327,22 +343,38 @@ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: return result -def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: +def _get_auth_flow_dict_for_idp( + idp: SsoIdentityProvider, use_unstable_brands: bool = False +) -> JsonDict: """Return an entry for the login flow dict Returns an entry suitable for inclusion in "identity_providers" in the response to GET /_matrix/client/r0/login + + Args: + idp: the identity provider to describe + use_unstable_brands: whether we should use brand identifiers suitable + for the unstable API """ e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict if idp.idp_icon: e["icon"] = idp.idp_icon if idp.idp_brand: e["brand"] = idp.idp_brand + # use the stable brand identifier if the unstable identifier isn't defined. + if use_unstable_brands and idp.unstable_idp_brand: + e["brand"] = idp.unstable_idp_brand return e class SsoRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) + PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ + re.compile( + "^" + + CLIENT_API_PREFIX + + "/r0/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" + ) + ] def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they @@ -360,7 +392,8 @@ def __init__(self, hs: "HomeServer"): def register(self, http_server: HttpServer) -> None: super().register(http_server) if self._msc2858_enabled: - # expose additional endpoint for MSC2858 support + # expose additional endpoint for MSC2858 support: backwards-compat support + # for clients which don't yet support the stable endpoints. http_server.register_paths( "GET", client_patterns( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9a1df30c2999..5884daea6da2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -671,7 +671,10 @@ async def on_GET(self, request, room_id, event_id): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], time_now + results["state"], + time_now, + # No need to bundle aggregations for state events + bundle_aggregations=False, ) return 200, results diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 7aea4cebf576..5901432fadce 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -32,6 +32,7 @@ assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from synapse.types import GroupID, JsonDict from ._base import client_patterns @@ -70,7 +71,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -81,7 +84,9 @@ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: return 200, group_description @_validate_group_id - async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_POST( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -111,7 +116,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -144,7 +151,11 @@ def __init__(self, hs: "HomeServer"): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, category_id: Optional[str], room_id: str + self, + request: SynapseRequest, + group_id: str, + category_id: Optional[str], + room_id: str, ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -176,7 +187,7 @@ async def on_PUT( @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, category_id: str, room_id: str + self, request: SynapseRequest, group_id: str, category_id: str, room_id: str ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -206,7 +217,7 @@ def __init__(self, hs: "HomeServer"): @_validate_group_id async def on_GET( - self, request: Request, group_id: str, category_id: str + self, request: SynapseRequest, group_id: str, category_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -219,7 +230,7 @@ async def on_GET( @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, category_id: str + self, request: SynapseRequest, group_id: str, category_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -247,7 +258,7 @@ async def on_PUT( @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, category_id: str + self, request: SynapseRequest, group_id: str, category_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -274,7 +285,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -298,7 +311,7 @@ def __init__(self, hs: "HomeServer"): @_validate_group_id async def on_GET( - self, request: Request, group_id: str, role_id: str + self, request: SynapseRequest, group_id: str, role_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -311,7 +324,7 @@ async def on_GET( @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, role_id: str + self, request: SynapseRequest, group_id: str, role_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -339,7 +352,7 @@ async def on_PUT( @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, role_id: str + self, request: SynapseRequest, group_id: str, role_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -366,7 +379,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -399,7 +414,11 @@ def __init__(self, hs: "HomeServer"): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, role_id: Optional[str], user_id: str + self, + request: SynapseRequest, + group_id: str, + role_id: Optional[str], + user_id: str, ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -431,7 +450,7 @@ async def on_PUT( @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, role_id: str, user_id: str + self, request: SynapseRequest, group_id: str, role_id: str, user_id: str ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -458,7 +477,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -481,7 +502,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -504,7 +527,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -526,7 +551,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -554,7 +581,7 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -598,7 +625,7 @@ def __init__(self, hs: "HomeServer"): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, room_id: str + self, request: SynapseRequest, group_id: str, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -615,7 +642,7 @@ async def on_PUT( @_validate_group_id async def on_DELETE( - self, request: Request, group_id: str, room_id: str + self, request: SynapseRequest, group_id: str, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -646,7 +673,7 @@ def __init__(self, hs: "HomeServer"): @_validate_group_id async def on_PUT( - self, request: Request, group_id: str, room_id: str, config_key: str + self, request: SynapseRequest, group_id: str, room_id: str, config_key: str ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -678,7 +705,9 @@ def __init__(self, hs: "HomeServer"): self.is_mine_id = hs.is_mine_id @_validate_group_id - async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id, user_id + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -708,7 +737,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id, user_id + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -735,7 +766,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -762,7 +795,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -789,7 +824,9 @@ def __init__(self, hs: "HomeServer"): self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -816,7 +853,9 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @_validate_group_id - async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -839,7 +878,9 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) result = await self.groups_handler.get_publicised_groups_for_user(user_id) @@ -859,7 +900,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -881,7 +922,7 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 9039662f7e54..1eff98ef14c2 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -20,6 +20,7 @@ from twisted.web.server import Request from synapse.http.server import DirectServeJsonResource, respond_with_json +from synapse.http.site import SynapseRequest if TYPE_CHECKING: from synapse.app.homeserver import HomeServer @@ -35,7 +36,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 0641924f1887..8b4841ed5d6a 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -35,6 +35,7 @@ from synapse.config._base import ConfigError from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import random_string @@ -145,7 +146,7 @@ async def create_content( upload_name: Optional[str], content: IO, content_length: int, - auth_user: str, + auth_user: UserID, ) -> str: """Store uploaded content for a local user and return the mxc URL diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index a074e807dc0d..b8895aeaa93e 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -39,6 +39,7 @@ respond_with_json_bytes, ) from synapse.http.servlet import parse_integer, parse_string +from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers @@ -185,7 +186,7 @@ async def _async_render_OPTIONS(self, request: Request) -> None: request.setHeader(b"Allow", b"OPTIONS, GET") respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 07903e401761..988f52c78f73 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -96,9 +96,14 @@ def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]: def _resize(self, width: int, height: int) -> Image: # 1-bit or 8-bit color palette images need converting to RGB # otherwise they will be scaled using nearest neighbour which - # looks awful - if self.image.mode in ["1", "P"]: - self.image = self.image.convert("RGB") + # looks awful. + # + # If the image has transparency, use RGBA instead. + if self.image.mode in ["1", "L", "P"]: + mode = "RGB" + if self.image.info.get("transparency", None) is not None: + mode = "RGBA" + self.image = self.image.convert(mode) return self.image.resize((width, height), Image.ANTIALIAS) def scale(self, width: int, height: int, output_type: str) -> BytesIO: diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 5e104fac40fe..ae5aef2f7f9c 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -22,6 +22,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest from synapse.rest.media.v1.media_storage import SpamMediaException if TYPE_CHECKING: @@ -49,7 +50,7 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py index f6668fb5e3bf..4dfadf1bfb07 100644 --- a/synapse/rest/synapse/client/saml2/response_resource.py +++ b/synapse/rest/synapse/client/saml2/response_resource.py @@ -14,24 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.http.server import DirectServeHtmlResource +if TYPE_CHECKING: + from synapse.server import HomeServer + class SAML2ResponseResource(DirectServeHtmlResource): """A Twisted web resource which handles the SAML response""" isLeaf = 1 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self._saml_handler = hs.get_saml_handler() + self._sso_handler = hs.get_sso_handler() async def _async_render_GET(self, request): # We're not expecting any GET request on that resource if everything goes right, # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. # In this case, just tell the user that something went wrong and they should # try to authenticate again. - self._saml_handler._render_error( + self._sso_handler.render_error( request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" ) diff --git a/synapse/server.py b/synapse/server.py index afd7cd72e7e3..48ac87a124e2 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -36,7 +36,6 @@ cast, ) -import twisted.internet.base import twisted.internet.tcp from twisted.internet import defer from twisted.mail.smtp import sendmail @@ -130,7 +129,7 @@ from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import Databases, DataStore, Storage from synapse.streams.events import EventSources -from synapse.types import DomainSpecificString +from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.util.ratelimitutils import FederationRateLimiter @@ -291,7 +290,7 @@ def setup_background_tasks(self) -> None: for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP: getattr(self, "get_" + i + "_handler")() - def get_reactor(self) -> twisted.internet.base.ReactorBase: + def get_reactor(self) -> ISynapseReactor: """ Fetch the Twisted reactor in use by this HomeServer. """ @@ -352,11 +351,9 @@ def get_auth(self) -> Auth: @cache_in_self def get_http_client_context_factory(self) -> IPolicyForHTTPS: - return ( - InsecureInterceptableContextFactory() - if self.config.use_insecure_ssl_client_just_for_testing_do_not_use - else RegularPolicyForHTTPS() - ) + if self.config.use_insecure_ssl_client_just_for_testing_do_not_use: + return InsecureInterceptableContextFactory() + return RegularPolicyForHTTPS() @cache_in_self def get_simple_http_client(self) -> SimpleHttpClient: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 18ddb92fcca5..332193ad1c93 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -54,11 +54,12 @@ def __init__(self, database: DatabasePool, db_conn, hs): ) # type: LruCache[str, List[Tuple[str, int]]] async def get_auth_chain( - self, event_ids: Collection[str], include_given: bool = False + self, room_id: str, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: """Get auth events for given event_ids. The events *must* be state events. Args: + room_id: The room the event is in. event_ids: state events include_given: include the given events in result @@ -66,24 +67,44 @@ async def get_auth_chain( list of events """ event_ids = await self.get_auth_chain_ids( - event_ids, include_given=include_given + room_id, event_ids, include_given=include_given ) return await self.get_events_as_list(event_ids) async def get_auth_chain_ids( self, + room_id: str, event_ids: Collection[str], include_given: bool = False, ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. Args: + room_id: The room the event is in. event_ids: state events include_given: include the given events in result Returns: - An awaitable which resolve to a list of event_ids + list of event_ids """ + + # Check if we have indexed the room so we can use the chain cover + # algorithm. + room = await self.get_room(room_id) + if room["has_auth_chain_index"]: + try: + return await self.db_pool.runInteraction( + "get_auth_chain_ids_chains", + self._get_auth_chain_ids_using_cover_index_txn, + room_id, + event_ids, + include_given, + ) + except _NoChainCoverIndex: + # For whatever reason we don't actually have a chain cover index + # for the events in question, so we fall back to the old method. + pass + return await self.db_pool.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, @@ -91,9 +112,130 @@ async def get_auth_chain_ids( include_given, ) + def _get_auth_chain_ids_using_cover_index_txn( + self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool + ) -> List[str]: + """Calculates the auth chain IDs using the chain index.""" + + # First we look up the chain ID/sequence numbers for the given events. + + initial_events = set(event_ids) + + # All the events that we've found that are reachable from the events. + seen_events = set() # type: Set[str] + + # A map from chain ID to max sequence number of the given events. + event_chains = {} # type: Dict[int, int] + + sql = """ + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE %s + """ + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + for event_id, chain_id, sequence_number in txn: + seen_events.add(event_id) + event_chains[chain_id] = max( + sequence_number, event_chains.get(chain_id, 0) + ) + + # Check that we actually have a chain ID for all the events. + events_missing_chain_info = initial_events.difference(seen_events) + if events_missing_chain_info: + # This can happen due to e.g. downgrade/upgrade of the server. We + # raise an exception and fall back to the previous algorithm. + logger.info( + "Unexpectedly found that events don't have chain IDs in room %s: %s", + room_id, + events_missing_chain_info, + ) + raise _NoChainCoverIndex(room_id) + + # Now we look up all links for the chains we have, adding chains that + # are reachable from any event. + sql = """ + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM event_auth_chain_links + WHERE %s + """ + + # A map from chain ID to max sequence number *reachable* from any event ID. + chains = {} # type: Dict[int, int] + + # Add all linked chains reachable from initial set of chains. + for batch in batch_iter(event_chains, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch + ) + txn.execute(sql % (clause,), args) + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + # chains are only reachable if the origin sequence number of + # the link is less than the max sequence number in the + # origin chain. + if origin_sequence_number <= event_chains.get(origin_chain_id, 0): + chains[target_chain_id] = max( + target_sequence_number, + chains.get(target_chain_id, 0), + ) + + # Add the initial set of chains, excluding the sequence corresponding to + # initial event. + for chain_id, seq_no in event_chains.items(): + chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0)) + + # Now for each chain we figure out the maximum sequence number reachable + # from *any* event ID. Events with a sequence less than that are in the + # auth chain. + if include_given: + results = initial_events + else: + results = set() + + if isinstance(self.database_engine, PostgresEngine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) + WHERE + c.chain_id = l.chain_id + AND sequence_number <= max_seq + """ + + rows = txn.execute_values(sql, chains.items()) + results.update(r for r, in rows) + else: + # For SQLite we just fall back to doing a noddy for loop. + sql = """ + SELECT event_id FROM event_auth_chains + WHERE chain_id = ? AND sequence_number <= ? + """ + for chain_id, max_no in chains.items(): + txn.execute(sql, (chain_id, max_no)) + results.update(r for r, in txn) + + return list(results) + def _get_auth_chain_ids_txn( self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool ) -> List[str]: + """Calculates the auth chain IDs. + + This is used when we don't have a cover index for the room. + """ if include_given: results = set(event_ids) else: diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index cb6b1f8a0c17..78367ea58ded 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -135,6 +135,11 @@ def __init__(self, database: DatabasePool, db_conn, hs): self._chain_cover_index, ) + self.db_pool.updates.register_background_update_handler( + "purged_chain_cover", + self._purged_chain_cover_index, + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -932,3 +937,77 @@ def _calculate_chain_cover_txn( processed_count=count, finished_room_map=finished_rooms, ) + + async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int: + """ + A background updates that iterates over the chain cover and deletes the + chain cover for events that have been purged. + + This may be due to fully purging a room or via setting a retention policy. + """ + current_event_id = progress.get("current_event_id", "") + + def purged_chain_cover_txn(txn) -> int: + # The event ID from events will be null if the chain ID / sequence + # number points to a purged event. + sql = """ + SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL + FROM event_auth_chains + LEFT JOIN events AS e USING (event_id) + WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ? + """ + txn.execute(sql, (current_event_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + # The event IDs and chain IDs / sequence numbers where the event has + # been purged. + unreferenced_event_ids = [] + unreferenced_chain_id_tuples = [] + event_id = "" + for event_id, chain_id, sequence_number, has_event in rows: + if not has_event: + unreferenced_event_ids.append((event_id,)) + unreferenced_chain_id_tuples.append((chain_id, sequence_number)) + + # Delete the unreferenced auth chains from event_auth_chain_links and + # event_auth_chains. + txn.executemany( + """ + DELETE FROM event_auth_chains WHERE event_id = ? + """, + unreferenced_event_ids, + ) + # We should also delete matching target_*, but there is no index on + # target_chain_id. Hopefully any purged events are due to a room + # being fully purged and they will be removed from the origin_* + # searches. + txn.executemany( + """ + DELETE FROM event_auth_chain_links WHERE + origin_chain_id = ? AND origin_sequence_number = ? + """, + unreferenced_chain_id_tuples, + ) + + progress = { + "current_event_id": event_id, + } + + self.db_pool.updates._background_update_progress_txn( + txn, "purged_chain_cover", progress + ) + + return len(rows) + + result = await self.db_pool.runInteraction( + "_purged_chain_cover_index", + purged_chain_cover_txn, + ) + + if not result: + await self.db_pool.updates._end_background_update("purged_chain_cover") + + return result diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 36df56ba2c33..bf7a9a845ca6 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools + import logging import threading from collections import namedtuple @@ -1147,7 +1147,8 @@ async def have_seen_events(self, event_ids): Returns: set[str]: The events we have already seen. """ - results = set() + # if the event cache contains the event, obviously we've seen it. + results = {x for x in event_ids if self._get_event_cache.contains(x)} def have_seen_events_txn(txn, chunk): sql = "SELECT event_id FROM events as e WHERE " @@ -1155,12 +1156,9 @@ def have_seen_events_txn(txn, chunk): txn.database_engine, "e.event_id", chunk ) txn.execute(sql + clause, args) - for (event_id,) in txn: - results.add(event_id) + results.update(row[0] for row in txn) - # break the input up into chunks of 100 - input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): + for chunk in batch_iter((x for x in event_ids if x not in results), 100): await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 0836e4af4934..41f4fe7f95cd 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -331,13 +331,9 @@ def _purge_room_txn(self, txn, room_id: str) -> List[int]: txn.executemany( """ DELETE FROM event_auth_chain_links WHERE - (origin_chain_id = ? AND origin_sequence_number = ?) OR - (target_chain_id = ? AND target_sequence_number = ?) + origin_chain_id = ? AND origin_sequence_number = ? """, - ( - (chain_id, seq_num, chain_id, seq_num) - for (chain_id, seq_num) in referenced_chain_id_tuples - ), + referenced_chain_id_tuples, ) # Now we delete tables which lack an index on room_id but have one on event_id diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 61a7556e5689..eba66ff3527a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import attr @@ -1510,7 +1510,7 @@ def f(txn): async def user_delete_access_tokens( self, user_id: str, - except_token_id: Optional[str] = None, + except_token_id: Optional[int] = None, device_id: Optional[str] = None, ) -> List[Tuple[str, int, Optional[str]]]: """ @@ -1533,7 +1533,7 @@ def f(txn): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) - values = [v for _, v in items] + values = [v for _, v in items] # type: List[Union[str, int]] if except_token_id: where_clause += " AND id != ?" values.append(except_token_id) diff --git a/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql new file mode 100644 index 000000000000..87cb1f3cfdeb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql @@ -0,0 +1,17 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5910, 'purged_chain_cover', '{}'); diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index b921d63d3051..030966184140 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -350,11 +350,11 @@ def _store_destination_rooms_entries_txn( self.db_pool.simple_upsert_many_txn( txn, - "destination_rooms", - ["destination", "room_id"], - rows, - ["stream_ordering"], - [(stream_ordering,)] * len(rows), + table="destination_rooms", + key_names=("destination", "room_id"), + key_values=rows, + value_names=["stream_ordering"], + value_values=[(stream_ordering,)] * len(rows), ) async def get_destination_last_successful_stream_ordering( diff --git a/synapse/types.py b/synapse/types.py index 721343f0b5ea..b08ce9014028 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -35,6 +35,14 @@ import attr from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 +from zope.interface import Interface + +from twisted.internet.interfaces import ( + IReactorCore, + IReactorPluggableNameResolver, + IReactorTCP, + IReactorTime, +) from synapse.api.errors import Codes, SynapseError from synapse.util.stringutils import parse_and_validate_server_name @@ -67,33 +75,40 @@ class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore JsonDict = Dict[str, Any] -class Requester( - namedtuple( - "Requester", - [ - "user", - "access_token_id", - "is_guest", - "shadow_banned", - "device_id", - "app_service", - "authenticated_entity", - ], - ) +# Note that this seems to require inheriting *directly* from Interface in order +# for mypy-zope to realize it is an interface. +class ISynapseReactor( + IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface ): + """The interfaces necessary for Synapse to function.""" + + +@attr.s(frozen=True, slots=True) +class Requester: """ Represents the user making a request Attributes: - user (UserID): id of the user making the request - access_token_id (int|None): *ID* of the access token used for this + user: id of the user making the request + access_token_id: *ID* of the access token used for this request, or None if it came via the appservice API or similar - is_guest (bool): True if the user making this request is a guest user - shadow_banned (bool): True if the user making this request has been shadow-banned. - device_id (str|None): device_id which was set at authentication time - app_service (ApplicationService|None): the AS requesting on behalf of the user + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request has been shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. """ + user = attr.ib(type="UserID") + access_token_id = attr.ib(type=Optional[int]) + is_guest = attr.ib(type=bool) + shadow_banned = attr.ib(type=bool) + device_id = attr.ib(type=Optional[str]) + app_service = attr.ib(type=Optional["ApplicationService"]) + authenticated_entity = attr.ib(type=str) + def serialize(self): """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -141,23 +156,23 @@ def deserialize(store, input): def create_requester( user_id: Union[str, "UserID"], access_token_id: Optional[int] = None, - is_guest: Optional[bool] = False, - shadow_banned: Optional[bool] = False, + is_guest: bool = False, + shadow_banned: bool = False, device_id: Optional[str] = None, app_service: Optional["ApplicationService"] = None, authenticated_entity: Optional[str] = None, -): +) -> Requester: """ Create a new ``Requester`` object Args: - user_id (str|UserID): id of the user making the request - access_token_id (int|None): *ID* of the access token used for this + user_id: id of the user making the request + access_token_id: *ID* of the access token used for this request, or None if it came via the appservice API or similar - is_guest (bool): True if the user making this request is a guest user - shadow_banned (bool): True if the user making this request is shadow-banned. - device_id (str|None): device_id which was set at authentication time - app_service (ApplicationService|None): the AS requesting on behalf of the user + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request is shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user authenticated_entity: The entity that authenticated when making the request. This is different to the user_id when an admin user or the server is "puppeting" the user. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 719e35b78d72..f33c115844db 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -76,11 +76,16 @@ def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False): def callback(r): object.__setattr__(self, "_result", (True, r)) while self._observers: + observer = self._observers.pop() try: - # TODO: Handle errors here. - self._observers.pop().callback(r) - except Exception: - pass + observer.callback(r) + except Exception as e: + logger.exception( + "%r threw an exception on .callback(%r), ignoring...", + observer, + r, + exc_info=e, + ) return r def errback(f): @@ -90,11 +95,16 @@ def errback(f): # traces when we `await` on one of the observer deferreds. f.value.__failure__ = f + observer = self._observers.pop() try: - # TODO: Handle errors here. - self._observers.pop().errback(f) - except Exception: - pass + observer.errback(f) + except Exception as e: + logger.exception( + "%r threw an exception on .errback(%r), ignoring...", + observer, + f, + exc_info=e, + ) if consumeErrors: return None diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 32228f42ee59..46ea8e09644c 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -13,17 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar +from typing import Any, Callable, Dict, Generic, Optional, TypeVar from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.util import Clock from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache -if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer - logger = logging.getLogger(__name__) T = TypeVar("T") @@ -37,11 +35,11 @@ class ResponseCache(Generic[T]): used rather than trying to compute a new response. """ - def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): + def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): # Requests that haven't finished yet. self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] - self.clock = hs.get_clock() + self.clock = clock self.timeout_sec = timeout_ms / 1000.0 self._name = name diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py new file mode 100644 index 000000000000..12cdd53327cd --- /dev/null +++ b/synapse/util/macaroons.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for manipulating macaroons""" + +from typing import Callable, Optional + +import pymacaroons +from pymacaroons.exceptions import MacaroonVerificationFailedException + + +def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + """Extracts a caveat value from a macaroon token. + + Checks that there is exactly one caveat of the form "key = " in the macaroon, + and returns the extracted value. + + Args: + macaroon: the token + key: the key of the caveat to extract + + Returns: + The extracted value + + Raises: + MacaroonVerificationFailedException: if there are conflicting values for the + caveat in the macaroon, or if the caveat was not found in the macaroon. + """ + prefix = key + " = " + result = None # type: Optional[str] + for caveat in macaroon.caveats: + if not caveat.caveat_id.startswith(prefix): + continue + + val = caveat.caveat_id[len(prefix) :] + + if result is None: + # first time we found this caveat: record the value + result = val + elif val != result: + # on subsequent occurrences, raise if the value is different. + raise MacaroonVerificationFailedException( + "Conflicting values for caveat " + key + ) + + if result is not None: + return result + + # If the caveat is not there, we raise a MacaroonVerificationFailedException. + # Note that it is insecure to generate a macaroon without all the caveats you + # might need (because there is nothing stopping people from adding extra caveats), + # so if the caveat isn't there, something odd must be going on. + raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,)) + + +def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None: + """Make a macaroon verifier which accepts 'time' caveats + + Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to + the given macaroon verifier. + + Args: + v: the macaroon verifier + get_time_ms: a callable which will return the timestamp after which the caveat + should be considered expired. Normally the current time. + """ + + def verify_expiry_caveat(caveat: str): + time_msec = get_time_ms() + prefix = "time < " + if not caveat.startswith(prefix): + return False + expiry = int(caveat[len(prefix) :]) + return time_msec < expiry + + v.satisfy_general(verify_expiry_caveat) diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 1a3ccb263dae..6f96cd79403e 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -7,6 +7,7 @@ from synapse.federation.units import Edu from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable from tests.unittest import FederatingHomeserverTestCase, override_config @@ -49,7 +50,7 @@ async def record_transaction(self, txn, json_cb): else: data = json_cb() self.failed_pdus.extend(data["pdus"]) - raise IOError("Failed to connect because this is a test!") + raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination) def get_destination_room(self, room: str, destination: str = "host2") -> dict: """ diff --git a/tests/handlers/oidc_test_key.p8 b/tests/handlers/oidc_test_key.p8 new file mode 100644 index 000000000000..bb929763331f --- /dev/null +++ b/tests/handlers/oidc_test_key.p8 @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp +Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA +P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW +-----END PRIVATE KEY----- diff --git a/tests/handlers/oidc_test_key.pub.pem b/tests/handlers/oidc_test_key.pub.pem new file mode 100644 index 000000000000..176d4a4b4bf7 --- /dev/null +++ b/tests/handlers/oidc_test_key.pub.pem @@ -0,0 +1,4 @@ +-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi +QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg== +-----END PUBLIC KEY----- diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 0e42013bb901..c9f889b5117d 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -68,38 +68,45 @@ def verify_nonce(caveat): v.verify(macaroon, self.hs.config.macaroon_secret_key) def test_short_term_login_token_gives_user_id(self): - token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) - user_id = self.get_success( - self.auth_handler.validate_short_term_login_token_and_get_user_id(token) + token = self.macaroon_generator.generate_short_term_login_token( + "a_user", "", 5000 ) - self.assertEqual("a_user", user_id) + res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) + self.assertEqual("a_user", res.user_id) + self.assertEqual("", res.auth_provider_id) # when we advance the clock, the token should be rejected self.reactor.advance(6) self.get_failure( - self.auth_handler.validate_short_term_login_token_and_get_user_id(token), + self.auth_handler.validate_short_term_login_token(token), AuthError, ) + def test_short_term_login_token_gives_auth_provider(self): + token = self.macaroon_generator.generate_short_term_login_token( + "a_user", auth_provider_id="my_idp" + ) + res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) + self.assertEqual("a_user", res.user_id) + self.assertEqual("my_idp", res.auth_provider_id) + def test_short_term_login_token_cannot_replace_user_id(self): - token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) + token = self.macaroon_generator.generate_short_term_login_token( + "a_user", "", 5000 + ) macaroon = pymacaroons.Macaroon.deserialize(token) - user_id = self.get_success( - self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() - ) + res = self.get_success( + self.auth_handler.validate_short_term_login_token(macaroon.serialize()) ) - self.assertEqual("a_user", user_id) + self.assertEqual("a_user", res.user_id) # add another "user_id" caveat, which might allow us to override the # user_id. macaroon.add_first_party_caveat("user_id = b_user") self.get_failure( - self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() - ), + self.auth_handler.validate_short_term_login_token(macaroon.serialize()), AuthError, ) @@ -113,7 +120,7 @@ def test_mau_limits_disabled(self): ) self.get_success( - self.auth_handler.validate_short_term_login_token_and_get_user_id( + self.auth_handler.validate_short_term_login_token( self._get_macaroon().serialize() ) ) @@ -135,7 +142,7 @@ def test_mau_limits_exceeded_large(self): return_value=make_awaitable(self.large_number_of_users) ) self.get_failure( - self.auth_handler.validate_short_term_login_token_and_get_user_id( + self.auth_handler.validate_short_term_login_token( self._get_macaroon().serialize() ), ResourceLimitError, @@ -159,7 +166,7 @@ def test_mau_limits_parity(self): ResourceLimitError, ) self.get_failure( - self.auth_handler.validate_short_term_login_token_and_get_user_id( + self.auth_handler.validate_short_term_login_token( self._get_macaroon().serialize() ), ResourceLimitError, @@ -175,7 +182,7 @@ def test_mau_limits_parity(self): ) ) self.get_success( - self.auth_handler.validate_short_term_login_token_and_get_user_id( + self.auth_handler.validate_short_term_login_token( self._get_macaroon().serialize() ) ) @@ -197,11 +204,13 @@ def test_mau_limits_not_exceeded(self): return_value=make_awaitable(self.small_number_of_users) ) self.get_success( - self.auth_handler.validate_short_term_login_token_and_get_user_id( + self.auth_handler.validate_short_term_login_token( self._get_macaroon().serialize() ) ) def _get_macaroon(self): - token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000) + token = self.macaroon_generator.generate_short_term_login_token( + "user_a", "", 5000 + ) return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 6f992291b8ba..7975af243c7c 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -66,7 +66,7 @@ def test_map_cas_user_to_user(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None, new_user=True + "@test_user:test", "cas", request, "redirect_uri", None, new_user=True ) def test_map_cas_user_to_existing_user(self): @@ -89,7 +89,7 @@ def test_map_cas_user_to_existing_user(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None, new_user=False + "@test_user:test", "cas", request, "redirect_uri", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -98,7 +98,7 @@ def test_map_cas_user_to_existing_user(self): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None, new_user=False + "@test_user:test", "cas", request, "redirect_uri", None, new_user=False ) def test_map_cas_user_to_invalid_localpart(self): @@ -116,7 +116,7 @@ def test_map_cas_user_to_invalid_localpart(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True + "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True ) @override_config( @@ -160,7 +160,7 @@ def test_required_attributes(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None, new_user=True + "@test_user:test", "cas", request, "redirect_uri", None, new_user=True ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index cf1de28fa9b2..5e9c9c2e88eb 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import Optional +import os from urllib.parse import parse_qs, urlparse from mock import ANY, Mock, patch @@ -23,6 +23,7 @@ from synapse.handlers.sso import MappingException from synapse.server import HomeServer from synapse.types import UserID +from synapse.util.macaroons import get_value_from_macaroon from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -50,7 +51,18 @@ JWKS_URI = ISSUER + ".well-known/jwks.json" # config for common cases -COMMON_CONFIG = { +DEFAULT_CONFIG = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, +} + +# extends the default config with explicit OAuth2 endpoints instead of using discovery +EXPLICIT_ENDPOINT_CONFIG = { + **DEFAULT_CONFIG, "discover": False, "authorization_endpoint": AUTHORIZATION_ENDPOINT, "token_endpoint": TOKEN_ENDPOINT, @@ -107,6 +119,32 @@ async def get_json(url): return {"keys": []} +def _key_file_path() -> str: + """path to a file containing the private half of a test key""" + + # this key was generated with: + # openssl ecparam -name prime256v1 -genkey -noout | + # openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8 + # + # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because + # that's what Apple use, and we want to be sure that we work with Apple's keys. + # + # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing + # keys using ASN.1. Both are then typically formatted using PEM, which says: use the + # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't + # really need to care about any of that.) + return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8") + + +def _public_key_file_path() -> str: + """path to a file containing the public half of a test key""" + # this was generated with: + # openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem + # + # See above about where oidc_test_key.p8 came from + return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem") + + class OidcHandlerTestCase(HomeserverTestCase): if not HAS_OIDC: skip = "requires OIDC" @@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase): def default_config(self): config = super().default_config() config["public_baseurl"] = BASE_URL - oidc_config = { - "enabled": True, - "client_id": CLIENT_ID, - "client_secret": CLIENT_SECRET, - "issuer": ISSUER, - "scopes": SCOPES, - "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, - } - - # Update this config with what's in the default config so that - # override_config works as expected. - oidc_config.update(config.get("oidc_config", {})) - config["oidc_config"] = oidc_config - return config def make_homeserver(self, reactor, clock): @@ -170,13 +194,14 @@ def assertRenderedError(self, error, error_description=None): self.render_error.reset_mock() return args + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" self.assertEqual(self.provider._callback_url, CALLBACK_URL) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) - @override_config({"oidc_config": {"discover": True}}) + @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) def test_discovery(self): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid @@ -195,13 +220,13 @@ def test_discovery(self): self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() - @override_config({"oidc_config": COMMON_CONFIG}) + @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self): """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() - @override_config({"oidc_config": COMMON_CONFIG}) + @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_load_jwks(self): """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) @@ -236,6 +261,7 @@ async def patched_load_metadata(): self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" h = self.provider @@ -318,13 +344,14 @@ async def force_load(): # Shouldn't raise with a valid userinfo, even without jwks force_load_metadata() - @override_config({"oidc_config": {"skip_verification": True}}) + @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) def test_skip_verification(self): """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw get_awaitable_result(self.provider.load_metadata()) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["cookies"]) @@ -360,20 +387,15 @@ def test_redirect_request(self): self.assertEqual(name, b"oidc_session") macaroon = pymacaroons.Macaroon.deserialize(cookie) - state = self.handler._token_generator._get_value_from_macaroon( - macaroon, "state" - ) - nonce = self.handler._token_generator._get_value_from_macaroon( - macaroon, "nonce" - ) - redirect = self.handler._token_generator._get_value_from_macaroon( - macaroon, "client_redirect_url" - ) + state = get_value_from_macaroon(macaroon, "state") + nonce = get_value_from_macaroon(macaroon, "nonce") + redirect = get_value_from_macaroon(macaroon, "client_redirect_url") self.assertEqual(params["state"], [state]) self.assertEqual(params["nonce"], [nonce]) self.assertEqual(redirect, "http://client/redirect") + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_error(self): """Errors from the provider returned in the callback are displayed.""" request = Mock(args={}) @@ -385,6 +407,7 @@ def test_callback_error(self): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_client", "some description") + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback(self): """Code callback works and display errors if something went wrong. @@ -434,7 +457,7 @@ def test_callback(self): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, new_user=True + expected_user_id, "oidc", request, client_redirect_url, None, new_user=True ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -465,7 +488,7 @@ def test_callback(self): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, new_user=False + expected_user_id, "oidc", request, client_redirect_url, None, new_user=False ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() @@ -486,6 +509,7 @@ def test_callback(self): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self): """The callback verifies the session presence and validity""" request = Mock(spec=["args", "getCookie", "cookies"]) @@ -528,7 +552,9 @@ def test_callback_session(self): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") - @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) + @override_config( + {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} + ) def test_exchange_code(self): """Code exchange behaves correctly and handles various error scenarios.""" token = {"type": "bearer"} @@ -613,9 +639,105 @@ def test_exchange_code(self): @override_config( { "oidc_config": { + "enabled": True, + "client_id": CLIENT_ID, + "issuer": ISSUER, + "client_auth_method": "client_secret_post", + "client_secret_jwt_key": { + "key_file": _key_file_path(), + "jwt_header": {"alg": "ES256", "kid": "ABC789"}, + "jwt_payload": {"iss": "DEFGHI"}, + }, + } + } + ) + def test_exchange_code_jwt_key(self): + """Test that code exchange works with a JWK client secret.""" + from authlib.jose import jwt + + token = {"type": "bearer"} + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") + ) + ) + code = "code" + + # advance the clock a bit before we start, so we aren't working with zero + # timestamps. + self.reactor.advance(1000) + start_time = self.reactor.seconds() + ret = self.get_success(self.provider._exchange_code(code)) + + self.assertEqual(ret, token) + + # the request should have hit the token endpoint + kwargs = self.http_client.request.call_args[1] + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + + # the client secret provided to the should be a jwt which can be checked with + # the public key + args = parse_qs(kwargs["data"].decode("utf-8")) + secret = args["client_secret"][0] + with open(_public_key_file_path()) as f: + key = f.read() + claims = jwt.decode(secret, key) + self.assertEqual(claims.header["kid"], "ABC789") + self.assertEqual(claims["aud"], ISSUER) + self.assertEqual(claims["iss"], "DEFGHI") + self.assertEqual(claims["sub"], CLIENT_ID) + self.assertEqual(claims["iat"], start_time) + self.assertGreater(claims["exp"], start_time) + + # check the rest of the POSTed data + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + + @override_config( + { + "oidc_config": { + "enabled": True, + "client_id": CLIENT_ID, + "issuer": ISSUER, + "client_auth_method": "none", + } + } + ) + def test_exchange_code_no_auth(self): + """Test that code exchange works with no client secret.""" + token = {"type": "bearer"} + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") + ) + ) + code = "code" + ret = self.get_success(self.provider._exchange_code(code)) + + self.assertEqual(ret, token) + + # the request should have hit the token endpoint + kwargs = self.http_client.request.call_args[1] + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + + # check the POSTed data + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, "user_mapping_provider": { "module": __name__ + ".TestMappingProviderExtra" - } + }, } } ) @@ -651,12 +773,14 @@ def test_extra_attributes(self): auth_handler.complete_sso_login.assert_called_once_with( "@foo:test", + "oidc", request, client_redirect_url, {"phone": "1234567"}, new_user=True, ) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self): """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" auth_handler = self.hs.get_auth_handler() @@ -668,7 +792,7 @@ def test_map_userinfo_to_user(self): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", ANY, ANY, None, new_user=True + "@test_user:test", "oidc", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -679,7 +803,7 @@ def test_map_userinfo_to_user(self): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", ANY, ANY, None, new_user=True + "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -697,7 +821,7 @@ def test_map_userinfo_to_user(self): "Mapping provider does not support de-duplicating Matrix IDs", ) - @override_config({"oidc_config": {"allow_existing_users": True}}) + @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" store = self.hs.get_datastore() @@ -716,14 +840,14 @@ def test_map_userinfo_to_existing_user(self): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, new_user=False + user.to_string(), "oidc", ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, new_user=False + user.to_string(), "oidc", ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -738,7 +862,7 @@ def test_map_userinfo_to_existing_user(self): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, new_user=False + user.to_string(), "oidc", ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -774,9 +898,10 @@ def test_map_userinfo_to_existing_user(self): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", ANY, ANY, None, new_user=False + "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False ) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" self.get_success( @@ -787,9 +912,10 @@ def test_map_userinfo_to_invalid_localpart(self): @override_config( { "oidc_config": { + **DEFAULT_CONFIG, "user_mapping_provider": { "module": __name__ + ".TestMappingProviderFailures" - } + }, } } ) @@ -810,7 +936,7 @@ def test_map_userinfo_to_user_retries(self): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", ANY, ANY, None, new_user=True + "@test_user1:test", "oidc", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -834,6 +960,7 @@ def test_map_userinfo_to_user_retries(self): "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_empty_localpart(self): """Attempts to map onto an empty localpart should be rejected.""" userinfo = { @@ -846,9 +973,10 @@ def test_empty_localpart(self): @override_config( { "oidc_config": { + **DEFAULT_CONFIG, "user_mapping_provider": { "config": {"localpart_template": "{{ user.username }}"} - } + }, } } ) @@ -866,7 +994,7 @@ def _generate_oidc_session_token( state: str, nonce: str, client_redirect_url: str, - ui_auth_session_id: Optional[str] = None, + ui_auth_session_id: str = "", ) -> str: from synapse.handlers.oidc_handler import OidcSessionData @@ -909,6 +1037,7 @@ async def _make_callback_with_userinfo( idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url, + ui_auth_session_id="", ), ) request = _build_callback_request("code", state, session) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index bdf3d0a8a298..94b69035945b 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -517,6 +517,37 @@ def check_registration_for_spam( self.assertTrue(requester.shadow_banned) + def test_spam_checker_receives_sso_type(self): + """Test rejecting registration based on SSO type""" + + class BanBadIdPUser: + def check_registration_for_spam( + self, email_threepid, username, request_info, auth_provider_id=None + ): + # Reject any user coming from CAS and whose username contains profanity + if auth_provider_id == "cas" and "flimflob" in username: + return RegistrationBehaviour.DENY + return RegistrationBehaviour.ALLOW + + # Configure a spam checker that denies a certain user on a specific IdP + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanBadIdPUser()] + + f = self.get_failure( + self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"), + SynapseError, + ) + exception = f.value + + # We return 429 from the spam checker for denied registrations + self.assertIsInstance(exception, SynapseError) + self.assertEqual(exception.code, 429) + + # Check the same username can register using SAML + self.get_success( + self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml") + ) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 029af2853e17..30efd43b4060 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -131,7 +131,7 @@ def test_map_saml_response_to_user(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None, new_user=True + "@test_user:test", "saml", request, "redirect_uri", None, new_user=True ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -157,7 +157,7 @@ def test_map_saml_response_to_existing_user(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "", None, new_user=False + "@test_user:test", "saml", request, "", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -166,7 +166,7 @@ def test_map_saml_response_to_existing_user(self): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "", None, new_user=False + "@test_user:test", "saml", request, "", None, new_user=False ) def test_map_saml_response_to_invalid_localpart(self): @@ -214,7 +214,7 @@ def test_map_saml_response_to_user_retries(self): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", request, "", None, new_user=True + "@test_user1:test", "saml", request, "", None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -310,7 +310,7 @@ def test_attribute_requirements(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None, new_user=True + "@test_user:test", "saml", request, "redirect_uri", None, new_user=True ) diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 21ecb81c9926..0ce181a51e94 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -16,12 +16,23 @@ from mock import Mock +from netaddr import IPSet + +from twisted.internet.error import DNSLookupError from twisted.python.failure import Failure -from twisted.web.client import ResponseDone +from twisted.test.proto_helpers import AccumulatingProtocol +from twisted.web.client import Agent, ResponseDone from twisted.web.iweb import UNKNOWN_LENGTH -from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size +from synapse.api.errors import SynapseError +from synapse.http.client import ( + BlacklistingAgentWrapper, + BlacklistingReactorWrapper, + BodyExceededMaxSize, + read_body_with_max_size, +) +from tests.server import FakeTransport, get_clock from tests.unittest import TestCase @@ -119,3 +130,114 @@ def test_content_length(self): # The data is never consumed. self.assertEqual(result.getvalue(), b"") + + +class BlacklistingAgentTest(TestCase): + def setUp(self): + self.reactor, self.clock = get_clock() + + self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" + self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8" + self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1" + + # Configure the reactor's DNS resolver. + for (domain, ip) in ( + (self.safe_domain, self.safe_ip), + (self.unsafe_domain, self.unsafe_ip), + (self.allowed_domain, self.allowed_ip), + ): + self.reactor.lookups[domain.decode()] = ip.decode() + self.reactor.lookups[ip.decode()] = ip.decode() + + self.ip_whitelist = IPSet([self.allowed_ip.decode()]) + self.ip_blacklist = IPSet(["5.0.0.0/8"]) + + def test_reactor(self): + """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs.""" + agent = Agent( + BlacklistingReactorWrapper( + self.reactor, + ip_whitelist=self.ip_whitelist, + ip_blacklist=self.ip_blacklist, + ), + ) + + # The unsafe domains and IPs should be rejected. + for domain in (self.unsafe_domain, self.unsafe_ip): + self.failureResultOf( + agent.request(b"GET", b"http://" + domain), DNSLookupError + ) + + # The safe domains IPs should be accepted. + for domain in ( + self.safe_domain, + self.allowed_domain, + self.safe_ip, + self.allowed_ip, + ): + d = agent.request(b"GET", b"http://" + domain) + + # Grab the latest TCP connection. + ( + host, + port, + client_factory, + _timeout, + _bindAddress, + ) = self.reactor.tcpClients[-1] + + # Make the connection and pump data through it. + client = client_factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" + ) + + response = self.successResultOf(d) + self.assertEqual(response.code, 200) + + def test_agent(self): + """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs.""" + agent = BlacklistingAgentWrapper( + Agent(self.reactor), + ip_whitelist=self.ip_whitelist, + ip_blacklist=self.ip_blacklist, + ) + + # The unsafe IPs should be rejected. + self.failureResultOf( + agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError + ) + + # The safe and unsafe domains and safe IPs should be accepted. + for domain in ( + self.safe_domain, + self.unsafe_domain, + self.allowed_domain, + self.safe_ip, + self.allowed_ip, + ): + d = agent.request(b"GET", b"http://" + domain) + + # Grab the latest TCP connection. + ( + host, + port, + client_factory, + _timeout, + _bindAddress, + ) = self.reactor.tcpClients[-1] + + # Make the connection and pump data through it. + client = client_factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" + ) + + response = self.successResultOf(d) + self.assertEqual(response.code, 200) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index a81607ed044d..f136c00ea7ba 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Tuple - -import attr +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel from twisted.web.resource import Resource +from twisted.web.server import Request, Site from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, @@ -32,7 +31,10 @@ from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.replication.tcp.resource import ( + ReplicationStreamProtocolFactory, + ServerReplicationStreamProtocol, +) from synapse.server import HomeServer from synapse.util import Clock @@ -59,7 +61,9 @@ def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() - self.server = server_factory.buildProtocol(None) + self.server = server_factory.buildProtocol( + None + ) # type: ServerReplicationStreamProtocol # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" @@ -152,12 +156,8 @@ def handle_http_replication_attempt(self) -> SynapseRequest: # Set up client side protocol client_protocol = client_factory.buildProtocol(None) - request_factory = OneShotRequestFactory() - # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor) - channel.requestFactory = request_factory - channel.site = self.site + channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -179,7 +179,7 @@ def handle_http_replication_attempt(self) -> SynapseRequest: server_to_client_transport.loseConnection() client_to_server_transport.loseConnection() - return request_factory.request + return channel.request def assert_request_is_get_repl_stream_updates( self, request: SynapseRequest, stream_name: str @@ -188,8 +188,9 @@ def assert_request_is_get_repl_stream_updates( fetching updates for given stream. """ + path = request.path # type: bytes # type: ignore self.assertRegex( - request.path, + path, br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" % (stream_name.encode("ascii"),), ) @@ -232,7 +233,7 @@ def setUp(self): if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback( - "localhost", + b"localhost", 6379, self.connect_any_redis_attempts, ) @@ -387,12 +388,8 @@ def _handle_http_replication_attempt(self, hs, repl_port): # Set up client side protocol client_protocol = client_factory.buildProtocol(None) - request_factory = OneShotRequestFactory() - # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor) - channel.requestFactory = request_factory - channel.site = self._hs_to_site[hs] + channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs]) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -418,7 +415,7 @@ def connect_any_redis_attempts(self): clients = self.reactor.tcpClients while clients: (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") + self.assertEqual(host, b"localhost") self.assertEqual(port, 6379) client_protocol = client_factory.buildProtocol(None) @@ -450,21 +447,6 @@ async def on_rdata(self, stream_name, instance_name, token, rows): self.received_rdata_rows.append((stream_name, token, r)) -@attr.s() -class OneShotRequestFactory: - """A simple request factory that generates a single `SynapseRequest` and - stores it for future use. Can only be used once. - """ - - request = attr.ib(default=None) - - def __call__(self, *args, **kwargs): - assert self.request is None - - self.request = SynapseRequest(*args, **kwargs) - return self.request - - class _PushHTTPChannel(HTTPChannel): """A HTTPChannel that wraps pull producers to push producers. @@ -475,9 +457,13 @@ class _PushHTTPChannel(HTTPChannel): makes it very hard to test. """ - def __init__(self, reactor: IReactorTime): + def __init__( + self, reactor: IReactorTime, request_factory: Type[Request], site: Site + ): super().__init__() self.reactor = reactor + self.requestFactory = request_factory + self.site = site self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] @@ -503,6 +489,11 @@ def checkPersistence(self, request, version): request.responseHeaders.setRawHeaders(b"connection", [b"close"]) return False + def requestDone(self, request): + # Store the request for inspection. + self.request = request + super().requestDone(request) + class _PullToPushProducer: """A push producer that wraps a pull producer.""" @@ -590,6 +581,8 @@ def buildProtocol(self, addr): class FakeRedisPubSubProtocol(Protocol): """A connection from a client talking to the fake Redis server.""" + transport = None # type: Optional[FakeTransport] + def __init__(self, server: FakeRedisPubSubServer): self._server = server self._reader = hiredis.Reader() @@ -638,6 +631,8 @@ def handle_command(self, command, *args): def send(self, msg): """Send a message back to the client.""" + assert self.transport is not None + raw = self.encode(msg).encode("utf-8") self.transport.write(raw) diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index f235f1bd83e2..0d9e3bb11dba 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -17,7 +17,7 @@ from synapse.app.generic_worker import GenericWorkerServer from synapse.replication.tcp.commands import FederationAckCommand -from synapse.replication.tcp.protocol import AbstractConnection +from synapse.replication.tcp.protocol import IReplicationConnection from synapse.replication.tcp.streams.federation import FederationStream from tests.unittest import HomeserverTestCase @@ -51,8 +51,10 @@ def test_federation_ack_sent(self): """ rch = self.hs.get_tcp_replication() - # wire up the ReplicationCommandHandler to a mock connection - mock_connection = mock.Mock(spec=AbstractConnection) + # wire up the ReplicationCommandHandler to a mock connection, which needs + # to implement IReplicationConnection. (Note that Mock doesn't understand + # interfaces, but casing an interface to a list gives the attributes.) + mock_connection = mock.Mock(spec=list(IReplicationConnection)) rch.new_connection(mock_connection) # tell it it received an RDATA row diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 20af3285bd9f..988821b16f4d 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -437,14 +437,16 @@ def test_get_login_flows(self): channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) - expected_flows = [ - {"type": "m.login.cas"}, - {"type": "m.login.sso"}, - {"type": "m.login.token"}, - {"type": "m.login.password"}, - ] + ADDITIONAL_LOGIN_FLOWS + expected_flow_types = [ + "m.login.cas", + "m.login.sso", + "m.login.token", + "m.login.password", + ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] - self.assertCountEqual(channel.json_body["flows"], expected_flows) + self.assertCountEqual( + [f["type"] for f in channel.json_body["flows"]], expected_flow_types + ) @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_get_msc2858_login_flows(self): @@ -636,22 +638,25 @@ def test_multi_sso_redirect_to_unknown(self): ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_msc2858_disabled(self): - """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - - @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_unknown(self): """If the client tries to pick an unknown IdP, return a 404""" - channel = self._make_sso_redirect_request(True, "xxx") + channel = self._make_sso_redirect_request(False, "xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_oidc(self): """If the client pick a known IdP, redirect to it""" + channel = self._make_sso_redirect_request(False, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_msc2858_redirect_to_oidc(self): + """Test the unstable API""" channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 302, channel.result) oidc_uri = channel.headers.getRawHeaders("Location")[0] @@ -660,6 +665,12 @@ def test_client_idp_redirect_to_oidc(self): # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + def _make_sso_redirect_request( self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None ): diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 36d1e6bc4a4c..9f77125fd445 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -105,7 +105,7 @@ def test_ensure_media_is_in_local_cache(self): self.assertEqual(test_body, body) -@attr.s +@attr.s(slots=True, frozen=True) class _TestImage: """An image for testing thumbnailing with the expected results @@ -117,13 +117,15 @@ class _TestImage: test should just check for success. expected_scaled: The expected bytes from scaled thumbnailing, or None if test should just check for a valid image returned. + expected_found: True if the file should exist on the server, or False if + a 404 is expected. """ data = attr.ib(type=bytes) content_type = attr.ib(type=bytes) extension = attr.ib(type=bytes) - expected_cropped = attr.ib(type=Optional[bytes]) - expected_scaled = attr.ib(type=Optional[bytes]) + expected_cropped = attr.ib(type=Optional[bytes], default=None) + expected_scaled = attr.ib(type=Optional[bytes], default=None) expected_found = attr.ib(default=True, type=bool) @@ -153,6 +155,21 @@ class _TestImage: ), ), ), + # small png with transparency. + ( + _TestImage( + unhexlify( + b"89504e470d0a1a0a0000000d49484452000000010000000101000" + b"00000376ef9240000000274524e5300010194fdae0000000a4944" + b"4154789c636800000082008177cd72b60000000049454e44ae426" + b"082" + ), + b"image/png", + b".png", + # Note that we don't check the output since it varies across + # different versions of Pillow. + ), + ), # small lossless webp ( _TestImage( @@ -162,8 +179,6 @@ class _TestImage: ), b"image/webp", b".webp", - None, - None, ), ), # an empty file @@ -172,9 +187,7 @@ class _TestImage: b"", b"image/gif", b".gif", - None, - None, - False, + expected_found=False, ), ), ], diff --git a/tests/server.py b/tests/server.py index 939a0008ca2e..2287d200763b 100644 --- a/tests/server.py +++ b/tests/server.py @@ -16,6 +16,7 @@ IReactorPluggableNameResolver, IReactorTCP, IResolverSimple, + ITransport, ) from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock @@ -188,7 +189,7 @@ def getResourceFor(self, request): def make_request( reactor, - site: Site, + site: Union[Site, FakeSite], method, path, content=b"", @@ -467,6 +468,7 @@ def get_clock(): return clock, hs_clock +@implementer(ITransport) @attr.s(cmp=False) class FakeTransport: """ diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 06000f81a63d..d597d712d675 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -118,8 +118,7 @@ def insert_event(txn, i, room_id): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) - @parameterized.expand([(True,), (False,)]) - def test_auth_difference(self, use_chain_cover_index: bool): + def _setup_auth_chain(self, use_chain_cover_index: bool) -> str: room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -165,7 +164,7 @@ def test_auth_difference(self, use_chain_cover_index: bool): "j": 1, } - # Mark the room as not having a cover index + # Mark the room as maybe having a cover index. def store_room(txn): self.store.db_pool.simple_insert_txn( @@ -222,6 +221,77 @@ def insert_event(txn): ) ) + return room_id + + @parameterized.expand([(True,), (False,)]) + def test_auth_chain_ids(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(use_chain_cover_index) + + # a and b have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["a", "b"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + # d and e have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"])) + self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"])) + self.assertEqual(auth_chain_ids, ["k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"])) + self.assertEqual(auth_chain_ids, ["j"]) + + # j and k have no parents. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"])) + self.assertEqual(auth_chain_ids, []) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"])) + self.assertEqual(auth_chain_ids, []) + + # More complex input sequences. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "c", "d"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["h", "i"]) + ) + self.assertCountEqual(auth_chain_ids, ["k", "j"]) + + # e gets returned even though include_given is false, but it is in the + # auth chain of b. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "e"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + # Test include_given. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["i"], include_given=True) + ) + self.assertCountEqual(auth_chain_ids, ["i", "j"]) + + @parameterized.expand([(True,), (False,)]) + def test_auth_difference(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(use_chain_cover_index) + # Now actually test that various combinations give the right result: difference = self.get_success( diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index a06ad2c03e9b..41af8c484784 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from synapse.api.errors import NotFoundError +from synapse.api.errors import NotFoundError, SynapseError from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -33,9 +31,12 @@ def make_homeserver(self, reactor, clock): def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) - def test_purge(self): + self.store = hs.get_datastore() + self.storage = self.hs.get_storage() + + def test_purge_history(self): """ - Purging a room will delete everything before the topological point. + Purging a room history will delete everything before the topological point. """ # Send four messages to the room first = self.helper.send(self.room_id, body="test1") @@ -43,30 +44,27 @@ def test_purge(self): third = self.helper.send(self.room_id, body="test3") last = self.helper.send(self.room_id, body="test4") - store = self.hs.get_datastore() - storage = self.hs.get_storage() - # Get the topological token token = self.get_success( - store.get_topological_token_for_event(last["event_id"]) + self.store.get_topological_token_for_event(last["event_id"]) ) token_str = self.get_success(token.to_string(self.hs.get_datastore())) # Purge everything before this topological token self.get_success( - storage.purge_events.purge_history(self.room_id, token_str, True) + self.storage.purge_events.purge_history(self.room_id, token_str, True) ) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted # and last is not. - self.get_failure(store.get_event(first["event_id"]), NotFoundError) - self.get_failure(store.get_event(second["event_id"]), NotFoundError) - self.get_failure(store.get_event(third["event_id"]), NotFoundError) - self.get_success(store.get_event(last["event_id"])) + self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) + self.get_failure(self.store.get_event(second["event_id"]), NotFoundError) + self.get_failure(self.store.get_event(third["event_id"]), NotFoundError) + self.get_success(self.store.get_event(last["event_id"])) - def test_purge_wont_delete_extrems(self): + def test_purge_history_wont_delete_extrems(self): """ - Purging a room will delete everything before the topological point. + Purging a room history will delete everything before the topological point. """ # Send four messages to the room first = self.helper.send(self.room_id, body="test1") @@ -74,22 +72,43 @@ def test_purge_wont_delete_extrems(self): third = self.helper.send(self.room_id, body="test3") last = self.helper.send(self.room_id, body="test4") - storage = self.hs.get_datastore() - # Set the topological token higher than it should be token = self.get_success( - storage.get_topological_token_for_event(last["event_id"]) + self.store.get_topological_token_for_event(last["event_id"]) ) event = "t{}-{}".format(token.topological + 1, token.stream + 1) # Purge everything before this topological token - purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) - self.pump() - f = self.failureResultOf(purge) + f = self.get_failure( + self.storage.purge_events.purge_history(self.room_id, event, True), + SynapseError, + ) self.assertIn("greater than forward", f.value.args[0]) # Try and get the events - self.get_success(storage.get_event(first["event_id"])) - self.get_success(storage.get_event(second["event_id"])) - self.get_success(storage.get_event(third["event_id"])) - self.get_success(storage.get_event(last["event_id"])) + self.get_success(self.store.get_event(first["event_id"])) + self.get_success(self.store.get_event(second["event_id"])) + self.get_success(self.store.get_event(third["event_id"])) + self.get_success(self.store.get_event(last["event_id"])) + + def test_purge_room(self): + """ + Purging a room will delete everything about it. + """ + # Send four messages to the room + first = self.helper.send(self.room_id, body="test1") + + # Get the current room state. + state_handler = self.hs.get_state_handler() + create_event = self.get_success( + state_handler.get_current_state(self.room_id, "m.room.create", "") + ) + self.assertIsNotNone(create_event) + + # Purge everything before this topological token + self.get_success(self.storage.purge_events.purge_room(self.room_id)) + + # The events aren't found. + self.store._invalidate_get_event_cache(create_event.event_id) + self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) + self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 52ae5c57137f..74568b34f8c8 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") - self.tx_log.emit( + self.tx_log.emit( # type: ignore twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry ) diff --git a/tests/util/caches/test_responsecache.py b/tests/util/caches/test_responsecache.py new file mode 100644 index 000000000000..f9a187b8defc --- /dev/null +++ b/tests/util/caches/test_responsecache.py @@ -0,0 +1,131 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.caches.response_cache import ResponseCache + +from tests.server import get_clock +from tests.unittest import TestCase + + +class DeferredCacheTestCase(TestCase): + """ + A TestCase class for ResponseCache. + + The test-case function naming has some logic to it in it's parts, here's some notes about it: + wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available + (Generally these just use .delayed_return instead of .instant_return in it's wrapped call.) + expire: Denotes tests that test expiry after assured existence. + (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock) + """ + + def setUp(self): + self.reactor, self.clock = get_clock() + + def with_cache(self, name: str, ms: int = 0) -> ResponseCache: + return ResponseCache(self.clock, name, timeout_ms=ms) + + @staticmethod + async def instant_return(o: str) -> str: + return o + + async def delayed_return(self, o: str) -> str: + await self.clock.sleep(1) + return o + + def test_cache_hit(self): + cache = self.with_cache("keeping_cache", ms=9001) + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.instant_return, expected_result) + + self.assertEqual( + expected_result, + self.successResultOf(wrap_d), + "initial wrap result should be the same", + ) + self.assertEqual( + expected_result, + self.successResultOf(cache.get(0)), + "cache should have the result", + ) + + def test_cache_miss(self): + cache = self.with_cache("trashing_cache", ms=0) + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.instant_return, expected_result) + + self.assertEqual( + expected_result, + self.successResultOf(wrap_d), + "initial wrap result should be the same", + ) + self.assertIsNone(cache.get(0), "cache should not have the result now") + + def test_cache_expire(self): + cache = self.with_cache("short_cache", ms=1000) + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.instant_return, expected_result) + + self.assertEqual(expected_result, self.successResultOf(wrap_d)) + self.assertEqual( + expected_result, + self.successResultOf(cache.get(0)), + "cache should still have the result", + ) + + # cache eviction timer is handled + self.reactor.pump((2,)) + + self.assertIsNone(cache.get(0), "cache should not have the result now") + + def test_cache_wait_hit(self): + cache = self.with_cache("neutral_cache") + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.delayed_return, expected_result) + self.assertNoResult(wrap_d) + + # function wakes up, returns result + self.reactor.pump((2,)) + + self.assertEqual(expected_result, self.successResultOf(wrap_d)) + + def test_cache_wait_expire(self): + cache = self.with_cache("medium_cache", ms=3000) + + expected_result = "howdy" + + wrap_d = cache.wrap(0, self.delayed_return, expected_result) + self.assertNoResult(wrap_d) + + # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2 + self.reactor.pump((1, 1)) + + self.assertEqual(expected_result, self.successResultOf(wrap_d)) + self.assertEqual( + expected_result, + self.successResultOf(cache.get(0)), + "cache should still have the result", + ) + + # (1 + 1 + 2) > 3.0, cache eviction timer is handled + self.reactor.pump((2,)) + + self.assertIsNone(cache.get(0), "cache should not have the result now") diff --git a/tox.ini b/tox.ini index a6d10537ae54..9ff70fe312c9 100644 --- a/tox.ini +++ b/tox.ini @@ -189,7 +189,5 @@ commands= [testenv:mypy] deps = {[base]deps} - # Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513 - twisted==20.3.0 extras = all,mypy commands = mypy