diff --git a/CHANGES.md b/CHANGES.md index f794c585b72b..bee4d6baba1b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,44 @@ +Next version +============ + +* A new template (`sso_auth_confirm.html`) was added to Synapse. If your Synapse + is configured to use SSO and a custom `sso_redirect_confirm_template_dir` + configuration then this template will need to be duplicated into that + directory. + +Synapse 1.12.3 (2020-04-03) +=========================== + +- Remove the the pin to Pillow 7.0 which was introduced in Synapse 1.12.2, and +correctly fix the issue with building the Debian packages. ([\#7212](https://github.com/matrix-org/synapse/issues/7212)) + +Synapse 1.12.2 (2020-04-02) +=========================== + +This release works around [an +issue](https://github.com/matrix-org/synapse/issues/7208) with building the +debian packages. + +No other significant changes since 1.12.1. + +>>>>>>> master + +Synapse 1.12.1 (2020-04-02) +=========================== + +No significant changes since 1.12.1rc1. + + +Synapse 1.12.1rc1 (2020-03-31) +============================== + +Bugfixes +-------- + +- Fix starting workers when federation sending not split out. ([\#7133](https://github.com/matrix-org/synapse/issues/7133)). Introduced in v1.12.0. +- Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo. ([\#7155](https://github.com/matrix-org/synapse/issues/7155)). Introduced in v1.12.0rc1. +- Fix a bug which could cause outbound federation traffic to stop working if a client uploaded an incorrect e2e device signature. ([\#7177](https://github.com/matrix-org/synapse/issues/7177)). Introduced in v1.11.0. + Synapse 1.12.0 (2020-03-23) =========================== diff --git a/INSTALL.md b/INSTALL.md index f9e13b4cf672..b8f8a6732904 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -2,7 +2,6 @@ - [Installing Synapse](#installing-synapse) - [Installing from source](#installing-from-source) - [Platform-Specific Instructions](#platform-specific-instructions) - - [Troubleshooting Installation](#troubleshooting-installation) - [Prebuilt packages](#prebuilt-packages) - [Setting up Synapse](#setting-up-synapse) - [TLS certificates](#tls-certificates) @@ -10,6 +9,7 @@ - [Registering a user](#registering-a-user) - [Setting up a TURN server](#setting-up-a-turn-server) - [URL previews](#url-previews) +- [Troubleshooting Installation](#troubleshooting-installation) # Choosing your server name @@ -36,7 +36,7 @@ that your email address is probably `user@example.com` rather than System requirements: - POSIX-compliant system (tested on Linux & OS X) -- Python 3.5, 3.6, 3.7 or 3.8. +- Python 3.5.2 or later, up to Python 3.8. - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org Synapse is written in Python but some of the libraries it uses are written in @@ -70,7 +70,7 @@ pip install -U matrix-synapse ``` Before you can start Synapse, you will need to generate a configuration -file. To do this, run (in your virtualenv, as before):: +file. To do this, run (in your virtualenv, as before): ``` cd ~/synapse @@ -84,22 +84,24 @@ python -m synapse.app.homeserver \ ... substituting an appropriate value for `--server-name`. This command will generate you a config file that you can then customise, but it will -also generate a set of keys for you. These keys will allow your Home Server to -identify itself to other Home Servers, so don't lose or delete them. It would be +also generate a set of keys for you. These keys will allow your homeserver to +identify itself to other homeserver, so don't lose or delete them. It would be wise to back them up somewhere safe. (If, for whatever reason, you do need to -change your Home Server's keys, you may find that other Home Servers have the +change your homeserver's keys, you may find that other homeserver have the old key cached. If you update the signing key, you should change the name of the key in the `.signing.key` file (the second word) to something different. See the [spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) -for more information on key management.) +for more information on key management). To actually run your new homeserver, pick a working directory for Synapse to -run (e.g. `~/synapse`), and:: +run (e.g. `~/synapse`), and: - cd ~/synapse - source env/bin/activate - synctl start +``` +cd ~/synapse +source env/bin/activate +synctl start +``` ### Platform-Specific Instructions @@ -110,7 +112,7 @@ Installing prerequisites on Ubuntu or Debian: ``` sudo apt-get install build-essential python3-dev libffi-dev \ python3-pip python3-setuptools sqlite3 \ - libssl-dev python3-virtualenv libjpeg-dev libxslt1-dev + libssl-dev virtualenv libjpeg-dev libxslt1-dev ``` #### ArchLinux @@ -188,7 +190,7 @@ doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \ There is currently no port for OpenBSD. Additionally, OpenBSD's security settings require a slightly more difficult installation process. -XXX: I suspect this is out of date. +(XXX: I suspect this is out of date) 1. Create a new directory in `/usr/local` called `_synapse`. Also, create a new user called `_synapse` and set that directory as the new user's home. @@ -196,7 +198,7 @@ XXX: I suspect this is out of date. write and execute permissions on the same memory space to be run from `/usr/local`. 2. `su` to the new `_synapse` user and change to their home directory. -3. Create a new virtualenv: `virtualenv -p python2.7 ~/.synapse` +3. Create a new virtualenv: `virtualenv -p python3 ~/.synapse` 4. Source the virtualenv configuration located at `/usr/local/_synapse/.synapse/bin/activate`. This is done in `ksh` by using the `.` command, rather than `bash`'s `source`. @@ -217,45 +219,6 @@ be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server for Windows Server. -### Troubleshooting Installation - -XXX a bunch of this is no longer relevant. - -Synapse requires pip 8 or later, so if your OS provides too old a version you -may need to manually upgrade it:: - - sudo pip install --upgrade pip - -Installing may fail with `Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)`. -You can fix this by manually upgrading pip and virtualenv:: - - sudo pip install --upgrade virtualenv - -You can next rerun `virtualenv -p python3 synapse` to update the virtual env. - -Installing may fail during installing virtualenv with `InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.` -You can fix this by manually installing ndg-httpsclient:: - - pip install --upgrade ndg-httpsclient - -Installing may fail with `mock requires setuptools>=17.1. Aborting installation`. -You can fix this by upgrading setuptools:: - - pip install --upgrade setuptools - -If pip crashes mid-installation for reason (e.g. lost terminal), pip may -refuse to run until you remove the temporary installation directory it -created. To reset the installation:: - - rm -rf /tmp/pip_install_matrix - -pip seems to leak *lots* of memory during installation. For instance, a Linux -host with 512MB of RAM may run out of memory whilst installing Twisted. If this -happens, you will have to individually install the dependencies which are -failing, e.g.:: - - pip install twisted - ## Prebuilt packages As an alternative to installing from source, prebuilt packages are available @@ -314,7 +277,7 @@ For `buster` and `sid`, Synapse is available in the Debian repositories and it should be possible to install it with simply: ``` - sudo apt install matrix-synapse +sudo apt install matrix-synapse ``` There is also a version of `matrix-synapse` in `stretch-backports`. Please see @@ -375,8 +338,10 @@ sudo pip install py-bcrypt Synapse can be found in the void repositories as 'synapse': - xbps-install -Su - xbps-install -S synapse +``` +xbps-install -Su +xbps-install -S synapse +``` ### FreeBSD @@ -420,6 +385,7 @@ so, you will need to edit `homeserver.yaml`, as follows: resources: - names: [client, federation] ``` + * You will also need to uncomment the `tls_certificate_path` and `tls_private_key_path` lines under the `TLS` section. You can either point these settings at an existing certificate and key, or you can @@ -427,15 +393,15 @@ so, you will need to edit `homeserver.yaml`, as follows: for having Synapse automatically provision and renew federation certificates through ACME can be found at [ACME.md](docs/ACME.md). Note that, as pointed out in that document, this feature will not - work with installs set up after November 2019. - + work with installs set up after November 2019. + If you are using your own certificate, be sure to use a `.pem` file that includes the full certificate chain including any intermediate certificates (for instance, if using certbot, use `fullchain.pem` as your certificate, not `cert.pem`). For a more detailed guide to configuring your server for federation, see -[federate.md](docs/federate.md) +[federate.md](docs/federate.md). ## Email @@ -482,7 +448,7 @@ on your server even if `enable_registration` is `false`. ## Setting up a TURN server For reliable VoIP calls to be routed via this homeserver, you MUST configure -a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details. +a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details. ## URL previews @@ -491,10 +457,24 @@ turn it on you must enable the `url_preview_enabled: True` config parameter and explicitly specify the IP ranges that Synapse is not allowed to spider for previewing in the `url_preview_ip_range_blacklist` configuration parameter. This is critical from a security perspective to stop arbitrary Matrix users -spidering 'internal' URLs on your network. At the very least we recommend that +spidering 'internal' URLs on your network. At the very least we recommend that your loopback and RFC1918 IP addresses are blacklisted. -This also requires the optional lxml and netaddr python dependencies to be -installed. This in turn requires the libxml2 library to be available - on +This also requires the optional `lxml` and `netaddr` python dependencies to be +installed. This in turn requires the `libxml2` library to be available - on Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for your OS. + +# Troubleshooting Installation + +`pip` seems to leak *lots* of memory during installation. For instance, a Linux +host with 512MB of RAM may run out of memory whilst installing Twisted. If this +happens, you will have to individually install the dependencies which are +failing, e.g.: + +``` +pip install twisted +``` + +If you have any other problems, feel free to ask in +[#synapse:matrix.org](https://matrix.to/#/#synapse:matrix.org). diff --git a/changelog.d/6573.bugfix b/changelog.d/6573.bugfix new file mode 100644 index 000000000000..1bb8014db795 --- /dev/null +++ b/changelog.d/6573.bugfix @@ -0,0 +1 @@ +Don't attempt to use an invalid sqlite config if no database configuration is provided. Contributed by @nekatak. diff --git a/changelog.d/6639.bugfix b/changelog.d/6639.bugfix new file mode 100644 index 000000000000..c7593a6e8443 --- /dev/null +++ b/changelog.d/6639.bugfix @@ -0,0 +1 @@ +Fix missing field `default` when fetching user-defined push rules. diff --git a/changelog.d/6892.doc b/changelog.d/6892.doc new file mode 100644 index 000000000000..0d04cf0bdb5d --- /dev/null +++ b/changelog.d/6892.doc @@ -0,0 +1 @@ +Update Debian installation instructions to recommend installing the `virtualenv` package instead of `python3-virtualenv`. \ No newline at end of file diff --git a/changelog.d/6946.bugfix b/changelog.d/6946.bugfix new file mode 100644 index 000000000000..a238c83a18f7 --- /dev/null +++ b/changelog.d/6946.bugfix @@ -0,0 +1 @@ +Transfer alias mappings on room upgrade. \ No newline at end of file diff --git a/changelog.d/7024.misc b/changelog.d/7024.misc new file mode 100644 index 000000000000..676f285377f5 --- /dev/null +++ b/changelog.d/7024.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/changelog.d/7051.feature b/changelog.d/7051.feature new file mode 100644 index 000000000000..3e36a3f65e40 --- /dev/null +++ b/changelog.d/7051.feature @@ -0,0 +1 @@ +Admin API `POST /_synapse/admin/v1/join/` to join users to a room like `auto_join_rooms` for creation of users. \ No newline at end of file diff --git a/changelog.d/7068.bugfix b/changelog.d/7068.bugfix new file mode 100644 index 000000000000..d1693a7f2248 --- /dev/null +++ b/changelog.d/7068.bugfix @@ -0,0 +1 @@ +Ensure that a user inteactive authentication session is tied to a single request. diff --git a/changelog.d/7096.feature b/changelog.d/7096.feature new file mode 100644 index 000000000000..00f47b2a14a5 --- /dev/null +++ b/changelog.d/7096.feature @@ -0,0 +1 @@ +Add options to prevent users from changing their profile or associated 3PIDs. \ No newline at end of file diff --git a/changelog.d/7102.feature b/changelog.d/7102.feature new file mode 100644 index 000000000000..01057aa396ba --- /dev/null +++ b/changelog.d/7102.feature @@ -0,0 +1 @@ +Support SSO in the user interactive authentication workflow. diff --git a/changelog.d/7118.feature b/changelog.d/7118.feature new file mode 100644 index 000000000000..5cbfd981607b --- /dev/null +++ b/changelog.d/7118.feature @@ -0,0 +1 @@ +Allow server admins to define and enforce a password policy (MSC2000). \ No newline at end of file diff --git a/changelog.d/7119.doc b/changelog.d/7119.doc new file mode 100644 index 000000000000..05192966c350 --- /dev/null +++ b/changelog.d/7119.doc @@ -0,0 +1 @@ +Update postgres docs with login troubleshooting information. \ No newline at end of file diff --git a/changelog.d/7128.misc b/changelog.d/7128.misc new file mode 100644 index 000000000000..5703f6d2ecde --- /dev/null +++ b/changelog.d/7128.misc @@ -0,0 +1 @@ +Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage. diff --git a/changelog.d/7136.misc b/changelog.d/7136.misc new file mode 100644 index 000000000000..3f666d25fdea --- /dev/null +++ b/changelog.d/7136.misc @@ -0,0 +1 @@ +Refactored the CAS authentication logic to a separate class. diff --git a/changelog.d/7137.removal b/changelog.d/7137.removal new file mode 100644 index 000000000000..75266a06bb3d --- /dev/null +++ b/changelog.d/7137.removal @@ -0,0 +1 @@ +Remove nonfunctional `captcha_bypass_secret` option from `homeserver.yaml`. \ No newline at end of file diff --git a/changelog.d/7141.doc b/changelog.d/7141.doc new file mode 100644 index 000000000000..2fcbd666c29f --- /dev/null +++ b/changelog.d/7141.doc @@ -0,0 +1 @@ +Clean up INSTALL.md a bit. \ No newline at end of file diff --git a/changelog.d/7147.doc b/changelog.d/7147.doc new file mode 100644 index 000000000000..2c855ff5f7b3 --- /dev/null +++ b/changelog.d/7147.doc @@ -0,0 +1 @@ +Add documentation for running a local CAS server for testing. diff --git a/changelog.d/7150.bugfix b/changelog.d/7150.bugfix new file mode 100644 index 000000000000..1feb294799a9 --- /dev/null +++ b/changelog.d/7150.bugfix @@ -0,0 +1 @@ +Ensure `is_verified` is a boolean in responses to `GET /_matrix/client/r0/room_keys/keys`. Also warn the user if they forgot the `version` query param. \ No newline at end of file diff --git a/changelog.d/7151.bugfix b/changelog.d/7151.bugfix new file mode 100644 index 000000000000..8aaa2dc65971 --- /dev/null +++ b/changelog.d/7151.bugfix @@ -0,0 +1 @@ +Fix error page being shown when a custom SAML handler attempted to redirect when processing an auth response. diff --git a/changelog.d/7152.feature b/changelog.d/7152.feature new file mode 100644 index 000000000000..fafa79c7e7f5 --- /dev/null +++ b/changelog.d/7152.feature @@ -0,0 +1 @@ +Improve the support for SSO authentication on the login fallback page. diff --git a/changelog.d/7153.feature b/changelog.d/7153.feature new file mode 100644 index 000000000000..414ebe1f6978 --- /dev/null +++ b/changelog.d/7153.feature @@ -0,0 +1 @@ +Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set. diff --git a/changelog.d/7155.bugfix b/changelog.d/7155.bugfix new file mode 100644 index 000000000000..0bf51e7aba34 --- /dev/null +++ b/changelog.d/7155.bugfix @@ -0,0 +1 @@ +Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo. diff --git a/changelog.d/7157.misc b/changelog.d/7157.misc new file mode 100644 index 000000000000..0eb1128c7a42 --- /dev/null +++ b/changelog.d/7157.misc @@ -0,0 +1 @@ +Add tests for outbound device pokes. diff --git a/changelog.d/7158.misc b/changelog.d/7158.misc new file mode 100644 index 000000000000..269b8daeb086 --- /dev/null +++ b/changelog.d/7158.misc @@ -0,0 +1 @@ +Fix device list update stream ids going backward. diff --git a/changelog.d/7159.bugfix b/changelog.d/7159.bugfix new file mode 100644 index 000000000000..1b341b127b0c --- /dev/null +++ b/changelog.d/7159.bugfix @@ -0,0 +1 @@ +Fix excessive CPU usage by `prune_old_outbound_device_pokes` job. diff --git a/changelog.d/7160.feature b/changelog.d/7160.feature new file mode 100644 index 000000000000..c1205969a18f --- /dev/null +++ b/changelog.d/7160.feature @@ -0,0 +1 @@ +Always send users their own device updates. diff --git a/changelog.d/7167.doc b/changelog.d/7167.doc new file mode 100644 index 000000000000..a7e7ba9b51f5 --- /dev/null +++ b/changelog.d/7167.doc @@ -0,0 +1 @@ +Improve README.md by being explicit about public IP recommendation for TURN relaying. diff --git a/changelog.d/7171.doc b/changelog.d/7171.doc new file mode 100644 index 000000000000..25a3bd8ac6dc --- /dev/null +++ b/changelog.d/7171.doc @@ -0,0 +1 @@ +Fix a small typo in the `metrics_flags` config option. \ No newline at end of file diff --git a/changelog.d/7177.bugfix b/changelog.d/7177.bugfix new file mode 100644 index 000000000000..329a96cb0b65 --- /dev/null +++ b/changelog.d/7177.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause outbound federation traffic to stop working if a client uploaded an incorrect e2e device signature. \ No newline at end of file diff --git a/changelog.d/7178.bugfix b/changelog.d/7178.bugfix new file mode 100644 index 000000000000..35ea645d7596 --- /dev/null +++ b/changelog.d/7178.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause incorrect 'cyclic dependency' error. diff --git a/changelog.d/7181.misc b/changelog.d/7181.misc new file mode 100644 index 000000000000..731f4dcb52e4 --- /dev/null +++ b/changelog.d/7181.misc @@ -0,0 +1 @@ +Clean up some LoggingContext code. diff --git a/changelog.d/7183.misc b/changelog.d/7183.misc new file mode 100644 index 000000000000..731f4dcb52e4 --- /dev/null +++ b/changelog.d/7183.misc @@ -0,0 +1 @@ +Clean up some LoggingContext code. diff --git a/changelog.d/7184.misc b/changelog.d/7184.misc new file mode 100644 index 000000000000..fac5bc04032c --- /dev/null +++ b/changelog.d/7184.misc @@ -0,0 +1 @@ +Convert some of synapse.rest.media to async/await. diff --git a/changelog.d/7190.misc b/changelog.d/7190.misc new file mode 100644 index 000000000000..34348873f171 --- /dev/null +++ b/changelog.d/7190.misc @@ -0,0 +1 @@ +Only run one background database update at a time. diff --git a/changelog.d/7191.feature b/changelog.d/7191.feature new file mode 100644 index 000000000000..83d5685bb2c4 --- /dev/null +++ b/changelog.d/7191.feature @@ -0,0 +1 @@ +Admin users are no longer required to be in a room to create an alias for it. diff --git a/changelog.d/7195.misc b/changelog.d/7195.misc new file mode 100644 index 000000000000..676f285377f5 --- /dev/null +++ b/changelog.d/7195.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/changelog.d/7203.bugfix b/changelog.d/7203.bugfix new file mode 100644 index 000000000000..8b383952e53c --- /dev/null +++ b/changelog.d/7203.bugfix @@ -0,0 +1 @@ +Fix some worker-mode replication handling not being correctly recorded in CPU usage stats. diff --git a/debian/changelog b/debian/changelog index 39ec9da7abc2..642115fc5af4 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,26 @@ +matrix-synapse-py3 (1.12.3) stable; urgency=medium + + [ Richard van der Hoff ] + * Update the Debian build scripts to handle the new installation paths + for the support libraries introduced by Pillow 7.1.1. + + [ Synapse Packaging team ] + * New synapse release 1.12.3. + + -- Synapse Packaging team Fri, 03 Apr 2020 10:55:03 +0100 + +matrix-synapse-py3 (1.12.2) stable; urgency=medium + + * New synapse release 1.12.2. + + -- Synapse Packaging team Mon, 02 Apr 2020 19:02:17 +0000 + +matrix-synapse-py3 (1.12.1) stable; urgency=medium + + * New synapse release 1.12.1. + + -- Synapse Packaging team Mon, 02 Apr 2020 11:30:47 +0000 + matrix-synapse-py3 (1.12.0) stable; urgency=medium * New synapse release 1.12.0. diff --git a/debian/rules b/debian/rules index a4d2ce2ba4cd..c744060a57ae 100755 --- a/debian/rules +++ b/debian/rules @@ -15,17 +15,38 @@ override_dh_installinit: # we don't really want to strip the symbols from our object files. override_dh_strip: +# dh_shlibdeps calls dpkg-shlibdeps, which finds all the binary files +# (executables and shared libs) in the package, and looks for the shared +# libraries that they depend on. It then adds a dependency on the package that +# contains that library to the package. +# +# We make two modifications to that process... +# override_dh_shlibdeps: - # make the postgres package's dependencies a recommendation - # rather than a hard dependency. + # Firstly, postgres is not a hard dependency for us, so we want to make + # the things that psycopg2 depends on (such as libpq) be + # recommendations rather than hard dependencies. We do so by + # running dpkg-shlibdeps manually on psycopg2's libs. + # find debian/$(PACKAGE_NAME)/ -path '*/site-packages/psycopg2/*.so' | \ xargs dpkg-shlibdeps -Tdebian/$(PACKAGE_NAME).substvars \ -pshlibs1 -dRecommends - # all the other dependencies can be normal 'Depends' requirements, - # except for PIL's, which is self-contained and which confuses - # dpkg-shlibdeps. - dh_shlibdeps -X site-packages/PIL/.libs -X site-packages/psycopg2 + # secondly, we exclude PIL's libraries from the process. They are known + # to be self-contained, but they have interdependencies and + # dpkg-shlibdeps doesn't know how to resolve them. + # + # As of Pillow 7.1.0, these libraries are in + # site-packages/Pillow.libs. Previously, they were in + # site-packages/PIL/.libs. + # + # (we also need to exclude psycopg2, of course, since we've already + # dealt with that.) + # + dh_shlibdeps \ + -X site-packages/PIL/.libs \ + -X site-packages/Pillow.libs \ + -X site-packages/psycopg2 override_dh_virtualenv: ./debian/build_virtualenv diff --git a/docs/admin_api/room_membership.md b/docs/admin_api/room_membership.md new file mode 100644 index 000000000000..16736d3d37c7 --- /dev/null +++ b/docs/admin_api/room_membership.md @@ -0,0 +1,34 @@ +# Edit Room Membership API + +This API allows an administrator to join an user account with a given `user_id` +to a room with a given `room_id_or_alias`. You can only modify the membership of +local users. The server administrator must be in the room and have permission to +invite users. + +## Parameters + +The following parameters are available: + +* `user_id` - Fully qualified user: for example, `@user:server.com`. +* `room_id_or_alias` - The room identifier or alias to join: for example, + `!636q39766251:server.com`. + +## Usage + +``` +POST /_synapse/admin/v1/join/ + +{ + "user_id": "@user:server.com" +} +``` + +Including an `access_token` of a server admin. + +Response: + +``` +{ + "room_id": "!636q39766251:server.com" +} +``` diff --git a/docs/dev/cas.md b/docs/dev/cas.md new file mode 100644 index 000000000000..f8d02cc82ca9 --- /dev/null +++ b/docs/dev/cas.md @@ -0,0 +1,64 @@ +# How to test CAS as a developer without a server + +The [django-mama-cas](https://github.com/jbittel/django-mama-cas) project is an +easy to run CAS implementation built on top of Django. + +## Prerequisites + +1. Create a new virtualenv: `python3 -m venv ` +2. Activate your virtualenv: `source /path/to/your/virtualenv/bin/activate` +3. Install Django and django-mama-cas: + ``` + python -m pip install "django<3" "django-mama-cas==2.4.0" + ``` +4. Create a Django project in the current directory: + ``` + django-admin startproject cas_test . + ``` +5. Follow the [install directions](https://django-mama-cas.readthedocs.io/en/latest/installation.html#configuring) for django-mama-cas +6. Setup the SQLite database: `python manage.py migrate` +7. Create a user: + ``` + python manage.py createsuperuser + ``` + 1. Use whatever you want as the username and password. + 2. Leave the other fields blank. +8. Use the built-in Django test server to serve the CAS endpoints on port 8000: + ``` + python manage.py runserver + ``` + +You should now have a Django project configured to serve CAS authentication with +a single user created. + +## Configure Synapse (and Riot) to use CAS + +1. Modify your `homeserver.yaml` to enable CAS and point it to your locally + running Django test server: + ```yaml + cas_config: + enabled: true + server_url: "http://localhost:8000" + service_url: "http://localhost:8081" + #displayname_attribute: name + #required_attributes: + # name: value + ``` +2. Restart Synapse. + +Note that the above configuration assumes the homeserver is running on port 8081 +and that the CAS server is on port 8000, both on localhost. + +## Testing the configuration + +Then in Riot: + +1. Visit the login page with a Riot pointing at your homeserver. +2. Click the Single Sign-On button. +3. Login using the credentials created with `createsuperuser`. +4. You should be logged in. + +If you want to repeat this process you'll need to manually logout first: + +1. http://localhost:8000/admin/ +2. Click "logout" in the top right. diff --git a/docs/dev/saml.md b/docs/dev/saml.md index f41aadce477c..a9bfd2dc05d6 100644 --- a/docs/dev/saml.md +++ b/docs/dev/saml.md @@ -18,9 +18,13 @@ To make Synapse (and therefore Riot) use it: metadata: local: ["samling.xml"] ``` -5. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure +5. Ensure that your `homeserver.yaml` has a setting for `public_baseurl`: + ```yaml + public_baseurl: http://localhost:8080/ + ``` +6. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure the dependencies are installed and ready to go. -6. Restart Synapse. +7. Restart Synapse. Then in Riot: diff --git a/docs/postgres.md b/docs/postgres.md index 04aa7460515d..70fe29cdccaf 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -61,7 +61,33 @@ Note that the PostgreSQL database *must* have the correct encoding set You may need to enable password authentication so `synapse_user` can connect to the database. See -. +. + +If you get an error along the lines of `FATAL: Ident authentication failed for +user "synapse_user"`, you may need to use an authentication method other than +`ident`: + +* If the `synapse_user` user has a password, add the password to the `database:` + section of `homeserver.yaml`. Then add the following to `pg_hba.conf`: + + ``` + host synapse synapse_user ::1/128 md5 # or `scram-sha-256` instead of `md5` if you use that + ``` + +* If the `synapse_user` user does not have a password, then a password doesn't + have to be added to `homeserver.yaml`. But the following does need to be added + to `pg_hba.conf`: + + ``` + host synapse synapse_user ::1/128 trust + ``` + +Note that line order matters in `pg_hba.conf`, so make sure that if you do add a +new line, it is inserted before: + +``` +host all all ::1/128 ident +``` ### Fixing incorrect `COLLATE` or `CTYPE` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 276e43b732f8..6a770508f93f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -872,10 +872,6 @@ media_store_path: "DATADIR/media_store" # #enable_registration_captcha: false -# A secret key used to bypass the captcha test entirely. -# -#captcha_bypass_secret: "YOUR_SECRET_HERE" - # The API endpoint to use for verifying m.login.recaptcha responses. # #recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify" @@ -1090,6 +1086,29 @@ account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process +# Whether users are allowed to change their displayname after it has +# been initially set. Useful when provisioning users based on the +# contents of a third-party directory. +# +# Does not apply to server administrators. Defaults to 'true' +# +#enable_set_displayname: false + +# Whether users are allowed to change their avatar after it has been +# initially set. Useful when provisioning users based on the contents +# of a third-party directory. +# +# Does not apply to server administrators. Defaults to 'true' +# +#enable_set_avatar_url: false + +# Whether users can change the 3PIDs associated with their accounts +# (email address and msisdn). +# +# Defaults to 'true' +# +#enable_3pid_changes: false + # Users who register on this homeserver will automatically be joined # to these rooms # @@ -1125,7 +1144,7 @@ account_threepid_delegates: # enabled by default, either for performance reasons or limited use. # metrics_flags: - # Publish synapse_federation_known_servers, a g auge of the number of + # Publish synapse_federation_known_servers, a gauge of the number of # servers this homeserver knows about, including itself. May cause # performance problems on large homeservers. # @@ -1425,6 +1444,10 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # # By default, this list is empty. # #client_whitelist: @@ -1486,6 +1509,41 @@ password_config: # #pepper: "EVEN_MORE_SECRET" + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true + # Configuration for sending emails from Synapse. # diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index e3a4634b1407..3be8e50c4c6f 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and '<' worker to master flows): > SERVER example.com - < REPLICATE events 53 + < REPLICATE + > POSITION events 53 > RDATA events 54 ["$foo1:bar.com", ...] > RDATA events 55 ["$foo4:bar.com", ...] -The example shows the server accepting a new connection and sending its -identity with the `SERVER` command, followed by the client asking to -subscribe to the `events` stream from the token `53`. The server then -periodically sends `RDATA` commands which have the format -`RDATA `, where the format of `` is -defined by the individual streams. +The example shows the server accepting a new connection and sending its identity +with the `SERVER` command, followed by the client server to respond with the +position of all streams. The server then periodically sends `RDATA` commands +which have the format `RDATA `, where the format of +`` is defined by the individual streams. Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually connect to the server using a tool like netcat. A few things should be noted when manually using the protocol: -- When subscribing to a stream using `REPLICATE`, the special token - `NOW` can be used to get all future updates. The special stream name - `ALL` can be used with `NOW` to subscribe to all available streams. - The federation stream is only available if federation sending has been disabled on the main process. - The server will only time connections out that have sent a `PING` @@ -91,9 +88,7 @@ The client: - Sends a `NAME` command, allowing the server to associate a human friendly name with the connection. This is optional. - Sends a `PING` as above -- For each stream the client wishes to subscribe to it sends a - `REPLICATE` with the `stream_name` and token it wants to subscribe - from. +- Sends a `REPLICATE` to get the current position of all streams. - On receipt of a `SERVER` command, checks that the server name matches the expected server name. @@ -140,9 +135,7 @@ the wire: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -181,9 +174,9 @@ client (C): #### POSITION (S) - The position of the stream has been updated. Sent to the client - after all missing updates for a stream have been sent to the client - and they're now up to date. + On receipt of a POSITION command clients should check if they have missed any + updates, and if so then fetch them out of band. Sent in response to a + REPLICATE command (but can happen at any time). #### ERROR (S, C) @@ -199,24 +192,17 @@ client (C): #### REPLICATE (C) -Asks the server to replicate a given stream. The syntax is: +Asks the server for the current position of all streams. -``` - REPLICATE -``` +#### USER_SYNC (C) -Where `` may be either: - * a numeric stream_id to stream updates since (exclusive) - * `NOW` to stream all subsequent updates. + A user has started or stopped syncing -The `` is the name of a replication stream to subscribe -to (see [here](../synapse/replication/tcp/streams/_base.py) for a list -of streams). It can also be `ALL` to subscribe to all known streams, -in which case the `` must be set to `NOW`. +#### CLEAR_USER_SYNC (C) -#### USER_SYNC (C) + The server should clear all associated user sync data from the worker. - A user has started or stopped syncing + This is used when a worker is shutting down. #### FEDERATION_ACK (C) diff --git a/docs/turn-howto.md b/docs/turn-howto.md index 1bd3943f54b4..b26e41f19e14 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -11,6 +11,13 @@ TURN server. The following sections describe how to install [coturn]() (which implements the TURN REST API) and integrate it with synapse. +## Requirements + +For TURN relaying with `coturn` to work, it must be hosted on a server/endpoint with a public IP. + +Hosting TURN behind a NAT (even with appropriate port forwarding) is known to cause issues +and to often not work. + ## `coturn` Setup ### Initial installation diff --git a/synapse/__init__.py b/synapse/__init__.py index 5b8600894599..3bf2d0245051 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -36,7 +36,7 @@ except ImportError: pass -__version__ = "1.12.0" +__version__ = "1.12.3" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/constants.py b/synapse/api/constants.py index cc8577552b16..fda2c2e5bbf8 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -61,6 +61,7 @@ class LoginType(object): MSISDN = "m.login.msisdn" RECAPTCHA = "m.login.recaptcha" TERMS = "m.login.terms" + SSO = "org.matrix.login.sso" DUMMY = "m.login.dummy" # Only for C/S API v1 diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 616942b057b0..11da016ac590 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -64,6 +64,13 @@ class Codes(object): INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" + PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT" + PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT" + PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE" + PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE" + PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL" + PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY" + WEAK_PASSWORD = "M_WEAK_PASSWORD" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" BAD_ALIAS = "M_BAD_ALIAS" @@ -439,6 +446,20 @@ def error_dict(self): return cs_error(self.msg, self.errcode, room_version=self._room_version) +class PasswordRefusedError(SynapseError): + """A password has been refused, either during password reset/change or registration. + """ + + def __init__( + self, + msg="This password doesn't comply with the server's policy", + errcode=Codes.WEAK_PASSWORD, + ): + super(PasswordRefusedError, self).__init__( + code=400, msg=msg, errcode=errcode, + ) + + class RequestSendFailed(RuntimeError): """Sending a HTTP request over federation failed due to not being able to talk to the remote server for some reason. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index bd1733573bc3..174bef360f38 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -42,7 +42,7 @@ from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background +from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ @@ -65,6 +65,7 @@ from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, @@ -124,7 +125,6 @@ from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole -from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.generic_worker") @@ -233,6 +233,7 @@ def __init__(self, hs): self.user_to_num_current_syncs = {} self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.instance_id = hs.get_instance_id() active_presence = self.store.take_presence_startup_info() self.user_to_current_state = {state.user_id: state for state in active_presence} @@ -245,13 +246,24 @@ def __init__(self, hs): self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) - self.process_id = random_string(16) - logger.info("Presence process_id is %r", self.process_id) + hs.get_reactor().addSystemEventTrigger( + "before", + "shutdown", + run_as_background_process, + "generic_presence.on_shutdown", + self._on_shutdown, + ) + + def _on_shutdown(self): + if self.hs.config.use_presence: + self.hs.get_tcp_replication().send_command( + ClearUserSyncsCommand(self.instance_id) + ) def send_user_sync(self, user_id, is_syncing, last_sync_ms): if self.hs.config.use_presence: self.hs.get_tcp_replication().send_user_sync( - user_id, is_syncing, last_sync_ms + self.instance_id, user_id, is_syncing, last_sync_ms ) def mark_as_coming_online(self, user_id): @@ -401,6 +413,9 @@ def process_replication_rows(self, token, rows): self._room_serials[row.room_id] = token self._room_typing[row.room_id] = row.user_ids + def get_current_token(self) -> int: + return self._latest_room_serial + class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly @@ -620,7 +635,7 @@ async def on_rdata(self, stream_name, token, rows): await super(GenericWorkerReplicationHandler, self).on_rdata( stream_name, token, rows ) - run_in_background(self.process_and_notify, stream_name, token, rows) + await self.process_and_notify(stream_name, token, rows) def get_streams_to_replicate(self): args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate() @@ -635,7 +650,9 @@ def get_currently_syncing_users(self): async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: - self.send_handler.process_replication_rows(stream_name, token, rows) + await self.send_handler.process_replication_rows( + stream_name, token, rows + ) if stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so @@ -767,12 +784,12 @@ def wake_destination(self, server: str): def stream_positions(self): return {"federation": self.federation_position} - def process_replication_rows(self, stream_name, token, rows): + async def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": send_queue.process_rows_for_federation(self.federation_sender, rows) - run_in_background(self.update_token, token) + await self.update_token(token) # We also need to poke the federation sender when new events happen elif stream_name == "events": @@ -780,9 +797,7 @@ def process_replication_rows(self, stream_name, token, rows): # ... and when new receipts happen elif stream_name == ReceiptsStream.NAME: - run_as_background_process( - "process_receipts_for_federation", self._on_new_receipts, rows - ) + await self._on_new_receipts(rows) # ... as well as device updates and messages elif stream_name == DeviceListsStream.NAME: diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index f0171bb5b230..56c87fa296cb 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -24,7 +24,6 @@ def read_config(self, config, **kwargs): self.enable_registration_captcha = config.get( "enable_registration_captcha", False ) - self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.recaptcha_siteverify_api = config.get( "recaptcha_siteverify_api", "https://www.recaptcha.net/recaptcha/api/siteverify", @@ -49,10 +48,6 @@ def generate_config_section(self, **kwargs): # #enable_registration_captcha: false - # A secret key used to bypass the captcha test entirely. - # - #captcha_bypass_secret: "YOUR_SECRET_HERE" - # The API endpoint to use for verifying m.login.recaptcha responses. # #recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify" diff --git a/synapse/config/database.py b/synapse/config/database.py index b8ab2f86ac3f..c27fef157bff 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -20,6 +20,11 @@ logger = logging.getLogger(__name__) +NON_SQLITE_DATABASE_PATH_WARNING = """\ +Ignoring 'database_path' setting: not using a sqlite3 database. +-------------------------------------------------------------------------------- +""" + DEFAULT_CONFIG = """\ ## Database ## @@ -105,6 +110,11 @@ def __init__(self, name: str, db_config: dict): class DatabaseConfig(Config): section = "database" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.databases = [] + def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) @@ -125,12 +135,13 @@ def read_config(self, config, **kwargs): multi_database_config = config.get("databases") database_config = config.get("database") + database_path = config.get("database_path") if multi_database_config and database_config: raise ConfigError("Can't specify both 'database' and 'datbases' in config") if multi_database_config: - if config.get("database_path"): + if database_path: raise ConfigError("Can't specify 'database_path' with 'databases'") self.databases = [ @@ -138,13 +149,17 @@ def read_config(self, config, **kwargs): for name, db_conf in multi_database_config.items() ] - else: - if database_config is None: - database_config = {"name": "sqlite3", "args": {}} - + if database_config: self.databases = [DatabaseConnectionConfig("master", database_config)] - self.set_databasepath(config.get("database_path")) + if database_path: + if self.databases and self.databases[0].name != "sqlite3": + logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) + return + + database_config = {"name": "sqlite3", "args": {}} + self.databases = [DatabaseConnectionConfig("master", database_config)] + self.set_databasepath(database_path) def generate_config_section(self, data_dir_path, **kwargs): return DEFAULT_CONFIG % { @@ -152,27 +167,37 @@ def generate_config_section(self, data_dir_path, **kwargs): } def read_arguments(self, args): - self.set_databasepath(args.database_path) + """ + Cases for the cli input: + - If no databases are configured and no database_path is set, raise. + - No databases and only database_path available ==> sqlite3 db. + - If there are multiple databases and a database_path raise an error. + - If the database set in the config file is sqlite then + overwrite with the command line argument. + """ - def set_databasepath(self, database_path): - if database_path is None: + if args.database_path is None: + if not self.databases: + raise ConfigError("No database config provided") return - if database_path != ":memory:": - database_path = self.abspath(database_path) + if len(self.databases) == 0: + database_config = {"name": "sqlite3", "args": {}} + self.databases = [DatabaseConnectionConfig("master", database_config)] + self.set_databasepath(args.database_path) + return + + if self.get_single_database().name == "sqlite3": + self.set_databasepath(args.database_path) + else: + logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) - # We only support setting a database path if we have a single sqlite3 - # database. - if len(self.databases) != 1: - raise ConfigError("Cannot specify 'database_path' with multiple databases") + def set_databasepath(self, database_path): - database = self.get_single_database() - if database.config["name"] != "sqlite3": - # We don't raise here as we haven't done so before for this case. - logger.warn("Ignoring 'database_path' for non-sqlite3 database") - return + if database_path != ":memory:": + database_path = self.abspath(database_path) - database.config["args"]["database"] = database_path + self.databases[0].config["args"]["database"] = database_path @staticmethod def add_arguments(parser): @@ -187,7 +212,7 @@ def add_arguments(parser): def get_single_database(self) -> DatabaseConnectionConfig: """Returns the database if there is only one, useful for e.g. tests """ - if len(self.databases) != 1: + if not self.databases: raise Exception("More than one database exists") return self.databases[0] diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 22538153e1e0..6f517a71d092 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -86,7 +86,7 @@ def generate_config_section(self, report_stats=None, **kwargs): # enabled by default, either for performance reasons or limited use. # metrics_flags: - # Publish synapse_federation_known_servers, a g auge of the number of + # Publish synapse_federation_known_servers, a gauge of the number of # servers this homeserver knows about, including itself. May cause # performance problems on large homeservers. # diff --git a/synapse/config/password.py b/synapse/config/password.py index 2a634ac7516e..9c0ea8c30a02 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -31,6 +31,10 @@ def read_config(self, config, **kwargs): self.password_localdb_enabled = password_config.get("localdb_enabled", True) self.password_pepper = password_config.get("pepper", "") + # Password policy + self.password_policy = password_config.get("policy") or {} + self.password_policy_enabled = self.password_policy.get("enabled", False) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ password_config: @@ -48,4 +52,39 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # DO NOT CHANGE THIS AFTER INITIAL SETUP! # #pepper: "EVEN_MORE_SECRET" + + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true """ diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 9bb3beedbc1a..e7ea3a01cb87 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,6 +129,10 @@ def read_config(self, config, **kwargs): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.enable_set_displayname = config.get("enable_set_displayname", True) + self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) + self.enable_3pid_changes = config.get("enable_3pid_changes", True) + self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -330,6 +334,29 @@ def generate_config_section(self, generate_secrets=False, **kwargs): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process + # Whether users are allowed to change their displayname after it has + # been initially set. Useful when provisioning users based on the + # contents of a third-party directory. + # + # Does not apply to server administrators. Defaults to 'true' + # + #enable_set_displayname: false + + # Whether users are allowed to change their avatar after it has been + # initially set. Useful when provisioning users based on the contents + # of a third-party directory. + # + # Does not apply to server administrators. Defaults to 'true' + # + #enable_set_avatar_url: false + + # Whether users can change the 3PIDs associated with their accounts + # (email address and msisdn). + # + # Defaults to 'true' + # + #enable_3pid_changes: false + # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 95762689bc77..ec3dca9efce0 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -39,6 +39,17 @@ def read_config(self, config, **kwargs): self.sso_client_whitelist = sso_config.get("client_whitelist") or [] + # Attempt to also whitelist the server's login fallback, since that fallback sets + # the redirect URL to itself (so it can process the login token then return + # gracefully to the client). This would make it pointless to ask the user for + # confirmation, since the URL the confirmation page would be showing wouldn't be + # the client's. + # public_baseurl is an optional setting, so we only add the fallback's URL to the + # list if it's provided (because we can't figure out what that URL is otherwise). + if self.public_baseurl: + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) + def generate_config_section(self, **kwargs): return """\ # Additional settings to use with single-sign on systems such as SAML2 and CAS. @@ -54,6 +65,10 @@ def generate_config_section(self, **kwargs): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # # By default, this list is empty. # #client_whitelist: diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 233cb33daf94..a477578e445f 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -499,4 +499,13 @@ def wake_destination(self, destination: str): self._get_per_destination_queue(destination).attempt_new_transaction() def get_current_token(self) -> int: + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. return 0 + + async def get_replication_rows( + self, from_token, to_token, limit, federation_ack=None + ): + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. + return [] diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7860f9625e5e..7c09d15a7245 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -53,6 +53,31 @@ logger = logging.getLogger(__name__) +SUCCESS_TEMPLATE = """ + + +Success! + + + + + +
+

Thank you

+

You may now close this window and return to the application

+
+ + +""" + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -91,6 +116,7 @@ def __init__(self, hs): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled + self._saml2_enabled = hs.config.saml2_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first @@ -106,6 +132,13 @@ def __init__(self, hs): if t not in login_types: login_types.append(t) self._supported_login_types = login_types + # Login types and UI Auth types have a heavy overlap, but are not + # necessarily identical. Login types have SSO (and other login types) + # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. + ui_auth_types = login_types.copy() + if self._saml2_enabled: + ui_auth_types.append(LoginType.SSO) + self._supported_ui_auth_types = ui_auth_types # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. @@ -113,10 +146,21 @@ def __init__(self, hs): self._clock = self.hs.get_clock() - # Load the SSO redirect confirmation page HTML template + # Load the SSO HTML templates. + + # The following template is shown to the user during a client login via SSO, + # after the SSO completes and before redirecting them back to their client. + # It notifies the user they are about to give access to their matrix account + # to the client. self._sso_redirect_confirm_template = load_jinja2_templates( hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], )[0] + # The following template is shown during user interactive authentication + # in the fallback auth scenario. It notifies the user that they are + # authenticating for an operation to occur on their account. + self._sso_auth_confirm_template = load_jinja2_templates( + hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"], + )[0] self._server_name = hs.config.server_name @@ -125,7 +169,12 @@ def __init__(self, hs): @defer.inlineCallbacks def validate_user_via_ui_auth( - self, requester: Requester, request_body: Dict[str, Any], clientip: str + self, + requester: Requester, + request: SynapseRequest, + request_body: Dict[str, Any], + clientip: str, + description: str, ): """ Checks that the user is who they claim to be, via a UI auth. @@ -137,10 +186,15 @@ def validate_user_via_ui_auth( Args: requester: The user, as given by the access token + request: The request sent by the client. + request_body: The body of the request sent by the client clientip: The IP address of the client. + description: A human readable string to be displayed to the user that + describes the operation happening on their account. + Returns: defer.Deferred[dict]: the parameters for this request (which may have been given only in a previous call). @@ -169,10 +223,12 @@ def validate_user_via_ui_auth( ) # build a list of supported flows - flows = [[login_type] for login_type in self._supported_login_types] + flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = yield self.check_auth(flows, request_body, clientip) + result, params, _ = yield self.check_auth( + flows, request, request_body, clientip, description + ) except LoginError: # Update the ratelimite to say we failed (`can_do_action` doesn't raise). self._failed_uia_attempts_ratelimiter.can_do_action( @@ -185,7 +241,7 @@ def validate_user_via_ui_auth( raise # find the completed login type - for login_type in self._supported_login_types: + for login_type in self._supported_ui_auth_types: if login_type not in result: continue @@ -211,7 +267,12 @@ def get_enabled_auth_types(self): @defer.inlineCallbacks def check_auth( - self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str + self, + flows: List[List[str]], + request: SynapseRequest, + clientdict: Dict[str, Any], + clientip: str, + description: str, ): """ Takes a dictionary sent by the client in the login / registration @@ -231,11 +292,16 @@ def check_auth( strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + request: The request sent by the client. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. clientip: The IP address of the client. + description: A human readable string to be displayed to the user that + describes the operation happening on their account. + Returns: defer.Deferred[dict, dict, str]: a deferred tuple of (creds, params, session_id). @@ -270,13 +336,33 @@ def check_auth( # email auth link on there). It's probably too open to abuse # because it lets unauthenticated clients store arbitrary objects # on a homeserver. - # Revisit: Assumimg the REST APIs do sensible validation, the data + # Revisit: Assuming the REST APIs do sensible validation, the data # isn't arbintrary. session["clientdict"] = clientdict self._save_session(session) elif "clientdict" in session: clientdict = session["clientdict"] + # Ensure that the queried operation does not vary between stages of + # the UI authentication session. This is done by generating a stable + # comparator based on the URI, method, and body (minus the auth dict) + # and storing it during the initial query. Subsequent queries ensure + # that this comparator has not changed. + comparator = (request.uri, request.method, clientdict) + if "ui_auth" not in session: + session["ui_auth"] = comparator + self._save_session(session) + elif session["ui_auth"] != comparator: + raise SynapseError( + 403, + "Requested operation has changed during the UI authentication session.", + ) + + # Add a human readable description to the session. + if "description" not in session: + session["description"] = description + self._save_session(session) + if not authdict: raise InteractiveAuthIncompleteError( self._auth_dict_for_flows(flows, session) @@ -322,6 +408,7 @@ def check_auth( creds, list(clientdict), ) + return creds, clientdict, session["id"] ret = self._auth_dict_for_flows(flows, session) @@ -962,6 +1049,56 @@ def _do_validate_hash(): else: return defer.succeed(False) + def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + """ + Get the HTML for the SSO redirect confirmation page. + + Args: + redirect_url: The URL to redirect to the SSO provider. + session_id: The user interactive authentication session ID. + + Returns: + The HTML to render. + """ + session = self._get_session_info(session_id) + # Get the human readable operation of what is occurring, falling back to + # a generic message if it isn't available for some reason. + description = session.get("description", "modify your account") + return self._sso_auth_confirm_template.render( + description=description, redirect_url=redirect_url, + ) + + def complete_sso_ui_auth( + self, registered_user_id: str, session_id: str, request: SynapseRequest, + ): + """Having figured out a mxid for this user, complete the HTTP request + + Args: + registered_user_id: The registered user ID to complete SSO login for. + request: The request to complete. + client_redirect_url: The URL to which to redirect the user at the end of the + process. + """ + # Mark the stage of the authentication as successful. + sess = self._get_session_info(session_id) + if "creds" not in sess: + sess["creds"] = {} + creds = sess["creds"] + + # Save the user who authenticated with SSO, this will be used to ensure + # that the account be modified is also the person who logged in. + creds[LoginType.SSO] = registered_user_id + self._save_session(sess) + + # Render the HTML and return. + html_bytes = SUCCESS_TEMPLATE.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + def complete_sso_login( self, registered_user_id: str, diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py new file mode 100644 index 000000000000..f8dc274b78be --- /dev/null +++ b/synapse/handlers/cas_handler.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import xml.etree.ElementTree as ET +from typing import AnyStr, Dict, Optional, Tuple + +from six.moves import urllib + +from twisted.web.client import PartialDownloadError + +from synapse.api.errors import Codes, LoginError +from synapse.http.site import SynapseRequest +from synapse.types import UserID, map_username_to_mxid_localpart + +logger = logging.getLogger(__name__) + + +class CasHandler: + """ + Utility class for to handle the response from a CAS SSO service. + + Args: + hs (synapse.server.HomeServer) + """ + + def __init__(self, hs): + self._hostname = hs.hostname + self._auth_handler = hs.get_auth_handler() + self._registration_handler = hs.get_registration_handler() + + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url + self._cas_displayname_attribute = hs.config.cas_displayname_attribute + self._cas_required_attributes = hs.config.cas_required_attributes + + self._http_client = hs.get_proxied_http_client() + + def _build_service_param(self, client_redirect_url: AnyStr) -> str: + return "%s%s?%s" % ( + self._cas_service_url, + "/_matrix/client/r0/login/cas/ticket", + urllib.parse.urlencode({"redirectUrl": client_redirect_url}), + ) + + async def _handle_cas_response( + self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str + ) -> None: + """ + Retrieves the user and display name from the CAS response and continues with the authentication. + + Args: + request: The original client request. + cas_response_body: The response from the CAS server. + client_redirect_url: The URl to redirect the client to when + everything is done. + """ + user, attributes = self._parse_cas_response(cas_response_body) + displayname = attributes.pop(self._cas_displayname_attribute, None) + + for required_attribute, required_value in self._cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if required_value is not None: + actual_value = attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + await self._on_successful_auth(user, request, client_redirect_url, displayname) + + def _parse_cas_response( + self, cas_response_body: str + ) -> Tuple[str, Dict[str, Optional[str]]]: + """ + Retrieve the user and other parameters from the CAS response. + + Args: + cas_response_body: The response from the CAS query. + + Returns: + A tuple of the user and a mapping of other attributes. + """ + user = None + attributes = {} + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = root[0].tag.endswith("authenticationSuccess") + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + except Exception: + logger.exception("Error parsing CAS response") + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError( + 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED + ) + return user, attributes + + async def _on_successful_auth( + self, + username: str, + request: SynapseRequest, + client_redirect_url: str, + user_display_name: Optional[str] = None, + ) -> None: + """Called once the user has successfully authenticated with the SSO. + + Registers the user if necessary, and then returns a redirect (with + a login token) to the client. + + Args: + username: the remote user id. We'll map this onto + something sane for a MXID localpath. + + request: the incoming request from the browser. We'll + respond to it with a redirect. + + client_redirect_url: the redirect_url the client gave us when + it first started the process. + + user_display_name: if set, and we have to register a new user, + we will set their displayname to this. + """ + localpart = map_username_to_mxid_localpart(username) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=user_display_name + ) + + self._auth_handler.complete_sso_login( + registered_user_id, request, client_redirect_url + ) + + def handle_redirect_request(self, client_redirect_url: bytes) -> bytes: + """ + Generates a URL to the CAS server where the client should be redirected. + + Args: + client_redirect_url: The final URL the client should go to after the + user has negotiated SSO. + + Returns: + The URL to redirect to. + """ + args = urllib.parse.urlencode( + {"service": self._build_service_param(client_redirect_url)} + ) + + return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii") + + async def handle_ticket_request( + self, request: SynapseRequest, client_redirect_url: str, ticket: str + ) -> None: + """ + Validates a CAS ticket sent by the client for login/registration. + + On a successful request, writes a redirect to the request. + """ + uri = self._cas_server_url + "/proxyValidate" + args = { + "ticket": ticket, + "service": self._build_service_param(client_redirect_url), + } + try: + body = await self._http_client.get_raw(uri, args) + except PartialDownloadError as pde: + # Twisted raises this error if the connection is closed, + # even if that's being used old-http style to signal end-of-data + body = pde.response + + await self._handle_cas_response(request, body, client_redirect_url) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index a514c3071445..993499f446de 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -125,8 +125,14 @@ def get_user_ids_changed(self, user_id, from_token): users_who_share_room = yield self.store.get_users_who_share_room_with_user( user_id ) + + tracked_users = set(users_who_share_room) + + # Always tell the user about their own devices + tracked_users.add(user_id) + changed = yield self.store.get_users_whose_devices_changed( - from_token.device_list_key, users_who_share_room + from_token.device_list_key, tracked_users ) # Then work out if any users have since joined @@ -456,7 +462,11 @@ def notify_device_update(self, user_id, device_ids): room_ids = yield self.store.get_rooms_for_user(user_id) - yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids) + # specify the user ID too since the user should always get their own device list + # updates, even if they aren't in any rooms. + yield self.notifier.on_new_event( + "device_list_key", position, users=[user_id], rooms=room_ids + ) if hosts: logger.info( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 1d842c369bed..53e5f585d90c 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -127,7 +127,11 @@ def create_association( errcode=Codes.EXCLUSIVE, ) else: - if self.require_membership and check_membership: + # Server admins are not subject to the same constraints as normal + # users when creating an alias (e.g. being in the room). + is_admin = yield self.auth.is_server_admin(requester.user) + + if (self.require_membership and check_membership) and not is_admin: rooms_for_user = yield self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 38ab6a8fc3e9..c7aa7acf3b61 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,6 +49,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.handlers._base import BaseHandler from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -69,10 +70,9 @@ from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server -from ._base import BaseHandler - logger = logging.getLogger(__name__) @@ -93,27 +93,6 @@ class _NewEventInfo: auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) -def shortstr(iterable, maxitems=5): - """If iterable has maxitems or fewer, return the stringification of a list - containing those items. - - Otherwise, return the stringification of a a list with the first maxitems items, - followed by "...". - - Args: - iterable (Iterable): iterable to truncate - maxitems (int): number of items to return before truncating - - Returns: - unicode - """ - - items = list(itertools.islice(iterable, maxitems + 1)) - if len(items) <= maxitems: - return str(items) - return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py new file mode 100644 index 000000000000..d06b110269c2 --- /dev/null +++ b/synapse/handlers/password_policy.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.errors import Codes, PasswordRefusedError + +logger = logging.getLogger(__name__) + + +class PasswordPolicyHandler(object): + def __init__(self, hs): + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + # Regexps for the spec'd policy parameters. + self.regexp_digit = re.compile("[0-9]") + self.regexp_symbol = re.compile("[^a-zA-Z0-9]") + self.regexp_uppercase = re.compile("[A-Z]") + self.regexp_lowercase = re.compile("[a-z]") + + def validate_password(self, password): + """Checks whether a given password complies with the server's policy. + + Args: + password (str): The password to check against the server's policy. + + Raises: + PasswordRefusedError: The password doesn't comply with the server's policy. + """ + + if not self.enabled: + return + + minimum_accepted_length = self.policy.get("minimum_length", 0) + if len(password) < minimum_accepted_length: + raise PasswordRefusedError( + msg=( + "The password must be at least %d characters long" + % minimum_accepted_length + ), + errcode=Codes.PASSWORD_TOO_SHORT, + ) + + if ( + self.policy.get("require_digit", False) + and self.regexp_digit.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one digit", + errcode=Codes.PASSWORD_NO_DIGIT, + ) + + if ( + self.policy.get("require_symbol", False) + and self.regexp_symbol.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one symbol", + errcode=Codes.PASSWORD_NO_SYMBOL, + ) + + if ( + self.policy.get("require_uppercase", False) + and self.regexp_uppercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one uppercase letter", + errcode=Codes.PASSWORD_NO_UPPERCASE, + ) + + if ( + self.policy.get("require_lowercase", False) + and self.regexp_lowercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one lowercase letter", + errcode=Codes.PASSWORD_NO_LOWERCASE, + ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 50ce0c585b9e..6aa1c0f5e019 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,6 +157,15 @@ def set_displayname(self, target_user, requester, new_displayname, by_admin=Fals if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") + if not by_admin and not self.hs.config.enable_set_displayname: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.display_name: + raise SynapseError( + 400, + "Changing display name is disabled on this server", + Codes.FORBIDDEN, + ) + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -218,6 +227,13 @@ def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False) if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") + if not by_admin and not self.hs.config.enable_set_avatar_url: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.avatar_url: + raise SynapseError( + 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN + ) + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 426042636921..c3ee8db4f009 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -519,6 +519,9 @@ def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): yield self.store.set_room_is_public(old_room_id, False) yield self.store.set_room_is_public(room_id, True) + # Transfer alias mappings in the room directory + yield self.store.update_aliases_for_room(old_room_id, room_id) + # Check if any groups we own contain the predecessor room local_group_ids = yield self.store.get_local_groups_for_room(old_room_id) for group_id in local_group_ids: diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 72c109981bed..4741c82f6156 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import re -from typing import Tuple +from typing import Optional, Tuple import attr import saml2 @@ -26,6 +26,7 @@ from synapse.http.server import finish_request from synapse.http.servlet import parse_string from synapse.module_api import ModuleApi +from synapse.module_api.errors import RedirectException from synapse.types import ( UserID, map_username_to_mxid_localpart, @@ -43,11 +44,15 @@ class Saml2SessionData: # time the session was created, in milliseconds creation_time = attr.ib() + # The user interactive authentication session ID associated with this SAML + # session (or None if this SAML session is for an initial login). + ui_auth_session_id = attr.ib(type=Optional[str], default=None) class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) + self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -76,12 +81,14 @@ def __init__(self, hs): self._error_html_content = hs.config.saml2_error_html_content - def handle_redirect_request(self, client_redirect_url): + def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None): """Handle an incoming request to /login/sso/redirect Args: client_redirect_url (bytes): the URL that we should redirect the client to when everything is done + ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or + None if this is a login). Returns: bytes: URL to redirect to @@ -91,7 +98,9 @@ def handle_redirect_request(self, client_redirect_url): ) now = self._clock.time_msec() - self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now) + self._outstanding_requests_dict[reqid] = Saml2SessionData( + creation_time=now, ui_auth_session_id=ui_auth_session_id, + ) for key, value in info["headers"]: if key == "Location": @@ -118,7 +127,12 @@ async def handle_saml_response(self, request): self.expire_sessions() try: - user_id = await self._map_saml_response_to_user(resp_bytes, relay_state) + user_id, current_session = await self._map_saml_response_to_user( + resp_bytes, relay_state + ) + except RedirectException: + # Raise the exception as per the wishes of the SAML module response + raise except Exception as e: # If decoding the response or mapping it to a user failed, then log the # error and tell the user that something went wrong. @@ -133,9 +147,28 @@ async def handle_saml_response(self, request): finish_request(request) return - self._auth_handler.complete_sso_login(user_id, request, relay_state) + # Complete the interactive auth session or the login. + if current_session and current_session.ui_auth_session_id: + self._auth_handler.complete_sso_ui_auth( + user_id, current_session.ui_auth_session_id, request + ) + + else: + self._auth_handler.complete_sso_login(user_id, request, relay_state) + + async def _map_saml_response_to_user( + self, resp_bytes: str, client_redirect_url: str + ) -> Tuple[str, Optional[Saml2SessionData]]: + """ + Given a sample response, retrieve the cached session and user for it. - async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): + Args: + resp_bytes: The SAML response. + client_redirect_url: The redirect URL passed in by the client. + + Returns: + Tuple of the user ID and SAML session associated with this response. + """ try: saml2_auth = self._saml_client.parse_authn_request_response( resp_bytes, @@ -163,7 +196,9 @@ async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): logger.info("SAML2 mapped attributes: %s", saml2_auth.ava) - self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) + current_session = self._outstanding_requests_dict.pop( + saml2_auth.in_response_to, None + ) remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url @@ -184,7 +219,7 @@ async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): ) if registered_user_id is not None: logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id + return registered_user_id, current_session # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid @@ -209,7 +244,7 @@ async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id + return registered_user_id, current_session # Map saml response to user attributes using the configured mapping provider for i in range(1000): @@ -256,7 +291,7 @@ async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id + return registered_user_id, current_session def expire_sessions(self): expire_before = self._clock.time_msec() - self._saml2_session_lifetime diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 12657ca69836..7d1263caf2cb 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -32,6 +32,7 @@ def __init__(self, hs): super(SetPasswordHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() + self._password_policy_handler = hs.get_password_policy_handler() @defer.inlineCallbacks def set_password( @@ -44,6 +45,7 @@ def set_password( if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) + self._password_policy_handler.validate_password(new_password) password_hash = yield self._auth_handler.hash(new_password) try: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5746fdea1457..1f1cde2feb28 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1143,9 +1143,14 @@ async def _generate_sync_entry_for_device_list( user_id ) + tracked_users = set(users_who_share_room) + + # Always tell the user about their own devices + tracked_users.add(user_id) + # Step 1a, check for changes in devices of users we share a room with users_that_have_changed = await self.store.get_users_whose_devices_changed( - since_token.device_list_key, users_who_share_room + since_token.device_list_key, tracked_users ) # Step 1b, check for newly joined rooms diff --git a/synapse/http/site.py b/synapse/http/site.py index e092193c9c09..32feb0d968db 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -193,6 +193,12 @@ def connectionLost(self, reason): self.finish_time = time.time() Request.connectionLost(self, reason) + if self.logcontext is None: + logger.info( + "Connection from %s lost before request headers were read", self.client + ) + return + # we only get here if the connection to the client drops before we send # the response. # @@ -236,13 +242,6 @@ def _started_processing(self, servlet_name): def _finished_processing(self): """Log the completion of this request and update the metrics """ - - if self.logcontext is None: - # this can happen if the connection closed before we read the - # headers (so render was never called). In that case we'll already - # have logged a warning, so just bail out. - return - usage = self.logcontext.get_resource_usage() if self._processing_finished_time is None: diff --git a/synapse/logging/context.py b/synapse/logging/context.py index a8eafb1c7ce9..a8f674d13da9 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -51,7 +51,7 @@ is_thread_resource_usage_supported = True - def get_thread_resource_usage(): + def get_thread_resource_usage() -> "Optional[resource._RUsage]": return resource.getrusage(RUSAGE_THREAD) @@ -60,7 +60,7 @@ def get_thread_resource_usage(): # won't track resource usage. is_thread_resource_usage_supported = False - def get_thread_resource_usage(): + def get_thread_resource_usage() -> "Optional[resource._RUsage]": return None @@ -201,10 +201,10 @@ def copy_to_twisted_log_entry(self, record): record["request"] = None record["scope"] = None - def start(self): + def start(self, rusage: "Optional[resource._RUsage]"): pass - def stop(self): + def stop(self, rusage: "Optional[resource._RUsage]"): pass def add_database_transaction(self, duration_sec): @@ -261,7 +261,7 @@ def __init__(self, name=None, parent_context=None, request=None) -> None: # The thread resource usage when the logcontext became active. None # if the context is not currently active. - self.usage_start = None + self.usage_start = None # type: Optional[resource._RUsage] self.main_thread = get_thread_id() self.request = None @@ -336,7 +336,17 @@ def copy_to_twisted_log_entry(self, record) -> None: record["request"] = self.request record["scope"] = self.scope - def start(self) -> None: + def start(self, rusage: "Optional[resource._RUsage]") -> None: + """ + Record that this logcontext is currently running. + + This should not be called directly: use set_current_context + + Args: + rusage: the resources used by the current thread, at the point of + switching to this logcontext. May be None if this platform doesn't + support getrusuage. + """ if get_thread_id() != self.main_thread: logger.warning("Started logcontext %s on different thread", self) return @@ -349,36 +359,48 @@ def start(self) -> None: if self.usage_start: logger.warning("Re-starting already-active log context %s", self) else: - self.usage_start = get_thread_resource_usage() + self.usage_start = rusage - def stop(self) -> None: - if get_thread_id() != self.main_thread: - logger.warning("Stopped logcontext %s on different thread", self) - return + def stop(self, rusage: "Optional[resource._RUsage]") -> None: + """ + Record that this logcontext is no longer running. + + This should not be called directly: use set_current_context + + Args: + rusage: the resources used by the current thread, at the point of + switching away from this logcontext. May be None if this platform + doesn't support getrusuage. + """ - # When we stop, let's record the cpu used since we started - if not self.usage_start: - # Log a warning on platforms that support thread usage tracking - if is_thread_resource_usage_supported: + try: + if get_thread_id() != self.main_thread: + logger.warning("Stopped logcontext %s on different thread", self) + return + + if not rusage: + return + + # Record the cpu used since we started + if not self.usage_start: logger.warning( - "Called stop on logcontext %s without calling start", self + "Called stop on logcontext %s without recording a start rusage", + self, ) - return + return - utime_delta, stime_delta = self._get_cputime() - self._resource_usage.ru_utime += utime_delta - self._resource_usage.ru_stime += stime_delta + utime_delta, stime_delta = self._get_cputime(rusage) + self._resource_usage.ru_utime += utime_delta + self._resource_usage.ru_stime += stime_delta - self.usage_start = None + # if we have a parent, pass our CPU usage stats on + if self.parent_context: + self.parent_context._resource_usage += self._resource_usage - # if we have a parent, pass our CPU usage stats on - if self.parent_context is not None and hasattr( - self.parent_context, "_resource_usage" - ): - self.parent_context._resource_usage += self._resource_usage - - # reset them in case we get entered again - self._resource_usage.reset() + # reset them in case we get entered again + self._resource_usage.reset() + finally: + self.usage_start = None def get_resource_usage(self) -> ContextResourceUsage: """Get resources used by this logcontext so far. @@ -394,24 +416,24 @@ def get_resource_usage(self) -> ContextResourceUsage: # can include resource usage so far. is_main_thread = get_thread_id() == self.main_thread if self.usage_start and is_main_thread: - utime_delta, stime_delta = self._get_cputime() + rusage = get_thread_resource_usage() + assert rusage is not None + utime_delta, stime_delta = self._get_cputime(rusage) res.ru_utime += utime_delta res.ru_stime += stime_delta return res - def _get_cputime(self) -> Tuple[float, float]: - """Get the cpu usage time so far + def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]: + """Get the cpu usage time between start() and the given rusage + + Args: + rusage: the current resource usage Returns: Tuple[float, float]: seconds in user mode, seconds in system mode """ assert self.usage_start is not None - current = get_thread_resource_usage() - - # Indicate to mypy that we know that self.usage_start is None. - assert self.usage_start is not None - utime_delta = current.ru_utime - self.usage_start.ru_utime stime_delta = current.ru_stime - self.usage_start.ru_stime @@ -539,12 +561,19 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe Returns: The context that was previously active """ + # everything blows up if we allow current_context to be set to None, so sanity-check + # that now. + if context is None: + raise TypeError("'context' argument may not be None") + current = current_context() if current is not context: - current.stop() + rusage = get_thread_resource_usage() + current.stop(rusage) _thread_local.current_context = context - context.start() + context.start(rusage) + return current diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 28dbc6fcbaf1..4613b2538ce8 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -21,6 +21,7 @@ membership, register, send_event, + streams, ) REPLICATION_PREFIX = "/_synapse/replication" @@ -38,3 +39,4 @@ def register_servlets(self, hs): login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) + streams.register_servlets(hs, self) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py new file mode 100644 index 000000000000..ffd4c6199378 --- /dev/null +++ b/synapse/replication/http/streams.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import parse_integer +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationGetStreamUpdates(ReplicationEndpoint): + """Fetches stream updates from a server. Used for streams not persisted to + the database, e.g. typing notifications. + + The API looks like: + + GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100 + + 200 OK + + { + updates: [ ... ], + upto_token: 10, + limited: False, + } + + """ + + NAME = "get_repl_stream_updates" + PATH_ARGS = ("stream_name",) + METHOD = "GET" + + def __init__(self, hs): + super().__init__(hs) + + # We pull the streams from the replication steamer (if we try and make + # them ourselves we end up in an import loop). + self.streams = hs.get_replication_streamer().get_streams() + + @staticmethod + def _serialize_payload(stream_name, from_token, upto_token, limit): + return {"from_token": from_token, "upto_token": upto_token, "limit": limit} + + async def _handle_request(self, request, stream_name): + stream = self.streams.get(stream_name) + if stream is None: + raise SynapseError(400, "Unknown stream") + + from_token = parse_integer(request, "from_token", required=True) + upto_token = parse_integer(request, "upto_token", required=True) + limit = parse_integer(request, "limit", required=True) + + updates, upto_token, limited = await stream.get_updates_since( + from_token, upto_token, limit + ) + + return ( + 200, + {"updates": updates, "upto_token": upto_token, "limited": limited}, + ) + + +def register_servlets(hs, http_server): + ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f45cbd37a0f5..751c799d9432 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,8 +18,10 @@ import six -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.data_stores.main.cache import ( + CURRENT_STATE_CACHE_NAME, + CacheInvalidationWorkerStore, +) from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine @@ -35,7 +37,7 @@ def __func__(inp): return inp.__func__ -class BaseSlavedStore(SQLBaseStore): +class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): @@ -60,6 +62,12 @@ def stream_positions(self) -> Dict[str, int]: pos["caches"] = self._cache_id_gen.get_current_token() return pos + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 + def process_replication_rows(self, stream_name, token, rows): if stream_name == "caches": if self._cache_id_gen: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index f22c2d44a327..bce8a3d115ca 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -33,6 +33,9 @@ def stream_positions(self): result["pushers"] = self._pushers_id_gen.get_current_token() return result + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + def process_replication_rows(self, stream_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 02ab5b66eab7..e86d9805f11c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -55,6 +55,7 @@ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name + self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) @@ -65,7 +66,7 @@ def startedConnecting(self, connector): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.client_name, self.server_name, self._clock, self.handler + self.hs, self.client_name, self.server_name, self._clock, self.handler, ) def clientConnectionLost(self, connector, reason): @@ -188,10 +189,12 @@ def send_federation_ack(self, token): """ self.send_command(FederationAckCommand(token)) - def send_user_sync(self, user_id, is_syncing, last_sync_ms): + def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """Poke the master that a user has started/stopped syncing. """ - self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms)) + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) def send_remove_pusher(self, app_id, push_key, user_id): """Poke the master to remove a pusher for a user diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 451671412d15..e4eec643f7f9 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -136,8 +136,8 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. - Sent to the client after all missing updates for a stream have been sent - to the client and they're now up to date. + On receipt of a POSITION command clients should check if they have missed + any updates, and if so then fetch them out of band. """ NAME = "POSITION" @@ -179,42 +179,24 @@ class NameCommand(Command): class ReplicateCommand(Command): - """Sent by the client to subscribe to the stream. + """Sent by the client to subscribe to streams. Format:: - REPLICATE - - Where may be either: - * a numeric stream_id to stream updates from - * "NOW" to stream all subsequent updates. - - The can be "ALL" to subscribe to all known streams, in which - case the must be set to "NOW", i.e.:: - - REPLICATE ALL NOW + REPLICATE """ NAME = "REPLICATE" - def __init__(self, stream_name, token): - self.stream_name = stream_name - self.token = token + def __init__(self): + pass @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - if token in ("NOW", "now"): - token = "NOW" - else: - token = int(token) - return cls(stream_name, token) + return cls() def to_line(self): - return " ".join((self.stream_name, str(self.token))) - - def get_logcontext_id(self): - return "REPLICATE-" + self.stream_name + return "" class UserSyncCommand(Command): @@ -225,30 +207,32 @@ class UserSyncCommand(Command): Format:: - USER_SYNC + USER_SYNC Where is either "start" or "stop" """ NAME = "USER_SYNC" - def __init__(self, user_id, is_syncing, last_sync_ms): + def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls, line): - user_id, state, last_sync_ms = line.split(" ", 2) + instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, state == "start", int(last_sync_ms)) def to_line(self): return " ".join( ( + self.instance_id, self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), @@ -256,6 +240,30 @@ def to_line(self): ) +class ClearUserSyncsCommand(Command): + """Sent by the client to inform the server that it should drop all + information about syncing users sent by the client. + + Mainly used when client is about to shut down. + + Format:: + + CLEAR_USER_SYNC + """ + + NAME = "CLEAR_USER_SYNC" + + def __init__(self, instance_id): + self.instance_id = instance_id + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self): + return self.instance_id + + class FederationAckCommand(Command): """Sent by the client when it has processed up to a given point in the federation stream. This allows the master to drop in-memory caches of the @@ -416,6 +424,7 @@ class RemoteServerUpCommand(Command): InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, + ClearUserSyncsCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -438,6 +447,7 @@ class RemoteServerUpCommand(Command): ReplicateCommand.NAME, PingCommand.NAME, UserSyncCommand.NAME, + ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, InvalidateCacheCommand.NAME, diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index bc1482a9bbf2..dae246825fb0 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -35,9 +35,7 @@ > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -53,17 +51,15 @@ import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set, Tuple +from typing import Any, DefaultDict, Dict, List, Set -from six import iteritems, iterkeys +from six import iteritems from prometheus_client import Counter -from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( @@ -82,11 +78,16 @@ SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string +MYPY = False +if MYPY: + from synapse.server import HomeServer + + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -411,16 +412,6 @@ def __init__(self, server_name, clock, streamer): self.server_name = server_name self.streamer = streamer - # The streams the client has subscribed to and is up to date with - self.replication_streams = set() # type: Set[str] - - # The streams the client is currently subscribing to. - self.connecting_streams = set() # type: Set[str] - - # Map from stream name to list of updates to send once we've finished - # subscribing the client to the stream. - self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] - def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) @@ -432,25 +423,17 @@ async def on_NAME(self, cmd): async def on_USER_SYNC(self, cmd): await self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - async def on_REPLICATE(self, cmd): - stream_name = cmd.stream_name - token = cmd.token - - if stream_name == "ALL": - # Subscribe to all streams we're publishing to. - deferreds = [ - run_in_background(self.subscribe_to_stream, stream, token) - for stream in iterkeys(self.streamer.streams_by_name) - ] + async def on_CLEAR_USER_SYNC(self, cmd): + await self.streamer.on_clear_user_syncs(cmd.instance_id) - await make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - else: - await self.subscribe_to_stream(stream_name, token) + async def on_REPLICATE(self, cmd): + # Subscribe to all streams we're publishing to. + for stream_name in self.streamer.streams_by_name: + current_token = self.streamer.get_stream_token(stream_name) + self.send_command(PositionCommand(stream_name, current_token)) async def on_FEDERATION_ACK(self, cmd): self.streamer.federation_ack(cmd.token) @@ -474,87 +457,12 @@ async def on_USER_IP(self, cmd): cmd.last_seen, ) - async def subscribe_to_stream(self, stream_name, token): - """Subscribe the remote to a stream. - - This invloves checking if they've missed anything and sending those - updates down if they have. During that time new updates for the stream - are queued and sent once we've sent down any missed updates. - """ - self.replication_streams.discard(stream_name) - self.connecting_streams.add(stream_name) - - try: - # Get missing updates - updates, current_token = await self.streamer.get_stream_updates( - stream_name, token - ) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) - - # We send a POSITION command to ensure that they have an up to - # date token (especially useful if we didn't send any updates - # above) - self.send_command(PositionCommand(stream_name, current_token)) - - # Now we can send any updates that came in while we were subscribing - pending_rdata = self.pending_rdata.pop(stream_name, []) - updates = [] - for token, update in pending_rdata: - # If the token is null, it is part of a batch update. Batches - # are multiple updates that share a single token. To denote - # this, the token is set to None for all tokens in the batch - # except for the last. If we find a None token, we keep looking - # through tokens until we find one that is not None and then - # process all previous updates in the batch as if they had the - # final token. - if token is None: - # Store this update as part of a batch - updates.append(update) - continue - - if token <= current_token: - # This update or batch of updates is older than - # current_token, dismiss it - updates = [] - continue - - updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear stored updates - updates = [] - - # They're now fully subscribed - self.replication_streams.add(stream_name) - except Exception as e: - logger.exception("[%s] Failed to handle REPLICATE command", self.id()) - self.send_error("failed to handle replicate: %r", e) - finally: - self.connecting_streams.discard(stream_name) - def stream_update(self, stream_name, token, data): """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not """ - if stream_name in self.replication_streams: - # The client is subscribed to the stream - self.send_command(RdataCommand(stream_name, token, data)) - elif stream_name in self.connecting_streams: - # The client is being subscribed to the stream - logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) - self.pending_rdata.setdefault(stream_name, []).append((token, data)) - else: - # The client isn't subscribed - logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + self.send_command(RdataCommand(stream_name, token, data)) def send_sync(self, data): self.send_command(SyncCommand(data)) @@ -638,6 +546,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, + hs: "HomeServer", client_name: str, server_name: str, clock: Clock, @@ -645,41 +554,42 @@ def __init__( ): BaseReplicationStreamProtocol.__init__(self, clock) + self.instance_id = hs.get_instance_id() + self.client_name = client_name self.server_name = server_name self.handler = handler + self.streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. - self.streams_connecting = set() # type: Set[str] + self.streams_connecting = set(STREAMS_MAP) # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} # type: Dict[str, Any] + self.pending_batches = {} # type: Dict[str, List[Any]] def connectionMade(self): self.send_command(NameCommand(self.client_name)) BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): - self.replicate(stream_name, token) + self.replicate() # Tell the server if we have any users currently syncing (should only # happen on synchrotrons) currently_syncing = self.handler.get_currently_syncing_users() now = self.clock.time_msec() for user_id in currently_syncing: - self.send_command(UserSyncCommand(user_id, True, now)) + self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) # We've now finished connecting to so inform the client handler self.handler.update_connection(self) - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.handler.finished_connecting() - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -697,7 +607,7 @@ async def on_RDATA(self, cmd): ) raise - if cmd.token is None: + if cmd.token is None or stream_name in self.streams_connecting: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token self.pending_batches.setdefault(stream_name, []).append(row) @@ -707,14 +617,55 @@ async def on_RDATA(self, cmd): rows.append(row) await self.handler.on_rdata(stream_name, cmd.token, rows) - async def on_POSITION(self, cmd): - # When we get a `POSITION` command it means we've finished getting - # missing updates for the given stream, and are now up to date. + async def on_POSITION(self, cmd: PositionCommand): + stream = self.streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # Find where we previously streamed up to. + current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + if updates: + await self.handler.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self.handler.on_position(cmd.stream_name, cmd.token) + self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - await self.handler.on_position(cmd.stream_name, cmd.token) + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + # Handle any RDATA that came in while we were catching up. + rows = self.pending_batches.pop(cmd.stream_name, []) + if rows: + await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) @@ -722,22 +673,12 @@ async def on_SYNC(self, cmd): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): self.handler.on_remote_server_up(cmd.data) - def replicate(self, stream_name, token): + def replicate(self): """Send the subscription request to the server """ - if stream_name not in STREAMS_MAP: - raise Exception("Invalid stream name %r" % (stream_name,)) - - logger.info( - "[%s] Subscribing to replication stream: %r from %r", - self.id(), - stream_name, - token, - ) - - self.streams_connecting.add(stream_name) + logger.info("[%s] Subscribing to replication streams", self.id()) - self.send_command(ReplicateCommand(stream_name, token)) + self.send_command(ReplicateCommand()) def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 6e2ebaf614d7..30021ee309df 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import Any, List +from typing import Any, Dict, List from six import itervalues @@ -30,7 +30,7 @@ from synapse.util.metrics import Measure, measure_func from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP +from .streams import STREAMS_MAP, Stream from .streams.federation import FederationStream stream_updates_counter = Counter( @@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = ReplicationStreamer(hs) + self.streamer = hs.get_replication_streamer() self.clock = hs.get_clock() self.server_name = hs.config.server_name @@ -99,22 +99,6 @@ def __init__(self, hs): self.streams_by_name = {stream.NAME: stream for stream in self.streams} - LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", - "", - ["stream_name"], - lambda: { - (stream_name,): len( - [ - conn - for conn in self.connections - if stream_name in conn.replication_streams - ] - ) - for stream_name in self.streams_by_name - }, - ) - self.federation_sender = None if not hs.config.send_federation: self.federation_sender = hs.get_federation_sender() @@ -133,6 +117,11 @@ def on_shutdown(self): for conn in self.connections: conn.send_error("server shutting down") + def get_streams(self) -> Dict[str, Stream]: + """Get a mapp from stream name to stream instance. + """ + return self.streams_by_name + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -190,7 +179,8 @@ async def _run_notifier_loop(self): stream.current_token(), ) try: - updates, current_token = await stream.get_updates() + updates, current_token, limited = await stream.get_updates() + self.pending_updates |= limited except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -226,8 +216,7 @@ async def _run_notifier_loop(self): self.pending_updates = False self.is_looping = False - @measure_func("repl.get_stream_updates") - async def get_stream_updates(self, stream_name, token): + def get_stream_token(self, stream_name): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -235,7 +224,7 @@ async def get_stream_updates(self, stream_name, token): if not stream: raise Exception("unknown stream %s", stream_name) - return await stream.get_updates_since(token) + return stream.current_token() @measure_func("repl.federation_ack") def federation_ack(self, token): @@ -246,14 +235,19 @@ def federation_ack(self, token): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() await self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms + instance_id, user_id, is_syncing, last_sync_ms ) + async def on_clear_user_syncs(self, instance_id): + """A replication client wants us to drop all their UserSync data. + """ + await self.presence_handler.update_external_syncs_clear(instance_id) + @measure_func("repl.on_remove_pusher") async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher @@ -316,14 +310,6 @@ def lost_connection(self, connection): except ValueError: pass - # We need to tell the presence handler that the connection has been - # lost so that it can handle any ongoing syncs on that connection. - run_as_background_process( - "update_external_syncs_clear", - self.presence_handler.update_external_syncs_clear, - connection.conn_id, - ) - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 29199f5b466b..37bcd3de6688 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -24,6 +24,9 @@ current_token: The function that returns the current token for the stream update_function: The function that returns a list of updates between two tokens """ + +from typing import Dict, Type + from synapse.replication.tcp.streams._base import ( AccountDataStream, BackfillStream, @@ -35,6 +38,7 @@ PushersStream, PushRulesStream, ReceiptsStream, + Stream, TagAccountDataStream, ToDeviceStream, TypingStream, @@ -63,10 +67,12 @@ GroupServerStream, UserSignatureStream, ) -} +} # type: Dict[str, Type[Stream]] + __all__ = [ "STREAMS_MAP", + "Stream", "BackfillStream", "PresenceStream", "TypingStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 32d9514883d1..c14dff6c6484 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from collections import namedtuple -from typing import Any, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr +from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -29,6 +29,15 @@ MAX_EVENTS_BEHIND = 500000 +# Some type aliases to make things a bit easier. + +# A stream position token +Token = int + +# A pair of position in stream and args used to create an instance of `ROW_TYPE`. +StreamRow = Tuple[Token, tuple] + + class Stream(object): """Base class for the streams. @@ -56,6 +65,7 @@ def parse_row(cls, row): return cls.ROW_TYPE(*row) def __init__(self, hs): + # The token from which we last asked for updates self.last_token = self.current_token() @@ -65,61 +75,46 @@ def discard_updates_and_advance(self): """ self.last_token = self.current_token() - async def get_updates(self): + async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before). Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - updates, current_token = await self.get_updates_since(self.last_token) + current_token = self.current_token() + updates, current_token, limited = await self.get_updates_since( + self.last_token, current_token + ) self.last_token = current_token - return updates, current_token + return updates, current_token, limited async def get_updates_since( - self, from_token: int - ) -> Tuple[List[Tuple[int, JsonDict]], int]: + self, from_token: Token, upto_token: Token, limit: int = 100 + ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Like get_updates except allows specifying from when we should stream updates Returns: - Resolves to a pair `(updates, new_last_token)`, where `updates` is - a list of `(token, row)` entries and `new_last_token` is the new - position in stream. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - if from_token in ("NOW", "now"): - return [], self.current_token() - - current_token = self.current_token() - from_token = int(from_token) - if from_token == current_token: - return [], current_token + if from_token == upto_token: + return [], upto_token, False - rows = await self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 + updates, upto_token, limited = await self.update_function( + from_token, upto_token, limit=limit, ) - - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - - updates = [(row[0], row[1:]) for row in rows] - - # check we didn't get more rows than the limit. - # doing it like this allows the update_function to be a generator. - if len(updates) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) - - # The update function didn't hit the limit, so we must have got all - # the updates to `current_token`, and can return that as our new - # stream position. - return updates, current_token + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -141,6 +136,48 @@ def update_function(self, from_token, current_token, limit): raise NotImplementedError() +def db_query_to_update_function( + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] +) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True + + return updates, upto_token, limited + + return update_function + + +def make_http_update_function( + hs, stream_name: str +) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + """Makes a suitable function for use as an `update_function` that queries + the master process for updates. + """ + + client = ReplicationGetStreamUpdates.make_client(hs) + + async def update_function( + from_token: int, upto_token: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + return await client( + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, + limit=limit, + ) + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -164,7 +201,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_backfill_token # type: ignore - self.update_function = store.get_all_new_backfill_event_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore super(BackfillStream, self).__init__(hs) @@ -190,8 +227,15 @@ def __init__(self, hs): store = hs.get_datastore() presence_handler = hs.get_presence_handler() + self._is_worker = hs.config.worker_app is not None + self.current_token = store.get_current_presence_token # type: ignore - self.update_function = presence_handler.get_all_presence_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore + else: + # Query master process + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(PresenceStream, self).__init__(hs) @@ -208,7 +252,12 @@ def __init__(self, hs): typing_handler = hs.get_typing_handler() self.current_token = typing_handler.get_current_token # type: ignore - self.update_function = typing_handler.get_all_typing_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore + else: + # Query master process + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(TypingStream, self).__init__(hs) @@ -232,7 +281,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_max_receipt_stream_id # type: ignore - self.update_function = store.get_all_updated_receipts # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -256,7 +305,13 @@ def current_token(self): async def update_function(self, from_token, to_token, limit): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): @@ -275,7 +330,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_pushers_stream_token # type: ignore - self.update_function = store.get_all_updated_pushers_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore super(PushersStream, self).__init__(hs) @@ -307,7 +362,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_cache_stream_token # type: ignore - self.update_function = store.get_all_updated_caches # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore super(CachesStream, self).__init__(hs) @@ -333,7 +388,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_public_room_stream_id # type: ignore - self.update_function = store.get_all_new_public_rooms # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -354,7 +409,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -372,7 +427,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_to_device_stream_token # type: ignore - self.update_function = store.get_all_new_device_messages # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -392,7 +447,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_max_account_data_stream_id # type: ignore - self.update_function = store.get_all_updated_tags # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -412,10 +467,11 @@ def __init__(self, hs): self.store = hs.get_datastore() self.current_token = self.store.get_max_account_data_stream_id # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(AccountDataStream, self).__init__(hs) - async def update_function(self, from_token, to_token, limit): + async def _update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) @@ -442,7 +498,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_group_stream_token # type: ignore - self.update_function = store.get_all_groups_changes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore super(GroupServerStream, self).__init__(hs) @@ -460,6 +516,6 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index b3afabb8cde3..c6a595629f7b 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ import attr -from ._base import Stream +from ._base import Stream, db_query_to_update_function """Handling of the 'events' replication stream @@ -117,10 +117,11 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() self.current_token = self._store.get_current_events_token # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(EventsStream, self).__init__(hs) - async def update_function(self, from_token, current_token, limit=None): + async def _update_function(self, from_token, current_token, limit=None): event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index f5f933643073..48c1d4571824 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,9 @@ # limitations under the License. from collections import namedtuple -from ._base import Stream +from twisted.internet import defer + +from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function class FederationStream(Stream): @@ -33,11 +35,18 @@ class FederationStream(Stream): NAME = "federation" ROW_TYPE = FederationStreamRow + _QUERY_MASTER = True def __init__(self, hs): - federation_sender = hs.get_federation_sender() - - self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = federation_sender.get_replication_rows # type: ignore + # Not all synapse instances will have a federation sender instance, + # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, + # so we stub the stream out when that is the case. + if hs.config.worker_app is None or hs.should_send_federation(): + federation_sender = hs.get_federation_sender() + self.current_token = federation_sender.get_current_token # type: ignore + self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore + else: + self.current_token = lambda: 0 # type: ignore + self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore super(FederationStream, self).__init__(hs) diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html new file mode 100644 index 000000000000..0d9de9d46528 --- /dev/null +++ b/synapse/res/templates/sso_auth_confirm.html @@ -0,0 +1,14 @@ + + + Authentication + + +
+

+ A client is trying to {{ description | e }}. To confirm this action, + re-authenticate with single sign-on. + If you did not expect this, your account may be compromised! +

+
+ + diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 4a1fc2ec2bfe..46e458e95ba0 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -41,6 +41,7 @@ keys, notifications, openid, + password_policy, read_marker, receipts, register, @@ -118,6 +119,7 @@ def register_servlets(client_resource, hs): capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) + password_policy.register_servlets(hs, client_resource) # moving to /_synapse/admin synapse.rest.admin.register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 42cc2b062a58..ed70d448a141 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -29,7 +29,11 @@ from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet -from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet +from synapse.rest.admin.rooms import ( + JoinRoomAliasServlet, + ListRoomRestServlet, + ShutdownRoomRestServlet, +) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, @@ -189,6 +193,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f9b8c0a4f0f3..659b8a10ee2d 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Optional -from synapse.api.constants import Membership -from synapse.api.errors import Codes, SynapseError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -29,7 +30,7 @@ historical_admin_path_patterns, ) from synapse.storage.data_stores.main.room import RoomSortOrder -from synapse.types import create_requester +from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -237,3 +238,75 @@ async def on_GET(self, request): response["prev_batch"] = 0 return 200, response + + +class JoinRoomAliasServlet(RestServlet): + + PATTERNS = admin_patterns("/join/(?P[^/]*)") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.admin_handler = hs.get_handlers().admin_handler + self.state_handler = hs.get_state_handler() + + async def on_POST(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + content = parse_json_object_from_request(request) + + assert_params_in_dict(content, ["user_id"]) + target_user = UserID.from_string(content["user_id"]) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "This endpoint can only be used with local users") + + if not await self.admin_handler.get_user(target_user): + raise NotFoundError("User not found") + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + try: + remote_room_hosts = [ + x.decode("ascii") for x in request.args[b"server_name"] + ] # type: Optional[List[str]] + except Exception: + remote_room_hosts = None + elif RoomAlias.is_valid(room_identifier): + handler = self.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + fake_requester = create_requester(target_user) + + # send invite if room has "JoinRules.INVITE" + room_state = await self.state_handler.get_current_state(room_id) + join_rules_event = room_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + await self.room_member_handler.update_membership( + requester=requester, + target=fake_requester.user, + room_id=room_id, + action="invite", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + await self.room_member_handler.update_membership( + requester=fake_requester, + target=fake_requester.user, + room_id=room_id, + action="join", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + return 200, {"room_id": room_id} diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 56d713462ace..59593cbf6e48 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,11 +14,6 @@ # limitations under the License. import logging -import xml.etree.ElementTree as ET - -from six.moves import urllib - -from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -28,9 +23,10 @@ parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn logger = logging.getLogger(__name__) @@ -72,14 +68,6 @@ def login_id_thirdparty_from_phone(identifier): return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} -def build_service_param(cas_service_url, client_redirect_url): - return "%s%s?redirectUrl=%s" % ( - cas_service_url, - "/_matrix/client/r0/login/cas/ticket", - urllib.parse.quote(client_redirect_url, safe=""), - ) - - class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -409,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) - def on_GET(self, request): + def on_GET(self, request: SynapseRequest): args = request.args if b"redirectUrl" not in args: return 400, "Redirect URL not specified for SSO auth" @@ -418,15 +406,15 @@ def on_GET(self, request): request.redirect(sso_url) finish_request(request) - def get_sso_url(self, client_redirect_url): + def get_sso_url(self, client_redirect_url: bytes) -> bytes: """Get the URL to redirect to, to perform SSO auth Args: - client_redirect_url (bytes): the URL that we should redirect the + client_redirect_url: the URL that we should redirect the client to when everything is done Returns: - bytes: URL to redirect to + URL to redirect to """ # to be implemented by subclasses raise NotImplementedError() @@ -434,16 +422,10 @@ def get_sso_url(self, client_redirect_url): class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): - super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url - self.cas_service_url = hs.config.cas_service_url + self._cas_handler = hs.get_cas_handler() - def get_sso_url(self, client_redirect_url): - args = urllib.parse.urlencode( - {"service": build_service_param(self.cas_service_url, client_redirect_url)} - ) - - return "%s/login?%s" % (self.cas_server_url, args) + def get_sso_url(self, client_redirect_url: bytes) -> bytes: + return self._cas_handler.handle_redirect_request(client_redirect_url) class CasTicketServlet(RestServlet): @@ -451,81 +433,15 @@ class CasTicketServlet(RestServlet): def __init__(self, hs): super(CasTicketServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url - self.cas_service_url = hs.config.cas_service_url - self.cas_displayname_attribute = hs.config.cas_displayname_attribute - self.cas_required_attributes = hs.config.cas_required_attributes - self._sso_auth_handler = SSOAuthHandler(hs) - self._http_client = hs.get_proxied_http_client() - - async def on_GET(self, request): - client_redirect_url = parse_string(request, "redirectUrl", required=True) - uri = self.cas_server_url + "/proxyValidate" - args = { - "ticket": parse_string(request, "ticket", required=True), - "service": build_service_param(self.cas_service_url, client_redirect_url), - } - try: - body = await self._http_client.get_raw(uri, args) - except PartialDownloadError as pde: - # Twisted raises this error if the connection is closed, - # even if that's being used old-http style to signal end-of-data - body = pde.response - result = await self.handle_cas_response(request, body, client_redirect_url) - return result + self._cas_handler = hs.get_cas_handler() - def handle_cas_response(self, request, cas_response_body, client_redirect_url): - user, attributes = self.parse_cas_response(cas_response_body) - displayname = attributes.pop(self.cas_displayname_attribute, None) - - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - return self._sso_auth_handler.on_successful_auth( - user, request, client_redirect_url, displayname + async def on_GET(self, request: SynapseRequest) -> None: + client_redirect_url = parse_string(request, "redirectUrl", required=True) + ticket = parse_string(request, "ticket", required=True) + await self._cas_handler.handle_ticket_request( + request, client_redirect_url, ticket ) - def parse_cas_response(self, cas_response_body): - user = None - attributes = {} - try: - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise Exception("root of CAS response is not serviceResponse") - success = root[0].tag.endswith("authenticationSuccess") - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - for attribute in child: - # ElementTree library expands the namespace in - # attribute tags to the full URL of the namespace. - # We don't care about namespace here and it will always - # be encased in curly braces, so we remove them. - tag = attribute.tag - if "}" in tag: - tag = tag.split("}")[1] - attributes[tag] = attribute.text - if user is None: - raise Exception("CAS response does not contain user") - except Exception: - logger.exception("Error parsing CAS response") - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not success: - raise LoginError( - 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED - ) - return user, attributes - class SAMLRedirectServlet(BaseSSORedirectServlet): PATTERNS = client_patterns("/login/sso/redirect", v1=True) @@ -533,65 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._saml_handler = hs.get_saml_handler() - def get_sso_url(self, client_redirect_url): + def get_sso_url(self, client_redirect_url: bytes) -> bytes: return self._saml_handler.handle_redirect_request(client_redirect_url) -class SSOAuthHandler(object): - """ - Utility class for Resources and Servlets which handle the response from a SSO - service - - Args: - hs (synapse.server.HomeServer) - """ - - def __init__(self, hs): - self._hostname = hs.hostname - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() - self._macaroon_gen = hs.get_macaroon_generator() - - # cast to tuple for use with str.startswith - self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) - - async def on_successful_auth( - self, username, request, client_redirect_url, user_display_name=None - ): - """Called once the user has successfully authenticated with the SSO. - - Registers the user if necessary, and then returns a redirect (with - a login token) to the client. - - Args: - username (unicode|bytes): the remote user id. We'll map this onto - something sane for a MXID localpath. - - request (SynapseRequest): the incoming request from the browser. We'll - respond to it with a redirect. - - client_redirect_url (unicode): the redirect_url the client gave us when - it first started the process. - - user_display_name (unicode|None): if set, and we have to register a new user, - we will set their displayname to this. - - Returns: - Deferred[none]: Completes once we have handled the request. - """ - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) - if not registered_user_id: - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name - ) - - self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url - ) - - def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.cas_enabled: diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 631cc74cb42f..31435b1e1c1c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -234,13 +234,21 @@ async def on_POST(self, request): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) params = await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", ) user_id = requester.user.to_string() else: requester = None result, params, _ = await self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request) + [[LoginType.EMAIL_IDENTITY]], + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", ) if LoginType.EMAIL_IDENTITY in result: @@ -308,7 +316,11 @@ async def on_POST(self, request): return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") @@ -602,6 +614,11 @@ async def on_GET(self, request): return 200, {"threepids": threepids} async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -646,6 +663,11 @@ def __init__(self, hs): @interactive_auth_handler async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -656,7 +678,11 @@ async def on_POST(self, request): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( @@ -741,10 +767,16 @@ class ThreepidDeleteRestServlet(RestServlet): def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 85cf5a14c647..1787562b9080 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -18,6 +18,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX +from synapse.handlers.auth import SUCCESS_TEMPLATE from synapse.http.server import finish_request from synapse.http.servlet import RestServlet, parse_string @@ -89,30 +90,6 @@ """ -SUCCESS_TEMPLATE = """ - - -Success! - - - - - -
-

Thank you

-

You may now close this window and return to the application

-
- - -""" - class AuthRestServlet(RestServlet): """ @@ -130,6 +107,11 @@ def __init__(self, hs): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + # SSO configuration. + self._saml_enabled = hs.config.saml2_enabled + if self._saml_enabled: + self._saml_handler = hs.get_saml_handler() + def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: @@ -150,6 +132,15 @@ def on_GET(self, request, stagetype): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } + + elif stagetype == LoginType.SSO and self._saml_enabled: + # Display a confirmation page which prompts the user to + # re-authenticate with their SSO provider. + client_redirect_url = "" + sso_redirect_url = self._saml_handler.handle_redirect_request( + client_redirect_url, session + ) + html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) else: raise SynapseError(404, "Unknown auth stage type") @@ -210,6 +201,9 @@ async def on_POST(self, request, stagetype): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } + elif stagetype == LoginType.SSO: + # The SSO fallback workflow should not post here, + raise SynapseError(404, "Fallback SSO auth does not support POST requests.") else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 94ff73f384e1..c0714fcfb105 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -81,7 +81,11 @@ async def on_POST(self, request): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove device(s) from your account", ) await self.device_handler.delete_devices( @@ -127,7 +131,11 @@ async def on_DELETE(self, request, device_id): raise await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove a device from your account", ) await self.device_handler.delete_device(requester.user.to_string(), device_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f7ed4daf90a7..8f41a3edbfcb 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -263,7 +263,11 @@ async def on_POST(self, request): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a device signing key to your account", ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py new file mode 100644 index 000000000000..968403cca455 --- /dev/null +++ b/synapse/rest/client/v2_alpha/password_policy.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PasswordPolicyServlet(RestServlet): + PATTERNS = client_patterns("/password_policy$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(PasswordPolicyServlet, self).__init__() + + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + def on_GET(self, request): + if not self.enabled or not self.policy: + return (200, {}) + + policy = {} + + for param in [ + "minimum_length", + "require_digit", + "require_symbol", + "require_lowercase", + "require_uppercase", + ]: + if param in self.policy: + policy["m.%s" % param] = self.policy[param] + + return (200, policy) + + +def register_servlets(hs, http_server): + PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a09189b1b469..431ecf4f84e9 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -373,6 +373,7 @@ def __init__(self, hs): self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() + self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_flows = _calculate_registration_flows( @@ -420,6 +421,7 @@ async def on_POST(self, request): or len(body["password"]) > 512 ): raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(body["password"]) desired_username = None if "username" in body: @@ -499,7 +501,11 @@ async def on_POST(self, request): ) auth_result, params, session_id = await self.auth_handler.check_auth( - self._registration_flows, body, self.hs.get_ip_from_request(request) + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), + "register a new account", ) # Check that we're not trying to register a denied 3pid. diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 38952a1d276a..59529707dfa2 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -188,7 +188,7 @@ async def on_GET(self, request, room_id, session_id): """ requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - version = parse_string(request, "version") + version = parse_string(request, "version", required=True) room_keys = await self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id diff --git a/synapse/server.py b/synapse/server.py index 1b980371de31..9228e1c8927c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -56,6 +56,7 @@ from synapse.handlers.acme import AcmeHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator +from synapse.handlers.cas_handler import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.devicemessage import DeviceMessageHandler @@ -66,6 +67,7 @@ from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler +from synapse.handlers.password_policy import PasswordPolicyHandler from synapse.handlers.presence import PresenceHandler from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler @@ -85,6 +87,7 @@ from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -100,6 +103,7 @@ from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor +from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -196,9 +200,12 @@ def build_DEPENDENCY(self) "sendmail", "registration_handler", "account_validity_handler", + "cas_handler", "saml_handler", "event_client_serializer", + "password_policy_handler", "storage", + "replication_streamer", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -224,6 +231,8 @@ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwar self._listening_services = [] self.start_time = None + self.instance_id = random_string(5) + self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() @@ -236,6 +245,14 @@ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwar for depname in kwargs: setattr(self, depname, kwargs[depname]) + def get_instance_id(self): + """A unique ID for this synapse process instance. + + This is used to distinguish running instances in worker-based + deployments. + """ + return self.instance_id + def setup(self): logger.info("Setting up.") self.start_time = int(self.get_clock().time()) @@ -525,6 +542,9 @@ def build_registration_handler(self): def build_account_validity_handler(self): return AccountValidityHandler(self) + def build_cas_handler(self): + return CasHandler(self) + def build_saml_handler(self): from synapse.handlers.saml_handler import SamlHandler @@ -533,9 +553,15 @@ def build_saml_handler(self): def build_event_client_serializer(self): return EventClientSerializer(self) + def build_password_policy_handler(self): + return PasswordPolicyHandler(self) + def build_storage(self) -> Storage: return Storage(self, self.datastores) + def build_replication_streamer(self) -> ReplicationStreamer: + return ReplicationStreamer(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) @@ -557,24 +583,22 @@ def _get(hs): try: builder = getattr(hs, "build_%s" % (depname)) except AttributeError: - builder = None + raise NotImplementedError( + "%s has no %s nor a builder for it" % (type(hs).__name__, depname) + ) - if builder: - # Prevent cyclic dependencies from deadlocking - if depname in hs._building: - raise ValueError("Cyclic dependency while building %s" % (depname,)) - hs._building[depname] = 1 + # Prevent cyclic dependencies from deadlocking + if depname in hs._building: + raise ValueError("Cyclic dependency while building %s" % (depname,)) + hs._building[depname] = 1 + try: dep = builder() setattr(hs, depname, dep) - + finally: del hs._building[depname] - return dep - - raise NotImplementedError( - "%s has no %s nor a builder for it" % (type(hs).__name__, depname) - ) + return dep setattr(HomeServer, "get_%s" % (depname), _get) diff --git a/synapse/server.pyi b/synapse/server.pyi index 3844f0e12ff2..9d1dfa71e792 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -114,3 +114,5 @@ class HomeServer(object): pass def is_mine_id(self, domain_id: str) -> bool: pass + def get_instance_id(self) -> str: + pass diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html index bcb6bc6bb743..712b0e398094 100644 --- a/synapse/static/client/login/index.html +++ b/synapse/static/client/login/index.html @@ -9,7 +9,7 @@

-

Log in with one of the following methods

+

diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index 276c271bbeed..debe46437134 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -1,37 +1,41 @@ window.matrixLogin = { endpoint: location.origin + "/_matrix/client/r0/login", serverAcceptsPassword: false, - serverAcceptsCas: false, serverAcceptsSso: false, }; +var title_pre_auth = "Log in with one of the following methods"; +var title_post_auth = "Logging in..."; + var submitPassword = function(user, pwd) { console.log("Logging in with password..."); + set_title(title_post_auth); var data = { type: "m.login.password", user: user, password: pwd, }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); matrixLogin.onLogin(response); }).error(errorFunc); }; var submitToken = function(loginToken) { console.log("Logging in with login token..."); + set_title(title_post_auth); var data = { type: "m.login.token", token: loginToken }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); matrixLogin.onLogin(response); }).error(errorFunc); }; var errorFunc = function(err) { - show_login(); + // We want to show the error to the user rather than redirecting immediately to the + // SSO portal (if SSO is the only login option), so we inhibit the redirect. + show_login(true); if (err.responseJSON && err.responseJSON.error) { setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")"); @@ -45,26 +49,33 @@ var setFeedbackString = function(text) { $("#feedback").text(text); }; -var show_login = function() { - $("#loading").hide(); - +var show_login = function(inhibit_redirect) { var this_page = window.location.origin + window.location.pathname; $("#sso_redirect_url").val(this_page); - if (matrixLogin.serverAcceptsPassword) { - $("#password_flow").show(); + // If inhibit_redirect is false, and SSO is the only supported login method, we can + // redirect straight to the SSO page + if (matrixLogin.serverAcceptsSso) { + if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) { + $("#sso_form").submit(); + return; + } + + // Otherwise, show the SSO form + $("#sso_form").show(); } - if (matrixLogin.serverAcceptsSso) { - $("#sso_flow").show(); - } else if (matrixLogin.serverAcceptsCas) { - $("#sso_form").attr("action", "/_matrix/client/r0/login/cas/redirect"); - $("#sso_flow").show(); + if (matrixLogin.serverAcceptsPassword) { + $("#password_flow").show(); } - if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas && !matrixLogin.serverAcceptsSso) { + if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) { $("#no_login_types").show(); } + + set_title(title_pre_auth); + + $("#loading").hide(); }; var show_spinner = function() { @@ -74,17 +85,15 @@ var show_spinner = function() { $("#loading").show(); }; +var set_title = function(title) { + $("#title").text(title); +}; var fetch_info = function(cb) { $.get(matrixLogin.endpoint, function(response) { var serverAcceptsPassword = false; - var serverAcceptsCas = false; for (var i=0; i bool: """Check if all the background updates have completed Returns: - Deferred[bool]: True if all background updates have completed + True if all background updates have completed """ # if we've previously determined that there is nothing left to do, that # is easy if self._all_done: return True - # obviously, if we have things in our queue, we're not done. - if self._background_update_queue: + # obviously, if we are currently processing an update, we're not done. + if self._current_background_update: return False # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self.db.simple_select_onecol( + updates = await self.db.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -153,11 +154,10 @@ def has_completed_background_updates(self): async def has_completed_background_update(self, update_name) -> bool: """Check if the given background update has finished running. """ - if self._all_done: return True - if update_name in self._background_update_queue: + if update_name == self._current_background_update: return False update_exists = await self.db.simple_select_one_onecol( @@ -170,9 +170,7 @@ async def has_completed_background_update(self, update_name) -> bool: return not update_exists - async def do_next_background_update( - self, desired_duration_ms: float - ) -> Optional[int]: + async def do_next_background_update(self, desired_duration_ms: float) -> bool: """Does some amount of work on the next queued background update Returns once some amount of work is done. @@ -181,33 +179,51 @@ async def do_next_background_update( desired_duration_ms(float): How long we want to spend updating. Returns: - None if there is no more work to do, otherwise an int + True if we have finished running all the background updates, otherwise False """ - if not self._background_update_queue: - updates = await self.db.simple_select_list( - "background_updates", - keyvalues=None, - retcols=("update_name", "depends_on"), + + def get_background_updates_txn(txn): + txn.execute( + """ + SELECT update_name, depends_on FROM background_updates + ORDER BY ordering, update_name + """ ) - in_flight = {update["update_name"] for update in updates} - for update in updates: - if update["depends_on"] not in in_flight: - self._background_update_queue.append(update["update_name"]) + return self.db.cursor_to_dict(txn) - if not self._background_update_queue: - # no work left to do - return None + if not self._current_background_update: + all_pending_updates = await self.db.runInteraction( + "background_updates", get_background_updates_txn, + ) + if not all_pending_updates: + # no work left to do + return True + + # find the first update which isn't dependent on another one in the queue. + pending = {update["update_name"] for update in all_pending_updates} + for upd in all_pending_updates: + depends_on = upd["depends_on"] + if not depends_on or depends_on not in pending: + break + logger.info( + "Not starting on bg update %s until %s is done", + upd["update_name"], + depends_on, + ) + else: + # if we get to the end of that for loop, there is a problem + raise Exception( + "Unable to find a background update which doesn't depend on " + "another: dependency cycle?" + ) - # pop from the front, and add back to the back - update_name = self._background_update_queue.pop(0) - self._background_update_queue.append(update_name) + self._current_background_update = upd["update_name"] - res = await self._do_background_update(update_name, desired_duration_ms) - return res + await self._do_background_update(desired_duration_ms) + return False - async def _do_background_update( - self, update_name: str, desired_duration_ms: float - ) -> int: + async def _do_background_update(self, desired_duration_ms: float) -> int: + update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) update_handler = self._background_update_handlers[update_name] @@ -400,27 +416,6 @@ def updater(progress, batch_size): self.register_background_update_handler(update_name, updater) - def start_background_update(self, update_name, progress): - """Starts a background update running. - - Args: - update_name: The update to set running. - progress: The initial state of the progress of the update. - - Returns: - A deferred that completes once the task has been added to the - queue. - """ - # Clear the background update queue so that we will pick up the new - # task on the next iteration of do_background_update. - self._background_update_queue = [] - progress_json = json.dumps(progress) - - return self.db.simple_insert( - "background_updates", - {"update_name": update_name, "progress_json": progress_json}, - ) - def _end_background_update(self, update_name): """Removes a completed background update task from the queue. @@ -429,9 +424,12 @@ def _end_background_update(self, update_name): Returns: A deferred that completes once the task is removed. """ - self._background_update_queue = [ - name for name in self._background_update_queue if name != update_name - ] + if update_name != self._current_background_update: + raise Exception( + "Cannot end background update %s which isn't currently running" + % update_name + ) + self._current_background_update = None return self.db.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index d4c44dcc7586..4dc5da3fe8b6 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -32,7 +32,29 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" -class CacheInvalidationStore(SQLBaseStore): +class CacheInvalidationWorkerStore(SQLBaseStore): + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + +class CacheInvalidationStore(CacheInvalidationWorkerStore): async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -145,26 +167,6 @@ def _send_invalidation_to_replication( }, ) - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_caches", get_all_updated_caches_txn - ) - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 0613b49f4a8a..9a1178fb3947 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -207,6 +207,50 @@ def delete_messages_for_remote_destination_txn(txn): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() + + return rows + + return self.db.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" @@ -411,47 +455,3 @@ def _add_messages_to_local_device_inbox_txn( rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows) - - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ - Args: - last_pos(int): - current_pos(int): - limit(int): - Returns: - A deferred list of rows from the device inbox - """ - if last_pos == current_pos: - return defer.succeed([]) - - def get_all_new_device_messages_txn(txn): - # We limit like this as we might have multiple rows per stream_id, and - # we want to make sure we always get all entries for any stream_id - # we return. - upper_pos = min(current_pos, last_pos + limit) - sql = ( - "SELECT max(stream_id), user_id" - " FROM device_inbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY user_id" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() - - sql = ( - "SELECT max(stream_id), destination" - " FROM device_federation_outbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY destination" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) - - # Order by ascending stream ordering - rows.sort() - - return rows - - return self.db.runInteraction( - "get_all_new_device_messages", get_all_new_device_messages_txn - ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 2d47cfd13161..dd3561e9b20d 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -41,6 +41,7 @@ cachedList, ) from synapse.util.iterutils import batch_iter +from synapse.util.stringutils import shortstr logger = logging.getLogger(__name__) @@ -164,7 +165,6 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit): # the max stream_id across each set of duplicate entries # # maps (user_id, device_id) -> (stream_id, opentracing_context) - # as long as their stream_id does not match that of the last row # # opentracing_context contains the opentracing metadata for the request # that created the poke @@ -269,7 +269,14 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m prev_id = yield self._get_last_device_update_for_remote_user( destination, user_id, from_stream_id ) - for device_id, device in iteritems(user_devices): + + # make sure we go through the devices in stream order + device_ids = sorted( + user_devices.keys(), key=lambda i: query_map[(user_id, i)][0], + ) + + for device_id in device_ids: + device = user_devices[device_id] stream_id, opentracing_context = query_map[(user_id, device_id)] result = { "user_id": user_id, @@ -285,14 +292,16 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m key_json = device.get("key_json", None) if key_json: result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) else: result["deleted"] = True @@ -493,14 +502,16 @@ def _get_devices_with_keys_by_user_txn(self, txn, user_id): key_json = device.get("key_json", None) if key_json: result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) results.append(result) @@ -1092,18 +1103,47 @@ def _add_device_outbound_poke_to_stream_txn( ], ) - def _prune_old_outbound_device_pokes(self): + def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. We keep one entry per - (destination, user_id) tuple to ensure that the prev_ids remain correct - if the server does come back. + that we don't fill up due to dead servers. + + Normally, we try to send device updates as a delta since a previous known point: + this is done by setting the prev_id in the m.device_list_update EDU. However, + for that to work, we have to have a complete record of each change to + each device, which can add up to quite a lot of data. + + An alternative mechanism is that, if the remote server sees that it has missed + an entry in the stream_id sequence for a given user, it will request a full + list of that user's devices. Hence, we can reduce the amount of data we have to + store (and transmit in some future transaction), by clearing almost everything + for a given destination out of the database, and having the remote server + resync. + + All we need to do is make sure we keep at least one row for each + (user, destination) pair, to remind us to send a m.device_list_update EDU for + that user when the destination comes back. It doesn't matter which device + we keep. """ - yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + yesterday = self._clock.time_msec() - prune_age def _prune_txn(txn): + # look for (user, destination) pairs which have an update older than + # the cutoff. + # + # For each pair, we also need to know the most recent stream_id, and + # an arbitrary device_id at that stream_id. select_sql = """ - SELECT destination, user_id, max(stream_id) as stream_id - FROM device_lists_outbound_pokes + SELECT + dlop1.destination, + dlop1.user_id, + MAX(dlop1.stream_id) AS stream_id, + (SELECT MIN(dlop2.device_id) AS device_id FROM + device_lists_outbound_pokes dlop2 + WHERE dlop2.destination = dlop1.destination AND + dlop2.user_id=dlop1.user_id AND + dlop2.stream_id=MAX(dlop1.stream_id) + ) + FROM device_lists_outbound_pokes dlop1 GROUP BY destination, user_id HAVING min(ts) < ? AND count(*) > 1 """ @@ -1114,14 +1154,29 @@ def _prune_txn(txn): if not rows: return + logger.info( + "Pruning old outbound device list updates for %i users/destinations: %s", + len(rows), + shortstr((row[0], row[1]) for row in rows), + ) + + # we want to keep the update with the highest stream_id for each user. + # + # there might be more than one update (with different device_ids) with the + # same stream_id, so we also delete all but one rows with the max stream id. delete_sql = """ DELETE FROM device_lists_outbound_pokes - WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + WHERE destination = ? AND user_id = ? AND ( + stream_id < ? OR + (stream_id = ? AND device_id != ?) + ) """ - - txn.executemany( - delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows) - ) + count = 0 + for (destination, user_id, stream_id, device_id) in rows: + txn.execute( + delete_sql, (destination, user_id, stream_id, stream_id, device_id) + ) + count += txn.rowcount # Since we've deleted unsent deltas, we need to remove the entry # of last successful sent so that the prev_ids are correctly set. @@ -1131,7 +1186,7 @@ def _prune_txn(txn): """ txn.executemany(sql, ((row[0], row[1]) for row in rows)) - logger.info("Pruned %d device list outbound pokes", txn.rowcount) + logger.info("Pruned %d device list outbound pokes", count) return run_as_background_process( "prune_old_outbound_device_pokes", diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index c9e7de7d1248..e1d1bc3e0586 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -14,6 +14,7 @@ # limitations under the License. from collections import namedtuple +from typing import Optional from twisted.internet import defer @@ -159,10 +160,29 @@ def _delete_room_alias_txn(self, txn, room_alias): return room_id - def update_aliases_for_room(self, old_room_id, new_room_id, creator): + def update_aliases_for_room( + self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, + ): + """Repoint all of the aliases for a given room, to a different room. + + Args: + old_room_id: + new_room_id: + creator: The user to record as the creator of the new mapping. + If None, the creator will be left unchanged. + """ + def _update_aliases_for_room_txn(txn): - sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" - txn.execute(sql, (new_room_id, creator, old_room_id)) + update_creator_sql = "" + sql_params = (new_room_id, old_room_id) + if creator: + update_creator_sql = ", creator = ?" + sql_params = (new_room_id, creator, old_room_id) + + sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % ( + update_creator_sql, + ) + txn.execute(sql, sql_params) self._invalidate_cache_and_stream( txn, self.get_aliases_for_room, (old_room_id,) ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 84594cf0a9bc..23f4570c4b3e 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -146,7 +146,8 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): room_entry["sessions"][row["session_id"]] = { "first_message_index": row["first_message_index"], "forwarded_count": row["forwarded_count"], - "is_verified": row["is_verified"], + # is_verified must be returned to the client as a boolean + "is_verified": bool(row["is_verified"]), "session_data": json.loads(row["session_data"]), } diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d593ef47b8a5..e71c23541d09 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1267,104 +1267,6 @@ def _count(txn): ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - - def get_current_events_token(self): - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - - def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_forward_event_rows(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows - ) - - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_backfill_event_rows(txn): - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows - ) - @cached(num_args=5, max_entries=10) def get_all_new_events( self, @@ -1850,22 +1752,6 @@ def _get_event_ordering(self, event_id): return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): - def get_all_updated_current_state_deltas_txn(txn): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_current_state_deltas", - get_all_updated_current_state_deltas_txn, - ) - def insert_labels_for_event_txn( self, txn, event_id, labels, room_id, topological_ordering ): diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 3013f49d32c5..16ea8948b119 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -963,3 +963,117 @@ def get_room_complexity(self, room_id): complexity_v1 = round(state_events / 500, 2) return {"v1": complexity_v1} + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 80ca36dedfa9..cf195f8aa61a 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -340,7 +340,7 @@ def _get_expired_url_cache_txn(txn): "get_expired_url_cache", _get_expired_url_cache_txn ) - def delete_url_cache(self, media_ids): + async def delete_url_cache(self, media_ids): if len(media_ids) == 0: return @@ -349,7 +349,7 @@ def delete_url_cache(self, media_ids): def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) + return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) def get_url_cache_media_before(self, before_ts): sql = ( diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 62ac88d9f285..46f9bda773eb 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -41,6 +41,7 @@ def _load_rules(rawrules, enabled_map): rule = dict(rawrule) rule["conditions"] = json.loads(rawrule["conditions"]) rule["actions"] = json.loads(rawrule["actions"]) + rule["default"] = False ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index e6c10c631676..aaebe427d3ac 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -732,6 +732,26 @@ def _quarantine_media_txn( return total_media_quarantined + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = """ + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (prev_id, current_id, limit)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) + + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1249,26 +1269,6 @@ def add_event_report( def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() - - if prev_id == current_id: - return defer.succeed([]) - - return self.db.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - @defer.inlineCallbacks def block_room(self, room_id, user_id): """Marks the room as blocked. Can be called multiple times. diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 2bfeefd54ed2..3bc2e8b9863f 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,14 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sqlite3 import struct import threading +import typing from synapse.storage.engines import BaseDatabaseEngine +if typing.TYPE_CHECKING: + import sqlite3 # noqa: F401 -class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): + +class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): def __init__(self, database_module, database_config): super().__init__(database_module, database_config) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6cb7d4b9229d..1712932f319d 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -29,7 +29,7 @@ # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 57 +SCHEMA_VERSION = 58 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/58/00background_update_ordering.sql b/synapse/storage/schema/delta/58/00background_update_ordering.sql new file mode 100644 index 000000000000..02dae587cc59 --- /dev/null +++ b/synapse/storage/schema/delta/58/00background_update_ordering.sql @@ -0,0 +1,19 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* add an "ordering" column to background_updates, which can be used to sort them + to achieve some level of consistency. */ + +ALTER TABLE background_updates ADD COLUMN ordering INT NOT NULL DEFAULT 0; diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 2c0dcb5208bd..6899bcb788bf 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -13,10 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import itertools import random import re import string +from collections import Iterable import six from six import PY2, PY3 @@ -126,3 +127,21 @@ def assert_valid_client_secret(client_secret): raise SynapseError( 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM ) + + +def shortstr(iterable: Iterable, maxitems: int = 5) -> str: + """If iterable has maxitems or fewer, return the stringification of a list + containing those items. + + Otherwise, return the stringification of a a list with the first maxitems items, + followed by "...". + + Args: + iterable: iterable to truncate + maxitems: number of items to return before truncating + """ + + items = list(itertools.islice(iterable, maxitems + 1)) + if len(items) <= maxitems: + return str(items) + return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index d3feafa1b7b0..be20a89682fd 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -27,8 +27,8 @@ def make_homeserver(self, reactor, clock): return hs - def default_config(self, name="test"): - c = super().default_config(name) + def default_config(self): + c = super().default_config() c["worker_app"] = "synapse.app.frontend_proxy" return c diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 89fcc3889a72..7364f9f1ec15 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -29,8 +29,8 @@ def make_homeserver(self, reactor, clock): ) return hs - def default_config(self, name="test"): - conf = super().default_config(name) + def default_config(self): + conf = super().default_config() # we're using FederationReaderServer, which uses a SlavedStore, so we # have to tell the FederationHandler not to try to access stuff that is only # in the primary store. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 24fa8dbb4508..94980733c4bd 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -33,8 +33,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): login.register_servlets, ] - def default_config(self, name="test"): - config = super().default_config(name=name) + def default_config(self): + config = super().default_config() config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index d456267b87ba..33105576af24 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -12,19 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from mock import Mock +from signedjson import key, sign +from signedjson.types import BaseKey, SigningKey + from twisted.internet import defer -from synapse.types import ReadReceipt +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.types import JsonDict, ReadReceipt from tests.unittest import HomeserverTestCase, override_config -class FederationSenderTestCases(HomeserverTestCase): +class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): - return super(FederationSenderTestCases, self).setup_test_homeserver( + return self.setup_test_homeserver( state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -147,3 +153,392 @@ def test_send_receipts_with_backoff(self): } ], ) + + +class FederationSenderDevicesTestCases(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + state_handler=Mock(spec=["get_current_hosts_in_room"]), + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def default_config(self): + c = super().default_config() + c["send_federation"] = True + return c + + def prepare(self, reactor, clock, hs): + # stub out get_current_hosts_in_room + mock_state_handler = hs.get_state_handler() + mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] + + # stub out get_users_who_share_room_with_user so that it claims that + # `@user2:host2` is in the room + def get_users_who_share_room_with_user(user_id): + return defer.succeed({"@user2:host2"}) + + hs.get_datastore().get_users_who_share_room_with_user = ( + get_users_who_share_room_with_user + ) + + # whenever send_transaction is called, record the edu data + self.edus = [] + self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.record_transaction + ) + + def record_transaction(self, txn, json_cb): + data = json_cb() + self.edus.extend(data["edus"]) + return defer.succeed({}) + + def test_send_device_updates(self): + """Basic case: each device update should result in an EDU""" + # create a device + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + + # expect one edu + self.assertEqual(len(self.edus), 1) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # a second call should produce no new device EDUs + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + self.assertEqual(self.edus, []) + + # a second device + self.login("user", "pass", device_id="D2") + + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + def test_upload_signatures(self): + """Uploading signatures on some devices should produce updates for that user""" + + e2e_handler = self.hs.get_e2e_keys_handler() + + # register two devices + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + self.login(u1, "pass", device_id="D2") + + # expect two edus + self.assertEqual(len(self.edus), 2) + stream_id = None + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload signing keys for each device + device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1") + device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2") + + # expect two more edus + self.assertEqual(len(self.edus), 2) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload master key and self-signing key + master_signing_key = generate_self_id_key() + master_key = { + "user_id": u1, + "usage": ["master"], + "keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)}, + } + + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + selfsigning_signing_key = generate_self_id_key() + selfsigning_key = { + "user_id": u1, + "usage": ["self_signing"], + "keys": { + key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key) + }, + } + sign.sign_json(selfsigning_key, u1, master_signing_key) + + cross_signing_keys = { + "master_key": master_key, + "self_signing_key": selfsigning_key, + } + + self.get_success( + e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys) + ) + + # expect signing key update edu + self.assertEqual(len(self.edus), 1) + self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") + + # sign the devices + d1_json = build_device_dict(u1, "D1", device1_signing_key) + sign.sign_json(d1_json, u1, selfsigning_signing_key) + d2_json = build_device_dict(u1, "D2", device2_signing_key) + sign.sign_json(d2_json, u1, selfsigning_signing_key) + + ret = self.get_success( + e2e_handler.upload_signatures_for_device_keys( + u1, {u1: {"D1": d1_json, "D2": d2_json}}, + ) + ) + self.assertEqual(ret["failures"], {}) + + # expect two edus, in one or two transactions. We don't know what order the + # devices will be updated. + self.assertEqual(len(self.edus), 2) + stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142 + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + if stream_id is not None: + self.assertEqual(c["prev_id"], [stream_id]) + self.assertGreaterEqual(c["stream_id"], stream_id) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2"}, devices) + + def test_delete_devices(self): + """If devices are deleted, that should result in EDUs too""" + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # expect three edus + self.assertEqual(len(self.edus), 3) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + # expect three edus, in an unknown order + self.assertEqual(len(self.edus), 3) + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertGreaterEqual( + c.items(), + {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(), + ) + self.assertGreaterEqual(c["stream_id"], stream_id) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def test_unreachable_server(self): + """If the destination server is unreachable, all the updates should get sent on + recovery + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # for each device, there should be a single update + self.assertEqual(len(self.edus), 3) + stream_id = None + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) + if stream_id is not None: + self.assertGreaterEqual(c["stream_id"], stream_id) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def test_prune_outbound_device_pokes1(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server has never been reachable. + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # there should be a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + + def test_prune_outbound_device_pokes2(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server was reachable, but then goes + offline. + """ + + # create first device + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + + # expect the update EDU + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # now the server goes offline + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 3) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # ... and we should get a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + + def check_device_update_edu( + self, + edu: JsonDict, + user_id: str, + device_id: str, + prev_stream_id: Optional[int], + ) -> int: + """Check that the given EDU is an update for the given device + Returns the stream_id. + """ + self.assertEqual(edu["edu_type"], "m.device_list_update") + content = edu["content"] + + expected = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_stream_id] if prev_stream_id is not None else [], + } + + self.assertLessEqual(expected.items(), content.items()) + if prev_stream_id is not None: + self.assertGreaterEqual(content["stream_id"], prev_stream_id) + return content["stream_id"] + + def check_signing_key_update_txn(self, txn: JsonDict,) -> None: + """Check that the txn has an EDU with a signing key update. + """ + edus = txn["edus"] + self.assertEqual(len(edus), 1) + + def generate_and_upload_device_signing_key( + self, user_id: str, device_id: str + ) -> SigningKey: + """Generate a signing keypair for the given device, and upload it""" + sk = key.generate_signing_key(device_id) + + device_dict = build_device_dict(user_id, device_id, sk) + + self.get_success( + self.hs.get_e2e_keys_handler().upload_keys_for_user( + user_id, device_id, {"device_keys": device_dict}, + ) + ) + return sk + + +def generate_self_id_key() -> SigningKey: + """generate a signing key whose version is its public key + + ... as used by the cross-signing-keys. + """ + k = key.generate_signing_key("x") + k.version = encode_pubkey(k) + return k + + +def key_id(k: BaseKey) -> str: + return "%s:%s" % (k.alg, k.version) + + +def encode_pubkey(sk: SigningKey) -> str: + """Encode the public key corresponding to the given signing key as base64""" + return key.encode_verify_key_base64(key.get_verify_key(sk)) + + +def build_device_dict(user_id: str, device_id: str, sk: SigningKey): + """Build a dict representing the given device""" + return { + "user_id": user_id, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "curve25519:" + device_id: "curve25519+key", + key_id(sk): encode_pubkey(sk), + }, + } diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5e40adba525c..00bb77627183 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -102,6 +102,68 @@ def test_incoming_fed_query(self): self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) +class TestCreateAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_handlers().directory_handler + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def test_create_alias_joined_room(self): + """A user can create an alias for a room they're in.""" + self.get_success( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, self.room_id, + ) + ) + + def test_create_alias_other_room(self): + """A user cannot create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.get_failure( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, other_room_id, + ), + synapse.api.errors.SynapseError, + ) + + def test_create_alias_admin(self): + """An admin can create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.test_user, tok=self.test_user_tok + ) + + self.get_success( + self.handler.create_association( + create_requester(self.admin_user), self.room_alias, other_room_id, + ) + ) + + class TestDeleteAlias(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d60c124eec34..be665262c601 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,6 +70,7 @@ def register_query_handler(query_type, handler): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() + self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -90,6 +91,33 @@ def test_set_my_name(self): "Frank Jr.", ) + # Set displayname again + yield self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank" + ) + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + @defer.inlineCallbacks + def test_set_my_name_if_disabled(self): + self.hs.config.enable_set_displayname = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + # Setting displayname a second time is forbidden + d = self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) + + yield self.assertFailure(d, SynapseError) + @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( @@ -147,3 +175,38 @@ def test_set_my_avatar(self): (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) + + # Set avatar again + yield self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/me.png", + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + @defer.inlineCallbacks + def test_set_my_avatar_if_disabled(self): + self.hs.config.enable_set_avatar_url = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + # Set avatar a second time is forbidden + d = self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) + + yield self.assertFailure(d, SynapseError) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e2915eb7b174..e7b638dbfe49 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() # some of the tests rely on us having a user consent version hs_config["user_consent"] = { diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index e96ad4ca4e48..a755fe28794f 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from mock import Mock from synapse.replication.tcp.commands import ReplicateCommand @@ -29,19 +30,37 @@ def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = server_factory.streamer - server = server_factory.buildProtocol(None) + self.server = server_factory.buildProtocol(None) - # build a replication client, with a dummy handler - handler_factory = Mock() - self.test_handler = TestReplicationClientHandler() - self.test_handler.factory = handler_factory + self.test_handler = Mock(wraps=TestReplicationClientHandler()) self.client = ClientReplicationStreamProtocol( - "client", "test", clock, self.test_handler + hs, "client", "test", clock, self.test_handler, ) - # wire them together - self.client.makeConnection(FakeTransport(server, reactor)) - server.makeConnection(FakeTransport(self.client, reactor)) + self._client_transport = None + self._server_transport = None + + def reconnect(self): + if self._client_transport: + self.client.close() + + if self._server_transport: + self.server.close() + + self._client_transport = FakeTransport(self.server, self.reactor) + self.client.makeConnection(self._client_transport) + + self._server_transport = FakeTransport(self.client, self.reactor) + self.server.makeConnection(self._server_transport) + + def disconnect(self): + if self._client_transport: + self._client_transport = None + self.client.close() + + if self._server_transport: + self._server_transport = None + self.server.close() def replicate(self): """Tell the master side of replication that something has happened, and then @@ -50,19 +69,24 @@ def replicate(self): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self, stream, token="NOW"): + def replicate_stream(self): """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand(stream, token)) + self.client.send_command(ReplicateCommand()) class TestReplicationClientHandler(object): """Drop-in for ReplicationClientHandler which just collects RDATA rows""" def __init__(self): - self.received_rdata_rows = [] + self.streams = set() + self._received_rdata_rows = [] def get_streams_to_replicate(self): - return {} + positions = {s: 0 for s in self.streams} + for stream, token, _ in self._received_rdata_rows: + if stream in self.streams: + positions[stream] = max(token, positions.get(stream, 0)) + return positions def get_currently_syncing_users(self): return [] @@ -73,6 +97,9 @@ def update_connection(self, connection): def finished_connecting(self): pass + async def on_position(self, stream_name, token): + """Called when we get new position data.""" + async def on_rdata(self, stream_name, token, rows): for r in rows: - self.received_rdata_rows.append((stream_name, token, r)) + self._received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index fa2493cad61f..0ec0825a0e62 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -17,30 +17,64 @@ from tests.replication.tcp.streams._base import BaseStreamTestCase USER_ID = "@feeling:blue" -ROOM_ID = "!room:blue" -EVENT_ID = "$event:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): + self.reconnect() + # make the client subscribe to the receipts stream - self.replicate_stream("receipts", "NOW") + self.replicate_stream() + self.test_handler.streams.add("receipts") # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) self.replicate() # there should be one RDATA command - rdata_rows = self.test_handler.received_rdata_rows + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - self.assertEqual(rdata_rows[0][0], "receipts") - row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow - self.assertEqual(ROOM_ID, row.room_id) + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) - self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual("$event:blue", row.event_id) self.assertEqual({"a": 1}, row.data) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + self.get_success( + self.hs.get_datastore().insert_receipt( + "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + ) + ) + self.replicate() + + # Nothing should have happened as we are disconnected + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now have caught up and get the missing data + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") + self.assertEqual(token, 3) + self.assertEqual(1, len(rdata_rows)) + + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room2:blue", row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual("$event2:foo", row.event_id) + self.assertEqual({"a": 2}, row.data) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py new file mode 100644 index 000000000000..672cc3eac521 --- /dev/null +++ b/tests/rest/admin/test_room.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, room + +from tests import unittest + +"""Tests admin REST events for /rooms paths.""" + + +class JoinAliasRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.second_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If a parameter is missing, return an error + """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_local_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + body = json.dumps({"user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_remote_user(self): + """ + Check that only local user can join rooms. + """ + body = json.dumps({"user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "This endpoint can only be used with local users", + channel.json_body["error"], + ) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/!unknown:test" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No known servers", channel.json_body["error"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/invalidroom" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom was not legal room ID or room alias", + channel.json_body["error"], + ) + + def test_join_public_room(self): + """ + Test joining a local user to a public room with "JoinRules.PUBLIC" + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_not_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE" + when server admin is not member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_join_private_room_if_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + self.helper.invite( + room=private_room_id, + src=self.creator, + targ=self.admin_user, + tok=self.creator_tok, + ) + self.helper.join( + room=private_room_id, user=self.admin_user, tok=self.admin_user_tok + ) + + # Validate if server admin is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + # Join user to room. + + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_owner(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is owner of this room. + """ + private_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index da2c9bfa1e57..aed8853d6e9d 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -350,7 +350,14 @@ def test_cas_redirect_confirm(self): def test_cas_redirect_whitelisted(self): """Tests that the SSO login flow serves a redirect to a whitelisted url """ - redirect_url = "https://legit-site.com/" + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" % (urllib.parse.quote(redirect_url)) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index c3facc00eb3a..45a9d445f823 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,6 +24,7 @@ import synapse.rest.admin from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -325,3 +326,304 @@ def deactivate(self, user_id, tok): ) self.render(request) self.assertEqual(request.code, 200) + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_email(self): + """Test adding an email to profile + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_add_email_if_disabled(self): + """Test adding email to profile when doing so is disallowed + """ + self.hs.config.enable_3pid_changes = False + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test deleting an email from profile + """ + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test deleting an email from profile when disallowed + """ + self.hs.config.enable_3pid_changes = False + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index b6df1396ad66..624bf5ada23f 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -104,7 +104,7 @@ def test_fallback_captcha(self): ) self.render(request) - # Now we should have fufilled a complete auth flow, including + # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. request, channel = self.make_request( @@ -115,3 +115,69 @@ def test_fallback_captcha(self): # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") + + def test_cannot_change_operation(self): + """ + The initial requested operation cannot be modified during the user interactive authentication session. + """ + + # Make the initial request to register. (Later on a different password + # will be used.) + request, channel = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + self.render(request) + + # Returns a 401 as per the spec + self.assertEqual(request.code, 401) + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + request, channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) + self.render(request) + self.assertEqual(request.code, 200) + + request, channel = self.make_request( + "POST", + "auth/m.login.recaptcha/fallback/web?session=" + + session + + "&g-recaptcha-response=a", + ) + self.render(request) + self.assertEqual(request.code, 200) + + # The recaptcha handler is called with the response given + attempts = self.recaptcha_checker.recaptcha_attempts + self.assertEqual(len(attempts), 1) + self.assertEqual(attempts[0][0]["response"], "a") + + # also complete the dummy auth + request, channel = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) + self.render(request) + + # Now we should have fulfilled a complete auth flow, including + # the recaptcha fallback step. Make the initial request again, but + # with a different password. This causes the request to fail since the + # operaiton was modified during the ui auth session. + request, channel = self.make_request( + "POST", + "register", + { + "username": "user", + "type": "m.login.password", + "password": "foo", # Note this doesn't match the original request. + "auth": {"session": session}, + }, + ) + self.render(request) + self.assertEqual(channel.code, 403) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py new file mode 100644 index 000000000000..c57072f50c6d --- /dev/null +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + For each test below that checks whether a password triggers the right error code, + that test provides a password good enough to pass the previous tests, but not the + one it is currently testing (nor any test that comes afterward). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/password_policy" + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } + ) + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index d0c997e385bc..b6ed06e02ded 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -36,8 +36,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] url = b"/_matrix/client/r0/register" - def default_config(self, name="test"): - config = super().default_config(name) + def default_config(self): + config = super().default_config() config["allow_guest_access"] = True return config diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 6776a56cadfd..99eb47714983 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -143,8 +143,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): endpoint, to check that the two implementations are compatible. """ - def default_config(self, *args, **kwargs): - config = super().default_config(*args, **kwargs) + def default_config(self): + config = super().default_config() # replace the signing key with our own self.hs_signing_key = signedjson.key.generate_signing_key("kssk") diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index eb540e34f65d..0d27b92a86bb 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -28,7 +28,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() hs_config["server_notices"] = { "system_mxid_localpart": "server", "system_mxid_display_name": "test display name", diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index ae14fb407d2f..940b16612997 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -11,7 +11,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater # the base test class should have run the real bg updates for us - self.assertTrue(self.updates.has_completed_background_updates()) + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) self.update_handler = Mock() self.updates.register_background_update_handler( @@ -25,12 +27,20 @@ def test_do_background_update(self): # the target runtime for each bg update target_background_update_duration_ms = 50000 + store = self.hs.get_datastore() + self.get_success( + store.db.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.hs.get_datastore().db.runInteraction( + yield store.db.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", @@ -39,10 +49,6 @@ def update(progress, count): return count self.update_handler.side_effect = update - - self.get_success( - self.updates.start_background_update("test_update", {"my_key": 1}) - ) self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update( @@ -50,7 +56,7 @@ def update(progress, count): ), by=0.1, ) - self.assertIsNotNone(res) + self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( @@ -73,7 +79,7 @@ def update(progress, count): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNotNone(result) + self.assertFalse(result) self.update_handler.assert_called_once() # third step: we don't expect to be called any more @@ -81,5 +87,5 @@ def update(progress, count): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNone(result) + self.assertTrue(result) self.assertFalse(self.update_handler.called) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 5ec5d2b358fd..5c2817cf28a2 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -28,8 +28,8 @@ class TermsTestCase(unittest.HomeserverTestCase): servlets = [register_servlets] - def default_config(self, name="test"): - config = super().default_config(name) + def default_config(self): + config = super().default_config() config.update( { "public_baseurl": "https://example.org/", @@ -53,7 +53,8 @@ def prepare(self, reactor, clock, hs): def test_ui_auth(self): # Do a UI auth request - request, channel = self.make_request(b"POST", self.url, b"{}") + request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request, channel = self.make_request(b"POST", self.url, request_data) self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) diff --git a/tests/unittest.py b/tests/unittest.py index 439174dbfc5f..27af5228feb9 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -40,6 +40,7 @@ from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import ( SENTINEL_CONTEXT, + LoggingContext, current_context, set_current_context, ) @@ -315,14 +316,11 @@ def create_test_json_resource(self): return resource - def default_config(self, name="test"): + def default_config(self): """ Get a default HomeServer config dict. - - Args: - name (str): The homeserver name/domain. """ - config = default_config(name) + config = default_config("test") # apply any additional config which was specified via the override_config # decorator. @@ -422,15 +420,17 @@ def setup_test_homeserver(self, *args, **kwargs): config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj + async def run_bg_updates(): + with LoggingContext("run_bg_updates", request="run_bg_updates-1"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": - while not self.get_success( - stor.db.updates.has_completed_background_updates() - ): - self.get_success(stor.db.updates.do_next_background_update(1)) + self.get_success(run_bg_updates()) return hs @@ -497,6 +497,7 @@ def register_user(self, username, password, admin=False): "password": password, "admin": admin, "mac": want_mac, + "inhibit_login": True, } ) request, channel = self.make_request( diff --git a/tox.ini b/tox.ini index 8e3f09e63878..a79fc93b57e2 100644 --- a/tox.ini +++ b/tox.ini @@ -186,6 +186,7 @@ commands = mypy \ synapse/federation/sender \ synapse/federation/transport \ synapse/handlers/auth.py \ + synapse/handlers/cas_handler.py \ synapse/handlers/directory.py \ synapse/handlers/presence.py \ synapse/handlers/sync.py \