From f299c5414c2dd300103b0e11e7114123d8eb58a1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 8 Aug 2019 15:30:04 +0100 Subject: [PATCH 1/9] Refactor MatrixFederationAgent to retry SRV. This refactors MatrixFederationAgent to move the SRV lookup into the endpoint code, this has two benefits: 1. Its easier to retry different host/ports in the same way as HostnameEndpoint. 2. We avoid SRV lookups if we have a free connection in the pool --- .../federation/matrix_federation_agent.py | 356 +++++++++--------- synapse/http/federation/srv_resolver.py | 35 +- .../test_matrix_federation_agent.py | 63 +++- tests/http/federation/test_srv_resolver.py | 8 +- 4 files changed, 268 insertions(+), 194 deletions(-) diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 71a15f434d6d..c20818579139 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,21 +14,21 @@ # limitations under the License. import logging +import urllib -import attr -from netaddr import IPAddress +from netaddr import AddrFormatError, IPAddress from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import URI, Agent, HTTPConnectionPool +from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent +from twisted.web.iweb import IAgent, IAgentEndpointFactory -from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list +from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver -from synapse.logging.context import make_deferred_yieldable +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.util import Clock logger = logging.getLogger(__name__) @@ -36,8 +36,9 @@ @implementer(IAgent) class MatrixFederationAgent(object): - """An Agent-like thing which provides a `request` method which will look up a matrix - server and send an HTTP request to it. + """An Agent-like thing which provides a `request` method which correctly + handles resolving matrix server names when using matrix://. Handles standard + https URIs as normal. Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) @@ -65,23 +66,25 @@ def __init__( ): self._reactor = reactor self._clock = Clock(reactor) - - self._tls_client_options_factory = tls_client_options_factory - if _srv_resolver is None: - _srv_resolver = SrvResolver() - self._srv_resolver = _srv_resolver - self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 + self._agent = Agent.usingEndpointFactory( + self._reactor, + MatrixHostnameEndpointFactory( + reactor, tls_client_options_factory, _srv_resolver + ), + pool=self._pool, + ) + self._well_known_resolver = WellKnownResolver( self._reactor, agent=Agent( self._reactor, - pool=self._pool, contextFactory=tls_client_options_factory, + pool=self._pool, ), well_known_cache=_well_known_cache, ) @@ -91,19 +94,15 @@ def request(self, method, uri, headers=None, bodyProducer=None): """ Args: method (bytes): HTTP method: GET/POST/etc - uri (bytes): Absolute URI to be retrieved - headers (twisted.web.http_headers.Headers|None): HTTP headers to send with the request, or None to send no extra headers. - bodyProducer (twisted.web.iweb.IBodyProducer|None): An object which can generate bytes to make up the body of this request (for example, the properly encoded contents of a file for a file upload). Or None if the request is to have no body. - Returns: Deferred[twisted.web.iweb.IResponse]: fires when the header of the response has been received (regardless of the @@ -111,210 +110,195 @@ def request(self, method, uri, headers=None, bodyProducer=None): response from being received (including problems that prevent the request from being sent). """ - parsed_uri = URI.fromBytes(uri, defaultPort=-1) - res = yield self._route_matrix_uri(parsed_uri) + # We use urlparse as that will set `port` to None if there is no + # explicit port. + parsed_uri = urllib.parse.urlparse(uri) - # set up the TLS connection params + # If this is a matrix:// URI check if the server has delegated matrix + # traffic using well-known delegation. # - # XXX disabling TLS is really only supported here for the benefit of the - # unit tests. We should make the UTs cope with TLS rather than having to make - # the code support the unit tests. - if self._tls_client_options_factory is None: - tls_options = None - else: - tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii") + # We have to do this here and not in the endpoint as we need to rewrite + # the host header with the delegated server name. + delegated_server = None + if ( + parsed_uri.scheme == b"matrix" + and not _is_ip_literal(parsed_uri.hostname) + and not parsed_uri.port + ): + well_known_result = yield self._well_known_resolver.get_well_known( + parsed_uri.hostname + ) + delegated_server = well_known_result.delegated_server + + if delegated_server: + # Ok, the server has delegated matrix traffic to somewhere else, so + # lets rewrite the URL to replace the server with the delegated + # server name. + uri = urllib.parse.urlunparse( + ( + parsed_uri.scheme, + delegated_server, + parsed_uri.path, + parsed_uri.params, + parsed_uri.query, + parsed_uri.fragment, + ) ) + parsed_uri = urllib.parse.urlparse(uri) - # make sure that the Host header is set correctly + # We need to make sure the host header is set to the netloc of the + # server. if headers is None: headers = Headers() else: headers = headers.copy() if not headers.hasHeader(b"host"): - headers.addRawHeader(b"host", res.host_header) + headers.addRawHeader(b"host", parsed_uri.netloc) - class EndpointFactory(object): - @staticmethod - def endpointForURI(_uri): - ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port - ) - if tls_options is not None: - ep = wrapClientTLS(tls_options, ep) - return ep + with PreserveLoggingContext(): + res = yield self._agent.request(method, uri, headers, bodyProducer) - agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool) - res = yield make_deferred_yieldable( - agent.request(method, uri, headers, bodyProducer) - ) return res - @defer.inlineCallbacks - def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): - """Helper for `request`: determine the routing for a Matrix URI - Args: - parsed_uri (twisted.web.client.URI): uri to route. Note that it should be - parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1 - if there is no explicit port given. +@implementer(IAgentEndpointFactory) +class MatrixHostnameEndpointFactory(object): + """Factory for MatrixHostnameEndpoint for parsing to an Agent. + """ - lookup_well_known (bool): True if we should look up the .well-known file if - there is no SRV record. + def __init__(self, reactor, tls_client_options_factory, srv_resolver): + self._reactor = reactor + self._tls_client_options_factory = tls_client_options_factory - Returns: - Deferred[_RoutingResult] - """ - # check for an IP literal - try: - ip_address = IPAddress(parsed_uri.host.decode("ascii")) - except Exception: - # not an IP address - ip_address = None - - if ip_address: - port = parsed_uri.port - if port == -1: - port = 8448 - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - ) + if srv_resolver is None: + srv_resolver = SrvResolver() - if parsed_uri.port != -1: - # there is an explicit port - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - ) + self._srv_resolver = srv_resolver - if lookup_well_known: - # try a .well-known lookup - well_known_result = yield self._well_known_resolver.get_well_known( - parsed_uri.host - ) - well_known_server = well_known_result.delegated_server - - if well_known_server: - # if we found a .well-known, start again, but don't do another - # .well-known lookup. - - # parse the server name in the .well-known response into host/port. - # (This code is lifted from twisted.web.client.URI.fromBytes). - if b":" in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b":", 1) - try: - well_known_port = int(well_known_port) - except ValueError: - # the part after the colon could not be parsed as an int - # - we assume it is an IPv6 literal with no port (the closing - # ']' stops it being parsed as an int) - well_known_host, well_known_port = well_known_server, -1 - else: - well_known_host, well_known_port = well_known_server, -1 - - new_uri = URI( - scheme=parsed_uri.scheme, - netloc=well_known_server, - host=well_known_host, - port=well_known_port, - path=parsed_uri.path, - params=parsed_uri.params, - query=parsed_uri.query, - fragment=parsed_uri.fragment, - ) + def endpointForURI(self, parsed_uri): + return MatrixHostnameEndpoint( + self._reactor, + self._tls_client_options_factory, + self._srv_resolver, + parsed_uri, + ) - res = yield self._route_matrix_uri(new_uri, lookup_well_known=False) - return res - - # try a SRV lookup - service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) - server_list = yield self._srv_resolver.resolve_service(service_name) - - if not server_list: - target_host = parsed_uri.host - port = 8448 - logger.debug( - "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), - target_host.decode("ascii"), - port, - ) + +@implementer(IStreamClientEndpoint) +class MatrixHostnameEndpoint(object): + """An endpoint that resolves matrix:// URLs using Matrix server name + resolution (i.e. via SRV). Does not check for well-known delegation. + """ + + def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri): + self._reactor = reactor + + # We reparse the URI so that defaultPort is -1 rather than 80 + self._parsed_uri = parsed_uri + + # set up the TLS connection params + # + # XXX disabling TLS is really only supported here for the benefit of the + # unit tests. We should make the UTs cope with TLS rather than having to make + # the code support the unit tests. + + if tls_client_options_factory is None: + self._tls_options = None else: - target_host, port = pick_server_from_list(server_list) - logger.debug( - "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), - port, - parsed_uri.host.decode("ascii"), + self._tls_options = tls_client_options_factory.get_options( + self._parsed_uri.host.decode("ascii") ) - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, - ) + self._srv_resolver = srv_resolver + @defer.inlineCallbacks + def connect(self, protocol_factory): + """Implements IStreamClientEndpoint interface + """ -@implementer(IStreamClientEndpoint) -class LoggingHostnameEndpoint(object): - """A wrapper for HostnameEndpint which logs when it connects""" + first_exception = None - def __init__(self, reactor, host, port, *args, **kwargs): - self.host = host - self.port = port - self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs) + server_list = yield self._resolve_server() - def connect(self, protocol_factory): - logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port) - return self.ep.connect(protocol_factory) + for server in server_list: + host = server.host + port = server.port + try: + logger.info("Connecting to %s:%i", host.decode("ascii"), port) + endpoint = HostnameEndpoint(self._reactor, host, port) + if self._tls_options: + endpoint = wrapClientTLS(self._tls_options, endpoint) + result = yield make_deferred_yieldable( + endpoint.connect(protocol_factory) + ) -@attr.s -class _RoutingResult(object): - """The result returned by `_route_matrix_uri`. + return result + except Exception as e: + logger.info( + "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e + ) + if not first_exception: + first_exception = e - Contains the parameters needed to direct a federation connection to a particular - server. + # We return the first failure because that's probably the most interesting. + if first_exception: + raise first_exception - Where a SRV record points to several servers, this object contains a single server - chosen from the list. - """ + # This shouldn't happen as we should always have at least one host/port + # to try and if that doesn't work then we'll have an exception. + raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,)) - host_header = attr.ib() - """ - The value we should assign to the Host header (host:port from the matrix - URI, or .well-known). + @defer.inlineCallbacks + def _resolve_server(self): + """Resolves the server name to a list of hosts and ports to attempt to + connect to. - :type: bytes - """ + Returns: + Deferred[list[Server]] + """ - tls_server_name = attr.ib() - """ - The server name we should set in the SNI (typically host, without port, from the - matrix URI or .well-known) + if self._parsed_uri.scheme != b"matrix": + return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)] - :type: bytes - """ + # Note: We don't do well-known lookup as that needs to have happened + # before now, due to needing to rewrite the Host header of the HTTP + # request. - target_host = attr.ib() - """ - The hostname (or IP literal) we should route the TCP connection to (the target of the - SRV record, or the hostname from the URL/.well-known) + parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes()) - :type: bytes - """ + host = parsed_uri.hostname + port = parsed_uri.port - target_port = attr.ib() - """ - The port we should route the TCP connection to (the target of the SRV record, or - the port from the URL/.well-known, or 8448) + # If there is an explicit port or the host is an IP address we bypass + # SRV lookups and just use the given host/port. + if port or _is_ip_literal(host): + return [Server(host, port or 8448)] - :type: int + server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host) + + if server_list: + return server_list + + # No SRV records, so we fallback to host and 8448 + return [Server(host, 8448)] + + +def _is_ip_literal(host): + """Test if the given host name is either an IPv4 or IPv6 literal. + + Args: + host (bytes) + + Returns: + bool """ + + host = host.decode("ascii") + + try: + IPAddress(host) + return True + except AddrFormatError: + return False diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index b32188766de7..bbda0a23f4d6 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -32,7 +32,7 @@ SERVER_CACHE = {} -@attr.s +@attr.s(slots=True, frozen=True) class Server(object): """ Our record of an individual server which can be tried to reach a destination. @@ -83,6 +83,35 @@ def pick_server_from_list(server_list): raise RuntimeError("pick_server_from_list got to end of eligible server list.") +def _sort_server_list(server_list): + """Given a list of SRV records sort them into priority order and shuffle + each priority with the given weight. + """ + priority_map = {} + + for server in server_list: + priority_map.setdefault(server.priority, []).append(server) + + results = [] + for priority in sorted(priority_map): + servers = priority_map.pop(priority) + + while servers: + total_weight = sum(s.weight for s in servers) + target_weight = random.randint(0, total_weight) + + for s in servers: + target_weight -= s.weight + + if target_weight <= 0: + break + + results.append(s) + servers.remove(s) + + return results + + class SrvResolver(object): """Interface to the dns client to do SRV lookups, with result caching. @@ -120,7 +149,7 @@ def resolve_service(self, service_name): if cache_entry: if all(s.expires > now for s in cache_entry): servers = list(cache_entry) - return servers + return _sort_server_list(servers) try: answers, _, _ = yield make_deferred_yieldable( @@ -169,4 +198,4 @@ def resolve_service(self, service_name): ) self._cache[service_name] = list(servers) - return servers + return _sort_server_list(servers) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 2c568788b306..f97c8a59f6f1 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -41,9 +41,9 @@ from synapse.logging.context import LoggingContext from synapse.util.caches.ttlcache import TTLCache +from tests import unittest from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.server import FakeTransport, ThreadedMemoryReactorClock -from tests.unittest import TestCase from tests.utils import default_config logger = logging.getLogger(__name__) @@ -67,7 +67,8 @@ def get_connection_factory(): return test_server_connection_factory -class MatrixFederationAgentTests(TestCase): +@unittest.DEBUG +class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -1056,8 +1057,64 @@ def test_well_known_cache_with_temp_failure(self): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) + def test_srv_fallbacks(self): + """Test that other SRV results are tried if the first one fails. + """ + + self.mock_resolver.resolve_service.side_effect = lambda _: [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + self.reactor.lookups["target.com"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.testserv" + ) + + # We should see an attempt to connect to the first server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8443) + + # Fonx the connection + client_factory.clientConnectionFailed(None, Exception("nope")) + + # There's a 300ms delay in HostnameEndpoint + self.reactor.pump((0.4,)) + + # Hasn't failed yet + self.assertNoResult(test_d) + + # We shouldnow see an attempt to connect to the second server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8444) + + # make a test server, and wire up the client + http_server = self._make_connection(client_factory, expected_sni=b"testserv") + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + -class TestCachePeriodFromHeaders(TestCase): +class TestCachePeriodFromHeaders(unittest.TestCase): def test_cache_control(self): # uppercase self.assertEqual( diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 3b885ef64bde..df034ab2378d 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -83,8 +83,10 @@ def test_from_cache_expired_and_dns_fail(self): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 0 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) @@ -105,8 +107,10 @@ def test_from_cache(self): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 999999999 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver( From c03e3e83010d0147515e3771353af6b89bf8cf03 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 15 Aug 2019 15:33:22 +0100 Subject: [PATCH 2/9] Newsfile --- changelog.d/5864.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/5864.misc diff --git a/changelog.d/5864.misc b/changelog.d/5864.misc new file mode 100644 index 000000000000..40ac11db6449 --- /dev/null +++ b/changelog.d/5864.misc @@ -0,0 +1 @@ +Correctly retry all hosts returned from SRV when we fail to connect. From 7777d353bfffc840b79391da107e593338a1a2fe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Aug 2019 11:46:54 +0100 Subject: [PATCH 3/9] Remove test debugs --- tests/federation/test_federation_server.py | 1 - tests/http/federation/test_matrix_federation_agent.py | 1 - tests/test_visibility.py | 1 - 3 files changed, 3 deletions(-) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index af15f4cc5a81..b08be451aa03 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -20,7 +20,6 @@ from tests import unittest -@unittest.DEBUG class ServerACLsTestCase(unittest.TestCase): def test_blacklisted_server(self): e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index f97c8a59f6f1..445a0e76abae 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -67,7 +67,6 @@ def get_connection_factory(): return test_server_connection_factory -@unittest.DEBUG class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0605dac2ffd..18f1a0035d6f 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -74,7 +74,6 @@ def test_filtering(self): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") - @tests.unittest.DEBUG @defer.inlineCallbacks def test_erased_user(self): # 4 message events, from erased and unerased users, with a membership From 1f9df1cc7ba7027aef3a38d01909a928ecf2a8c5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Aug 2019 11:49:44 +0100 Subject: [PATCH 4/9] Fixup _sort_server_list to be slightly more efficient Also document that we are using the algorithm described in RFC2782 and ensure we handle zero weight correctly. --- synapse/http/federation/srv_resolver.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index bbda0a23f4d6..110b112e8584 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -94,10 +94,18 @@ def _sort_server_list(server_list): results = [] for priority in sorted(priority_map): - servers = priority_map.pop(priority) + servers = priority_map[priority] + # This algorithms follows the algorithm described in RFC2782. + # + # N.B. Weights can be zero, which means that you should pick that server + # last *or* that its the only server in this priority. + + # We sort to ensure zero weighted items are first. + servers.sort(key=lambda s: s.weight) + + total_weight = sum(s.weight for s in servers) while servers: - total_weight = sum(s.weight for s in servers) target_weight = random.randint(0, total_weight) for s in servers: @@ -108,6 +116,7 @@ def _sort_server_list(server_list): results.append(s) servers.remove(s) + total_weight -= s.weight return results From 74f016d343fe270ab3affe79cc82266d94120e5c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Aug 2019 11:50:12 +0100 Subject: [PATCH 5/9] Remove now unused pick_server_from_list --- synapse/http/federation/srv_resolver.py | 30 ------------------------- 1 file changed, 30 deletions(-) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 110b112e8584..c8ca3fd0e9de 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -53,36 +53,6 @@ class Server(object): expires = attr.ib(default=0) -def pick_server_from_list(server_list): - """Randomly choose a server from the server list - - Args: - server_list (list[Server]): list of candidate servers - - Returns: - Tuple[bytes, int]: (host, port) pair for the chosen server - """ - if not server_list: - raise RuntimeError("pick_server_from_list called with empty list") - - # TODO: currently we only use the lowest-priority servers. We should maintain a - # cache of servers known to be "down" and filter them out - - min_priority = min(s.priority for s in server_list) - eligible_servers = list(s for s in server_list if s.priority == min_priority) - total_weight = sum(s.weight for s in eligible_servers) - target_weight = random.randint(0, total_weight) - - for s in eligible_servers: - target_weight -= s.weight - - if target_weight <= 0: - return s.host, s.port - - # this should be impossible. - raise RuntimeError("pick_server_from_list got to end of eligible server list.") - - def _sort_server_list(server_list): """Given a list of SRV records sort them into priority order and shuffle each priority with the given weight. From 29763f01c63f1c5d5053dad413b69f1980208131 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Aug 2019 12:38:06 +0100 Subject: [PATCH 6/9] Make changelog entry be a feature --- changelog.d/{5864.misc => 5864.feature} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename changelog.d/{5864.misc => 5864.feature} (100%) diff --git a/changelog.d/5864.misc b/changelog.d/5864.feature similarity index 100% rename from changelog.d/5864.misc rename to changelog.d/5864.feature From e70f0081da5bbea772316dffda5f173e9568d1d3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 23 Aug 2019 15:05:56 +0100 Subject: [PATCH 7/9] Fix logcontexts --- synapse/http/federation/matrix_federation_agent.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index a8815f078a3e..62883c06a451 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -28,7 +28,7 @@ from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver -from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util import Clock logger = logging.getLogger(__name__) @@ -158,8 +158,9 @@ def request(self, method, uri, headers=None, bodyProducer=None): if not headers.hasHeader(b"host"): headers.addRawHeader(b"host", parsed_uri.netloc) - with PreserveLoggingContext(): - res = yield self._agent.request(method, uri, headers, bodyProducer) + res = yield make_deferred_yieldable( + self._agent.request(method, uri, headers, bodyProducer) + ) return res @@ -214,11 +215,14 @@ def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri self._srv_resolver = srv_resolver - @defer.inlineCallbacks def connect(self, protocol_factory): """Implements IStreamClientEndpoint interface """ + return run_in_background(self._do_connect, protocol_factory) + + @defer.inlineCallbacks + def _do_connect(self, protocol_factory): first_exception = None server_list = yield self._resolve_server() From fbb758a7cef9282fee605eb6bc9f1b2d430d8d62 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 23 Aug 2019 15:09:08 +0100 Subject: [PATCH 8/9] Fixup comments --- synapse/http/federation/matrix_federation_agent.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 62883c06a451..feae7de5bec2 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -192,12 +192,19 @@ def endpointForURI(self, parsed_uri): class MatrixHostnameEndpoint(object): """An endpoint that resolves matrix:// URLs using Matrix server name resolution (i.e. via SRV). Does not check for well-known delegation. + + Args: + reactor (IReactor) + tls_client_options_factory (ClientTLSOptionsFactory|None): + factory to use for fetching client tls options, or none to disable TLS. + srv_resolver (SrvResolver): The SRV resolver to use + parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting + to connect to. """ def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri): self._reactor = reactor - # We reparse the URI so that defaultPort is -1 rather than 80 self._parsed_uri = parsed_uri # set up the TLS connection params @@ -272,6 +279,7 @@ def _resolve_server(self): # before now, due to needing to rewrite the Host header of the HTTP # request. + # We reparse the URI so that defaultPort is -1 rather than 80 parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes()) host = parsed_uri.hostname From 91caa5b4303bfa0b4604ecf95d56ae72a7074b0b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 27 Aug 2019 13:56:42 +0100 Subject: [PATCH 9/9] Fix off by one error in SRV result shuffling --- synapse/http/federation/srv_resolver.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index c8ca3fd0e9de..3fe4ffb9e5f4 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -66,17 +66,18 @@ def _sort_server_list(server_list): for priority in sorted(priority_map): servers = priority_map[priority] - # This algorithms follows the algorithm described in RFC2782. + # This algorithms roughly follows the algorithm described in RFC2782, + # changed to remove an off-by-one error. # - # N.B. Weights can be zero, which means that you should pick that server - # last *or* that its the only server in this priority. - - # We sort to ensure zero weighted items are first. - servers.sort(key=lambda s: s.weight) + # N.B. Weights can be zero, which means that they should be picked + # rarely. total_weight = sum(s.weight for s in servers) - while servers: - target_weight = random.randint(0, total_weight) + + # Total weight can become zero if there are only zero weight servers + # left, which we handle by just shuffling and appending to the results. + while servers and total_weight: + target_weight = random.randint(1, total_weight) for s in servers: target_weight -= s.weight @@ -88,6 +89,10 @@ def _sort_server_list(server_list): servers.remove(s) total_weight -= s.weight + if servers: + random.shuffle(servers) + results.extend(servers) + return results