diff --git a/changelog.d/5864.feature b/changelog.d/5864.feature new file mode 100644 index 000000000000..40ac11db6449 --- /dev/null +++ b/changelog.d/5864.feature @@ -0,0 +1 @@ +Correctly retry all hosts returned from SRV when we fail to connect. diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 64f62aaeeccd..feae7de5bec2 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,21 +14,21 @@ # limitations under the License. import logging +import urllib -import attr -from netaddr import IPAddress +from netaddr import AddrFormatError, IPAddress from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import URI, Agent, HTTPConnectionPool +from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent +from twisted.web.iweb import IAgent, IAgentEndpointFactory -from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list +from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver -from synapse.logging.context import make_deferred_yieldable +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util import Clock logger = logging.getLogger(__name__) @@ -36,8 +36,9 @@ @implementer(IAgent) class MatrixFederationAgent(object): - """An Agent-like thing which provides a `request` method which will look up a matrix - server and send an HTTP request to it. + """An Agent-like thing which provides a `request` method which correctly + handles resolving matrix server names when using matrix://. Handles standard + https URIs as normal. Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) @@ -65,17 +66,19 @@ def __init__( ): self._reactor = reactor self._clock = Clock(reactor) - - self._tls_client_options_factory = tls_client_options_factory - if _srv_resolver is None: - _srv_resolver = SrvResolver() - self._srv_resolver = _srv_resolver - self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 + self._agent = Agent.usingEndpointFactory( + self._reactor, + MatrixHostnameEndpointFactory( + reactor, tls_client_options_factory, _srv_resolver + ), + pool=self._pool, + ) + if _well_known_resolver is None: _well_known_resolver = WellKnownResolver( self._reactor, @@ -93,19 +96,15 @@ def request(self, method, uri, headers=None, bodyProducer=None): """ Args: method (bytes): HTTP method: GET/POST/etc - uri (bytes): Absolute URI to be retrieved - headers (twisted.web.http_headers.Headers|None): HTTP headers to send with the request, or None to send no extra headers. - bodyProducer (twisted.web.iweb.IBodyProducer|None): An object which can generate bytes to make up the body of this request (for example, the properly encoded contents of a file for a file upload). Or None if the request is to have no body. - Returns: Deferred[twisted.web.iweb.IResponse]: fires when the header of the response has been received (regardless of the @@ -113,210 +112,207 @@ def request(self, method, uri, headers=None, bodyProducer=None): response from being received (including problems that prevent the request from being sent). """ - parsed_uri = URI.fromBytes(uri, defaultPort=-1) - res = yield self._route_matrix_uri(parsed_uri) + # We use urlparse as that will set `port` to None if there is no + # explicit port. + parsed_uri = urllib.parse.urlparse(uri) - # set up the TLS connection params + # If this is a matrix:// URI check if the server has delegated matrix + # traffic using well-known delegation. # - # XXX disabling TLS is really only supported here for the benefit of the - # unit tests. We should make the UTs cope with TLS rather than having to make - # the code support the unit tests. - if self._tls_client_options_factory is None: - tls_options = None - else: - tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii") + # We have to do this here and not in the endpoint as we need to rewrite + # the host header with the delegated server name. + delegated_server = None + if ( + parsed_uri.scheme == b"matrix" + and not _is_ip_literal(parsed_uri.hostname) + and not parsed_uri.port + ): + well_known_result = yield self._well_known_resolver.get_well_known( + parsed_uri.hostname ) + delegated_server = well_known_result.delegated_server + + if delegated_server: + # Ok, the server has delegated matrix traffic to somewhere else, so + # lets rewrite the URL to replace the server with the delegated + # server name. + uri = urllib.parse.urlunparse( + ( + parsed_uri.scheme, + delegated_server, + parsed_uri.path, + parsed_uri.params, + parsed_uri.query, + parsed_uri.fragment, + ) + ) + parsed_uri = urllib.parse.urlparse(uri) - # make sure that the Host header is set correctly + # We need to make sure the host header is set to the netloc of the + # server. if headers is None: headers = Headers() else: headers = headers.copy() if not headers.hasHeader(b"host"): - headers.addRawHeader(b"host", res.host_header) + headers.addRawHeader(b"host", parsed_uri.netloc) - class EndpointFactory(object): - @staticmethod - def endpointForURI(_uri): - ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port - ) - if tls_options is not None: - ep = wrapClientTLS(tls_options, ep) - return ep - - agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool) res = yield make_deferred_yieldable( - agent.request(method, uri, headers, bodyProducer) + self._agent.request(method, uri, headers, bodyProducer) ) + return res - @defer.inlineCallbacks - def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): - """Helper for `request`: determine the routing for a Matrix URI - Args: - parsed_uri (twisted.web.client.URI): uri to route. Note that it should be - parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1 - if there is no explicit port given. +@implementer(IAgentEndpointFactory) +class MatrixHostnameEndpointFactory(object): + """Factory for MatrixHostnameEndpoint for parsing to an Agent. + """ - lookup_well_known (bool): True if we should look up the .well-known file if - there is no SRV record. + def __init__(self, reactor, tls_client_options_factory, srv_resolver): + self._reactor = reactor + self._tls_client_options_factory = tls_client_options_factory - Returns: - Deferred[_RoutingResult] - """ - # check for an IP literal - try: - ip_address = IPAddress(parsed_uri.host.decode("ascii")) - except Exception: - # not an IP address - ip_address = None - - if ip_address: - port = parsed_uri.port - if port == -1: - port = 8448 - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - ) + if srv_resolver is None: + srv_resolver = SrvResolver() - if parsed_uri.port != -1: - # there is an explicit port - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - ) + self._srv_resolver = srv_resolver - if lookup_well_known: - # try a .well-known lookup - well_known_result = yield self._well_known_resolver.get_well_known( - parsed_uri.host - ) - well_known_server = well_known_result.delegated_server - - if well_known_server: - # if we found a .well-known, start again, but don't do another - # .well-known lookup. - - # parse the server name in the .well-known response into host/port. - # (This code is lifted from twisted.web.client.URI.fromBytes). - if b":" in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b":", 1) - try: - well_known_port = int(well_known_port) - except ValueError: - # the part after the colon could not be parsed as an int - # - we assume it is an IPv6 literal with no port (the closing - # ']' stops it being parsed as an int) - well_known_host, well_known_port = well_known_server, -1 - else: - well_known_host, well_known_port = well_known_server, -1 - - new_uri = URI( - scheme=parsed_uri.scheme, - netloc=well_known_server, - host=well_known_host, - port=well_known_port, - path=parsed_uri.path, - params=parsed_uri.params, - query=parsed_uri.query, - fragment=parsed_uri.fragment, - ) + def endpointForURI(self, parsed_uri): + return MatrixHostnameEndpoint( + self._reactor, + self._tls_client_options_factory, + self._srv_resolver, + parsed_uri, + ) - res = yield self._route_matrix_uri(new_uri, lookup_well_known=False) - return res - - # try a SRV lookup - service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) - server_list = yield self._srv_resolver.resolve_service(service_name) - - if not server_list: - target_host = parsed_uri.host - port = 8448 - logger.debug( - "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), - target_host.decode("ascii"), - port, - ) + +@implementer(IStreamClientEndpoint) +class MatrixHostnameEndpoint(object): + """An endpoint that resolves matrix:// URLs using Matrix server name + resolution (i.e. via SRV). Does not check for well-known delegation. + + Args: + reactor (IReactor) + tls_client_options_factory (ClientTLSOptionsFactory|None): + factory to use for fetching client tls options, or none to disable TLS. + srv_resolver (SrvResolver): The SRV resolver to use + parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting + to connect to. + """ + + def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri): + self._reactor = reactor + + self._parsed_uri = parsed_uri + + # set up the TLS connection params + # + # XXX disabling TLS is really only supported here for the benefit of the + # unit tests. We should make the UTs cope with TLS rather than having to make + # the code support the unit tests. + + if tls_client_options_factory is None: + self._tls_options = None else: - target_host, port = pick_server_from_list(server_list) - logger.debug( - "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), - port, - parsed_uri.host.decode("ascii"), + self._tls_options = tls_client_options_factory.get_options( + self._parsed_uri.host.decode("ascii") ) - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, - ) + self._srv_resolver = srv_resolver + def connect(self, protocol_factory): + """Implements IStreamClientEndpoint interface + """ -@implementer(IStreamClientEndpoint) -class LoggingHostnameEndpoint(object): - """A wrapper for HostnameEndpint which logs when it connects""" + return run_in_background(self._do_connect, protocol_factory) - def __init__(self, reactor, host, port, *args, **kwargs): - self.host = host - self.port = port - self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs) + @defer.inlineCallbacks + def _do_connect(self, protocol_factory): + first_exception = None + + server_list = yield self._resolve_server() + + for server in server_list: + host = server.host + port = server.port + + try: + logger.info("Connecting to %s:%i", host.decode("ascii"), port) + endpoint = HostnameEndpoint(self._reactor, host, port) + if self._tls_options: + endpoint = wrapClientTLS(self._tls_options, endpoint) + result = yield make_deferred_yieldable( + endpoint.connect(protocol_factory) + ) - def connect(self, protocol_factory): - logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port) - return self.ep.connect(protocol_factory) + return result + except Exception as e: + logger.info( + "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e + ) + if not first_exception: + first_exception = e + # We return the first failure because that's probably the most interesting. + if first_exception: + raise first_exception -@attr.s -class _RoutingResult(object): - """The result returned by `_route_matrix_uri`. + # This shouldn't happen as we should always have at least one host/port + # to try and if that doesn't work then we'll have an exception. + raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,)) - Contains the parameters needed to direct a federation connection to a particular - server. + @defer.inlineCallbacks + def _resolve_server(self): + """Resolves the server name to a list of hosts and ports to attempt to + connect to. - Where a SRV record points to several servers, this object contains a single server - chosen from the list. - """ + Returns: + Deferred[list[Server]] + """ - host_header = attr.ib() - """ - The value we should assign to the Host header (host:port from the matrix - URI, or .well-known). + if self._parsed_uri.scheme != b"matrix": + return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)] - :type: bytes - """ + # Note: We don't do well-known lookup as that needs to have happened + # before now, due to needing to rewrite the Host header of the HTTP + # request. - tls_server_name = attr.ib() - """ - The server name we should set in the SNI (typically host, without port, from the - matrix URI or .well-known) + # We reparse the URI so that defaultPort is -1 rather than 80 + parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes()) - :type: bytes - """ + host = parsed_uri.hostname + port = parsed_uri.port - target_host = attr.ib() - """ - The hostname (or IP literal) we should route the TCP connection to (the target of the - SRV record, or the hostname from the URL/.well-known) + # If there is an explicit port or the host is an IP address we bypass + # SRV lookups and just use the given host/port. + if port or _is_ip_literal(host): + return [Server(host, port or 8448)] - :type: bytes - """ + server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host) - target_port = attr.ib() - """ - The port we should route the TCP connection to (the target of the SRV record, or - the port from the URL/.well-known, or 8448) + if server_list: + return server_list + + # No SRV records, so we fallback to host and 8448 + return [Server(host, 8448)] + + +def _is_ip_literal(host): + """Test if the given host name is either an IPv4 or IPv6 literal. - :type: int + Args: + host (bytes) + + Returns: + bool """ + + host = host.decode("ascii") + + try: + IPAddress(host) + return True + except AddrFormatError: + return False diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index b32188766de7..3fe4ffb9e5f4 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -32,7 +32,7 @@ SERVER_CACHE = {} -@attr.s +@attr.s(slots=True, frozen=True) class Server(object): """ Our record of an individual server which can be tried to reach a destination. @@ -53,34 +53,47 @@ class Server(object): expires = attr.ib(default=0) -def pick_server_from_list(server_list): - """Randomly choose a server from the server list +def _sort_server_list(server_list): + """Given a list of SRV records sort them into priority order and shuffle + each priority with the given weight. + """ + priority_map = {} - Args: - server_list (list[Server]): list of candidate servers + for server in server_list: + priority_map.setdefault(server.priority, []).append(server) - Returns: - Tuple[bytes, int]: (host, port) pair for the chosen server - """ - if not server_list: - raise RuntimeError("pick_server_from_list called with empty list") + results = [] + for priority in sorted(priority_map): + servers = priority_map[priority] + + # This algorithms roughly follows the algorithm described in RFC2782, + # changed to remove an off-by-one error. + # + # N.B. Weights can be zero, which means that they should be picked + # rarely. + + total_weight = sum(s.weight for s in servers) + + # Total weight can become zero if there are only zero weight servers + # left, which we handle by just shuffling and appending to the results. + while servers and total_weight: + target_weight = random.randint(1, total_weight) - # TODO: currently we only use the lowest-priority servers. We should maintain a - # cache of servers known to be "down" and filter them out + for s in servers: + target_weight -= s.weight - min_priority = min(s.priority for s in server_list) - eligible_servers = list(s for s in server_list if s.priority == min_priority) - total_weight = sum(s.weight for s in eligible_servers) - target_weight = random.randint(0, total_weight) + if target_weight <= 0: + break - for s in eligible_servers: - target_weight -= s.weight + results.append(s) + servers.remove(s) + total_weight -= s.weight - if target_weight <= 0: - return s.host, s.port + if servers: + random.shuffle(servers) + results.extend(servers) - # this should be impossible. - raise RuntimeError("pick_server_from_list got to end of eligible server list.") + return results class SrvResolver(object): @@ -120,7 +133,7 @@ def resolve_service(self, service_name): if cache_entry: if all(s.expires > now for s in cache_entry): servers = list(cache_entry) - return servers + return _sort_server_list(servers) try: answers, _, _ = yield make_deferred_yieldable( @@ -169,4 +182,4 @@ def resolve_service(self, service_name): ) self._cache[service_name] = list(servers) - return servers + return _sort_server_list(servers) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index af15f4cc5a81..b08be451aa03 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -20,7 +20,6 @@ from tests import unittest -@unittest.DEBUG class ServerACLsTestCase(unittest.TestCase): def test_blacklisted_server(self): e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index f5f46acee149..cfcd98ff7d8b 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -41,9 +41,9 @@ from synapse.logging.context import LoggingContext from synapse.util.caches.ttlcache import TTLCache +from tests import unittest from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.server import FakeTransport, ThreadedMemoryReactorClock -from tests.unittest import TestCase from tests.utils import default_config logger = logging.getLogger(__name__) @@ -67,7 +67,7 @@ def get_connection_factory(): return test_server_connection_factory -class MatrixFederationAgentTests(TestCase): +class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -1074,8 +1074,64 @@ def test_well_known_cache_with_temp_failure(self): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) + def test_srv_fallbacks(self): + """Test that other SRV results are tried if the first one fails. + """ + + self.mock_resolver.resolve_service.side_effect = lambda _: [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + self.reactor.lookups["target.com"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.testserv" + ) + + # We should see an attempt to connect to the first server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8443) + + # Fonx the connection + client_factory.clientConnectionFailed(None, Exception("nope")) + + # There's a 300ms delay in HostnameEndpoint + self.reactor.pump((0.4,)) + + # Hasn't failed yet + self.assertNoResult(test_d) + + # We shouldnow see an attempt to connect to the second server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8444) + + # make a test server, and wire up the client + http_server = self._make_connection(client_factory, expected_sni=b"testserv") + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + -class TestCachePeriodFromHeaders(TestCase): +class TestCachePeriodFromHeaders(unittest.TestCase): def test_cache_control(self): # uppercase self.assertEqual( diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 3b885ef64bde..df034ab2378d 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -83,8 +83,10 @@ def test_from_cache_expired_and_dns_fail(self): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 0 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) @@ -105,8 +107,10 @@ def test_from_cache(self): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 999999999 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver( diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0605dac2ffd..18f1a0035d6f 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -74,7 +74,6 @@ def test_filtering(self): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") - @tests.unittest.DEBUG @defer.inlineCallbacks def test_erased_user(self): # 4 message events, from erased and unerased users, with a membership