From 5efeae8ee159206bf2517a12a03a2f18d61e70e3 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Sat, 19 Aug 2023 12:17:34 +0200 Subject: [PATCH 01/12] Add Retry-After header --- synapse/api/errors.py | 3 +++ tests/rest/client/test_login.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 7ffd72c42cd4..f4370433f052 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -16,6 +16,7 @@ """Contains exceptions and error codes.""" import logging +import math import typing from enum import Enum from http import HTTPStatus @@ -512,6 +513,8 @@ def __init__( ): super().__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms + if self.retry_after_ms: + headers = {"Retry-After": math.ceil(retry_after_ms / 1000)} def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index ffbc13bb8df3..de15a957ccbc 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -189,12 +189,14 @@ def test_POST_ratelimiting_per_address(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) + retry_header = int(channel.headers.getRawHeaders("Retry-After")) else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. self.assertTrue(retry_after_ms < 6000) + self.assertTrue(retry_header < 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -234,12 +236,14 @@ def test_POST_ratelimiting_per_account(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) + retry_header = int(channel.headers.getRawHeaders("Retry-After")) else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. self.assertTrue(retry_after_ms < 6000) + self.assertTrue(retry_header < 6) self.reactor.advance(retry_after_ms / 1000.0) @@ -279,12 +283,14 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) + retry_header = int(channel.headers.getRawHeaders("Retry-After")) else: self.assertEqual(channel.code, 403, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. self.assertTrue(retry_after_ms < 6000) + self.assertTrue(retry_header < 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) From 9d26b2f93619d29c166d1b9d84325a2fc2b7fec1 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Sat, 19 Aug 2023 12:19:42 +0200 Subject: [PATCH 02/12] changelog --- changelog.d/16136.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/16136.feature diff --git a/changelog.d/16136.feature b/changelog.d/16136.feature new file mode 100644 index 000000000000..4ad98a88c309 --- /dev/null +++ b/changelog.d/16136.feature @@ -0,0 +1 @@ +Return a `Retry-After` with `M_LIMIT_EXCEEDED` error responses. From 7656743449047714a42078653b408d921f27bc88 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Sat, 19 Aug 2023 15:28:19 +0200 Subject: [PATCH 03/12] Tidy tidy --- synapse/api/errors.py | 9 ++++++--- tests/rest/client/test_login.py | 15 +++++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index f4370433f052..973655d49fe4 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -511,10 +511,13 @@ def __init__( retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - super().__init__(code, msg, errcode) + headers = ( + None + if retry_after_ms is None + else {"Retry-After": str(math.ceil(retry_after_ms / 1000))} + ) + super().__init__(code, msg, errcode, None, headers) self.retry_after_ms = retry_after_ms - if self.retry_after_ms: - headers = {"Retry-After": math.ceil(retry_after_ms / 1000)} def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index de15a957ccbc..76c5bfaa2532 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -189,14 +189,15 @@ def test_POST_ratelimiting_per_address(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) - retry_header = int(channel.headers.getRawHeaders("Retry-After")) + retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. self.assertTrue(retry_after_ms < 6000) - self.assertTrue(retry_header < 6) + assert retry_header + self.assertTrue(int(retry_header[0]) < 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -236,14 +237,15 @@ def test_POST_ratelimiting_per_account(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) - retry_header = int(channel.headers.getRawHeaders("Retry-After")) + retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. self.assertTrue(retry_after_ms < 6000) - self.assertTrue(retry_header < 6) + assert retry_header + self.assertTrue(int(retry_header[0]) < 6) self.reactor.advance(retry_after_ms / 1000.0) @@ -283,14 +285,15 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) - retry_header = int(channel.headers.getRawHeaders("Retry-After")) + retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 403, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. self.assertTrue(retry_after_ms < 6000) - self.assertTrue(retry_header < 6) + assert retry_header + self.assertTrue(int(retry_header[0]) < 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) From ab6f26fb28456f7467d6cddecc9197e18f153f1e Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Sun, 20 Aug 2023 09:46:59 +0200 Subject: [PATCH 04/12] Retry-After may be ceil'd --- tests/rest/client/test_login.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 76c5bfaa2532..883cddffc39f 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -197,7 +197,7 @@ def test_POST_ratelimiting_per_address(self) -> None: # than 1min. self.assertTrue(retry_after_ms < 6000) assert retry_header - self.assertTrue(int(retry_header[0]) < 6) + self.assertTrue(int(retry_header[0]) <= 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -245,7 +245,7 @@ def test_POST_ratelimiting_per_account(self) -> None: # than 1min. self.assertTrue(retry_after_ms < 6000) assert retry_header - self.assertTrue(int(retry_header[0]) < 6) + self.assertTrue(int(retry_header[0]) <= 6) self.reactor.advance(retry_after_ms / 1000.0) @@ -293,7 +293,7 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: # than 1min. self.assertTrue(retry_after_ms < 6000) assert retry_header - self.assertTrue(int(retry_header[0]) < 6) + self.assertTrue(int(retry_header[0]) <= 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) From a16504a95888bd28878c165a626b83791ddd8a86 Mon Sep 17 00:00:00 2001 From: Half-Shot Date: Mon, 21 Aug 2023 15:23:27 +0100 Subject: [PATCH 05/12] Fixes --- synapse/api/errors.py | 2 +- tests/rest/client/test_login.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 973655d49fe4..4badfcbb9c4c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -516,7 +516,7 @@ def __init__( if retry_after_ms is None else {"Retry-After": str(math.ceil(retry_after_ms / 1000))} ) - super().__init__(code, msg, errcode, None, headers) + super().__init__(code, msg, errcode, headers=headers) self.retry_after_ms = retry_after_ms def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 883cddffc39f..553a4895b9b1 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -195,9 +195,9 @@ def test_POST_ratelimiting_per_address(self) -> None: # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertTrue(retry_after_ms < 6000) + self.assertLess(retry_after_ms, 6000) assert retry_header - self.assertTrue(int(retry_header[0]) <= 6) + self.assertLessEqual(int(retry_header[0]), 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -243,9 +243,9 @@ def test_POST_ratelimiting_per_account(self) -> None: # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertTrue(retry_after_ms < 6000) + self.assertLess(retry_after_ms, 6000) assert retry_header - self.assertTrue(int(retry_header[0]) <= 6) + self.assertLessEqual(int(retry_header[0]), 6) self.reactor.advance(retry_after_ms / 1000.0) @@ -291,9 +291,9 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertTrue(retry_after_ms < 6000) + self.assertLess(retry_after_ms, 6000) assert retry_header - self.assertTrue(int(retry_header[0]) <= 6) + self.assertLessEqual(int(retry_header[0]), 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) From 0ccce88603c12e6222bf2b0524e364289d1f0a86 Mon Sep 17 00:00:00 2001 From: Half-Shot Date: Mon, 21 Aug 2023 16:01:45 +0100 Subject: [PATCH 06/12] Add tests for exceeded rounding. --- tests/api/test_errors.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/api/test_errors.py diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py new file mode 100644 index 000000000000..58dfd6d5ed10 --- /dev/null +++ b/tests/api/test_errors.py @@ -0,0 +1,30 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.errors import LimitExceededError + +from tests import unittest + + +class ErrorsTestCase(unittest.TestCase): + def test_limit_exceeded_header(self) -> None: + err = LimitExceededError(retry_after_ms=100) + self.assertEqual(err.error_dict({}).get("retry_after_ms"), 100) + self.assertEqual(err.headers.get("Retry-After"), "1") + + def test_limit_exceeded_rounding(self) -> None: + err = LimitExceededError(retry_after_ms=3001) + self.assertEqual(err.error_dict({}).get("retry_after_ms"), 3001) + self.assertEqual(err.headers.get("Retry-After"), "4") From b26a203506346357231c4eb02eea3060937823fb Mon Sep 17 00:00:00 2001 From: Half-Shot Date: Mon, 21 Aug 2023 16:50:12 +0100 Subject: [PATCH 07/12] Fixup types --- tests/api/test_errors.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py index 58dfd6d5ed10..df0ab14833ae 100644 --- a/tests/api/test_errors.py +++ b/tests/api/test_errors.py @@ -21,10 +21,12 @@ class ErrorsTestCase(unittest.TestCase): def test_limit_exceeded_header(self) -> None: err = LimitExceededError(retry_after_ms=100) - self.assertEqual(err.error_dict({}).get("retry_after_ms"), 100) + self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) + assert err.headers self.assertEqual(err.headers.get("Retry-After"), "1") def test_limit_exceeded_rounding(self) -> None: err = LimitExceededError(retry_after_ms=3001) - self.assertEqual(err.error_dict({}).get("retry_after_ms"), 3001) + self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) + assert err.headers self.assertEqual(err.headers.get("Retry-After"), "4") From f04efbfe7136c253979f6310d190dad3ed04d143 Mon Sep 17 00:00:00 2001 From: Half-Shot Date: Tue, 22 Aug 2023 16:56:22 +0100 Subject: [PATCH 08/12] Allowed headers on errors to be conditional based on config setting. --- synapse/api/errors.py | 38 +++++++++++++++++++++------------- synapse/config/experimental.py | 3 +++ synapse/http/server.py | 10 +++++---- 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 4badfcbb9c4c..41501d9defb2 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -127,14 +127,12 @@ class CodeMessageException(RuntimeError): Attributes: code: HTTP error code msg: string describing the error - headers: optional response headers to send """ def __init__( self, code: Union[int, HTTPStatus], msg: str, - headers: Optional[Dict[str, str]] = None, ): super().__init__("%d: %s" % (code, msg)) @@ -146,7 +144,11 @@ def __init__( # To eliminate this behaviour, we convert them to their integer equivalents here. self.code = int(code) self.msg = msg - self.headers = headers + + def headers_dict( + self, config: Optional["HomeServerConfig"] + ) -> Optional[Dict[str, str]]: + return None class RedirectException(CodeMessageException): @@ -192,7 +194,6 @@ def __init__( msg: str, errcode: str = Codes.UNKNOWN, additional_fields: Optional[Dict] = None, - headers: Optional[Dict[str, str]] = None, ): """Constructs a synapse error. @@ -201,7 +202,7 @@ def __init__( msg: The human-readable error message. errcode: The matrix error code e.g 'M_FORBIDDEN' """ - super().__init__(code, msg, headers) + super().__init__(code, msg) self.errcode = errcode if additional_fields is None: self._additional_fields: Dict = {} @@ -360,11 +361,14 @@ def __init__( self, required_scopes: List[str], ): - headers = { + self.required_scopes = required_scopes + super().__init__(401, "Insufficient scope", Codes.FORBIDDEN, None) + + def headers_dict(self, config: Optional["HomeServerConfig"]) -> Dict[str, str]: + return { "WWW-Authenticate": 'Bearer error="insufficient_scope", scope="%s"' - % (" ".join(required_scopes)) + % (" ".join(self.required_scopes)) } - super().__init__(401, "Insufficient scope", Codes.FORBIDDEN, None, headers) class UnstableSpecAuthError(AuthError): @@ -511,17 +515,23 @@ def __init__( retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - headers = ( - None - if retry_after_ms is None - else {"Retry-After": str(math.ceil(retry_after_ms / 1000))} - ) - super().__init__(code, msg, errcode, headers=headers) + super().__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) + def headers_dict( + self, config: Optional["HomeServerConfig"] + ) -> Optional[Dict[str, str]]: + if ( + self.retry_after_ms is not None + and config + and config.experimental.msc4041_enabled + ): + return {"Retry-After": str(math.ceil(self.retry_after_ms / 1000))} + return None + class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store""" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index ac9449b18f70..508133c70bc0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -398,3 +398,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.msc4010_push_rules_account_data = experimental.get( "msc4010_push_rules_account_data", False ) + + # MSC4041: Use http header Retry-After to enable library-assisted retry handling + self.msc4041_enabled: bool = experimental.get("msc4041_enabled", False) diff --git a/synapse/http/server.py b/synapse/http/server.py index 5109cec983c9..c4af2eca44a2 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -112,8 +112,9 @@ def return_json_error( exc: SynapseError = f.value error_code = exc.code error_dict = exc.error_dict(config) - if exc.headers is not None: - for header, value in exc.headers.items(): + headers_dict = exc.headers_dict(config) + if headers_dict is not None: + for header, value in headers_dict.items(): request.setHeader(header, value) logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) elif f.check(CancelledError): @@ -176,8 +177,9 @@ def return_html_error( cme: CodeMessageException = f.value code = cme.code msg = cme.msg - if cme.headers is not None: - for header, value in cme.headers.items(): + headers = cme.headers_dict(None) + if headers is not None: + for header, value in headers.items(): request.setHeader(header, value) if isinstance(cme, RedirectException): From 793daa71cf702e19e2f7f6087bcf1e5d81f3e564 Mon Sep 17 00:00:00 2001 From: Half-Shot Date: Tue, 22 Aug 2023 16:56:31 +0100 Subject: [PATCH 09/12] Update tests --- tests/api/test_errors.py | 25 ++++++++++++++++++++----- tests/handlers/test_oauth_delegation.py | 4 +++- tests/rest/client/test_login.py | 9 ++++++--- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py index df0ab14833ae..4fcd50d2e4a4 100644 --- a/tests/api/test_errors.py +++ b/tests/api/test_errors.py @@ -14,19 +14,34 @@ # limitations under the License. from synapse.api.errors import LimitExceededError +from synapse.config.homeserver import HomeServerConfig from tests import unittest +from tests.utils import default_config class ErrorsTestCase(unittest.TestCase): + def setUp(self) -> None: + self.config = HomeServerConfig() + self.config.parse_config_dict( + { + **default_config("test"), + "experimental_features": {"msc4041_enabled": True}, + }, + "", + "", + ) + def test_limit_exceeded_header(self) -> None: err = LimitExceededError(retry_after_ms=100) - self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) - assert err.headers - self.assertEqual(err.headers.get("Retry-After"), "1") + self.assertEqual(err.error_dict(self.config).get("retry_after_ms"), 100) + headers = err.headers_dict(self.config) + assert headers is not None + self.assertEqual(headers.get("Retry-After"), "1") def test_limit_exceeded_rounding(self) -> None: err = LimitExceededError(retry_after_ms=3001) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) - assert err.headers - self.assertEqual(err.headers.get("Retry-After"), "4") + headers = err.headers_dict(self.config) + assert headers is not None + self.assertEqual(headers.get("Retry-After"), "4") diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 1456b675a7ec..e44092917375 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -455,8 +455,10 @@ def test_active_guest_not_allowed(self) -> None: method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY ) self._assertParams() + headers = error.value.headers_dict(None) + assert headers is not None self.assertEqual( - getattr(error.value, "headers", {})["WWW-Authenticate"], + headers.get("WWW-Authenticate"), 'Bearer error="insufficient_scope", scope="urn:matrix:org.matrix.msc2967.client:api:*"', ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 553a4895b9b1..62c32cae5e3b 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -169,7 +169,8 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well "account": {"per_second": 10000, "burst_count": 10000}, - } + }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_address(self) -> None: @@ -220,7 +221,8 @@ def test_POST_ratelimiting_per_address(self) -> None: # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well "address": {"per_second": 10000, "burst_count": 10000}, - } + }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_account(self) -> None: @@ -268,7 +270,8 @@ def test_POST_ratelimiting_per_account(self) -> None: # rc_login dict here, we need to set this manually as well "address": {"per_second": 10000, "burst_count": 10000}, "failed_attempts": {"per_second": 0.17, "burst_count": 5}, - } + }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: From efa247c6a7ff446311552d6ed58afbc508a1bd5e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 24 Aug 2023 08:47:42 -0400 Subject: [PATCH 10/12] Revert "Allowed headers on errors to be conditional based on config setting." This reverts commit f04efbfe7136c253979f6310d190dad3ed04d143. --- synapse/api/errors.py | 38 ++++++++++++++------------------------ synapse/http/server.py | 10 ++++------ 2 files changed, 18 insertions(+), 30 deletions(-) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 41501d9defb2..4badfcbb9c4c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -127,12 +127,14 @@ class CodeMessageException(RuntimeError): Attributes: code: HTTP error code msg: string describing the error + headers: optional response headers to send """ def __init__( self, code: Union[int, HTTPStatus], msg: str, + headers: Optional[Dict[str, str]] = None, ): super().__init__("%d: %s" % (code, msg)) @@ -144,11 +146,7 @@ def __init__( # To eliminate this behaviour, we convert them to their integer equivalents here. self.code = int(code) self.msg = msg - - def headers_dict( - self, config: Optional["HomeServerConfig"] - ) -> Optional[Dict[str, str]]: - return None + self.headers = headers class RedirectException(CodeMessageException): @@ -194,6 +192,7 @@ def __init__( msg: str, errcode: str = Codes.UNKNOWN, additional_fields: Optional[Dict] = None, + headers: Optional[Dict[str, str]] = None, ): """Constructs a synapse error. @@ -202,7 +201,7 @@ def __init__( msg: The human-readable error message. errcode: The matrix error code e.g 'M_FORBIDDEN' """ - super().__init__(code, msg) + super().__init__(code, msg, headers) self.errcode = errcode if additional_fields is None: self._additional_fields: Dict = {} @@ -361,14 +360,11 @@ def __init__( self, required_scopes: List[str], ): - self.required_scopes = required_scopes - super().__init__(401, "Insufficient scope", Codes.FORBIDDEN, None) - - def headers_dict(self, config: Optional["HomeServerConfig"]) -> Dict[str, str]: - return { + headers = { "WWW-Authenticate": 'Bearer error="insufficient_scope", scope="%s"' - % (" ".join(self.required_scopes)) + % (" ".join(required_scopes)) } + super().__init__(401, "Insufficient scope", Codes.FORBIDDEN, None, headers) class UnstableSpecAuthError(AuthError): @@ -515,23 +511,17 @@ def __init__( retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - super().__init__(code, msg, errcode) + headers = ( + None + if retry_after_ms is None + else {"Retry-After": str(math.ceil(retry_after_ms / 1000))} + ) + super().__init__(code, msg, errcode, headers=headers) self.retry_after_ms = retry_after_ms def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) - def headers_dict( - self, config: Optional["HomeServerConfig"] - ) -> Optional[Dict[str, str]]: - if ( - self.retry_after_ms is not None - and config - and config.experimental.msc4041_enabled - ): - return {"Retry-After": str(math.ceil(self.retry_after_ms / 1000))} - return None - class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store""" diff --git a/synapse/http/server.py b/synapse/http/server.py index c4af2eca44a2..5109cec983c9 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -112,9 +112,8 @@ def return_json_error( exc: SynapseError = f.value error_code = exc.code error_dict = exc.error_dict(config) - headers_dict = exc.headers_dict(config) - if headers_dict is not None: - for header, value in headers_dict.items(): + if exc.headers is not None: + for header, value in exc.headers.items(): request.setHeader(header, value) logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) elif f.check(CancelledError): @@ -177,9 +176,8 @@ def return_html_error( cme: CodeMessageException = f.value code = cme.code msg = cme.msg - headers = cme.headers_dict(None) - if headers is not None: - for header, value in headers.items(): + if cme.headers is not None: + for header, value in cme.headers.items(): request.setHeader(header, value) if isinstance(cme, RedirectException): From 16a1e097f7281888ad625c55428745e3b723f1bd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 24 Aug 2023 08:54:53 -0400 Subject: [PATCH 11/12] Revert "Update tests" This reverts commit 793daa71cf702e19e2f7f6087bcf1e5d81f3e564. --- tests/api/test_errors.py | 25 +++++-------------------- tests/handlers/test_oauth_delegation.py | 4 +--- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py index 4fcd50d2e4a4..df0ab14833ae 100644 --- a/tests/api/test_errors.py +++ b/tests/api/test_errors.py @@ -14,34 +14,19 @@ # limitations under the License. from synapse.api.errors import LimitExceededError -from synapse.config.homeserver import HomeServerConfig from tests import unittest -from tests.utils import default_config class ErrorsTestCase(unittest.TestCase): - def setUp(self) -> None: - self.config = HomeServerConfig() - self.config.parse_config_dict( - { - **default_config("test"), - "experimental_features": {"msc4041_enabled": True}, - }, - "", - "", - ) - def test_limit_exceeded_header(self) -> None: err = LimitExceededError(retry_after_ms=100) - self.assertEqual(err.error_dict(self.config).get("retry_after_ms"), 100) - headers = err.headers_dict(self.config) - assert headers is not None - self.assertEqual(headers.get("Retry-After"), "1") + self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) + assert err.headers + self.assertEqual(err.headers.get("Retry-After"), "1") def test_limit_exceeded_rounding(self) -> None: err = LimitExceededError(retry_after_ms=3001) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) - headers = err.headers_dict(self.config) - assert headers is not None - self.assertEqual(headers.get("Retry-After"), "4") + assert err.headers + self.assertEqual(err.headers.get("Retry-After"), "4") diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 997bc8774ff9..b891e8469041 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -455,10 +455,8 @@ def test_active_guest_not_allowed(self) -> None: method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY ) self._assertParams() - headers = error.value.headers_dict(None) - assert headers is not None self.assertEqual( - headers.get("WWW-Authenticate"), + getattr(error.value, "headers", {})["WWW-Authenticate"], 'Bearer error="insufficient_scope", scope="urn:matrix:org.matrix.msc2967.client:api:*"', ) From 4302d2de6cad31d35d2c066c67283df043c94547 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 24 Aug 2023 09:15:34 -0400 Subject: [PATCH 12/12] Use a class-property to control headers. --- synapse/api/errors.py | 8 +++++--- synapse/config/experimental.py | 10 ++++++++-- tests/api/test_errors.py | 12 ++++++++---- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 4badfcbb9c4c..578e798773ff 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -504,6 +504,8 @@ def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled.""" + include_retry_after_header = False + def __init__( self, code: int = 429, @@ -512,9 +514,9 @@ def __init__( errcode: str = Codes.LIMIT_EXCEEDED, ): headers = ( - None - if retry_after_ms is None - else {"Retry-After": str(math.ceil(retry_after_ms / 1000))} + {"Retry-After": str(math.ceil(retry_after_ms / 1000))} + if self.include_retry_after_header and retry_after_ms is not None + else None ) super().__init__(code, msg, errcode, headers=headers) self.retry_after_ms = retry_after_ms diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c1cfc6f7adf2..5a1bfb67e15b 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -18,6 +18,7 @@ import attr import attr.validators +from synapse.api.errors import LimitExceededError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config import ConfigError from synapse.config._base import Config, RootConfig @@ -412,5 +413,10 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: "msc4010_push_rules_account_data", False ) - # MSC4041: Use http header Retry-After to enable library-assisted retry handling - self.msc4041_enabled: bool = experimental.get("msc4041_enabled", False) + # MSC4041: Use HTTP header Retry-After to enable library-assisted retry handling + # + # This is a bit hacky, but the most reasonable way to *alway* include the + # headers. + LimitExceededError.include_retry_after_header = experimental.get( + "msc4041_enabled", False + ) diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py index df0ab14833ae..319abfe63dc1 100644 --- a/tests/api/test_errors.py +++ b/tests/api/test_errors.py @@ -19,14 +19,18 @@ class ErrorsTestCase(unittest.TestCase): + # Create a sub-class to avoid mutating the class-level property. + class LimitExceededErrorHeaders(LimitExceededError): + include_retry_after_header = True + def test_limit_exceeded_header(self) -> None: - err = LimitExceededError(retry_after_ms=100) + err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=100) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) - assert err.headers + assert err.headers is not None self.assertEqual(err.headers.get("Retry-After"), "1") def test_limit_exceeded_rounding(self) -> None: - err = LimitExceededError(retry_after_ms=3001) + err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=3001) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) - assert err.headers + assert err.headers is not None self.assertEqual(err.headers.get("Retry-After"), "4")