diff --git a/synapse/http/site.py b/synapse/http/site.py index 7421c172e487..299fe2af9a81 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -16,6 +16,10 @@ import time from typing import Optional, Union +import attr +from zope.interface import implementer + +from twisted.internet.interfaces import IAddress from twisted.python.failure import Failure from twisted.web.server import Request, Site @@ -336,37 +340,71 @@ class XForwardedForRequest(SynapseRequest): """ Add a layer on top of another request that only uses the value of an X-Forwarded-For header as the result of C{getClientIP}. - - XXX: I think the right way to do this is with request.setHost(). """ - def __init__(self, *args, **kw): - SynapseRequest.__init__(self, *args, **kw) + # the client IP and ssl flag, as extracted from the headers. If no X-F-F header + # is found, then _client_ip is 'None'. + _client_ip = None # type: Optional[str] + _force_ssl = False # type: bool + + def requestReceived(self, command, path, version): + # this method is called by the Channel once the full request has been + # received, to dispatch the request to a resource. + # We can use it to set the IP address and port according to the + # headers. + self._process_forwarded_headers() + return super().requestReceived(command, path, version) + + def _process_forwarded_headers(self): + headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for") + if not headers: + return + + # for now, we just use the first x-forwarded-for header. Really, we ought + # to start from the client IP address, and check whether it is trusted; if it + # is, work backwards through the headers until we find an untrusted address. + # see https://github.com/matrix-org/synapse/issues/9471 + self._client_ip = headers[0].split(b",")[0].strip().decode("ascii") - forwarded_header = self.getHeader(b"x-forwarded-proto") - if forwarded_header is not None: - self._is_secure = forwarded_header.lower() == b"https" + header = self.getHeader(b"x-forwarded-proto") + if header is not None: + self._force_ssl = header.lower() == b"https" else: - logger.warning( - "received request lacks an x-forwarded-proto header: assuming https" + # we have to assume http if there is no x-forwarded-proto, since that is + # a common convention. + logger.info( + "forwarded request lacks an x-forwarded-proto header: assuming http" ) - self._is_secure = True + self._force_ssl = False def isSecure(self): - return self._is_secure + if self._force_ssl: + return True + return super().isSecure() + + def getClientIP(self) -> str: + """ + Return the IP address of the client who submitted this request. - def getClientIP(self): + This method is deprecated. Use getClientAddress() instead. """ - @return: The client address (the first address) in the value of the - I{X-Forwarded-For header}. If the header is not present, return - C{b"-"}. + if self._client_ip is not None: + return self._client_ip + return super().getClientIP() + + def getClientAddress(self) -> IAddress: """ - return ( - self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0] - .split(b",")[0] - .strip() - .decode("ascii") - ) + Return the address of the client who submitted this request. + """ + if self._client_ip is not None: + return _XForwardedForAddress(self._client_ip) + return super().getClientAddress() + + +@implementer(IAddress) +@attr.s(frozen=True, slots=True) +class _XForwardedForAddress: + host = attr.ib(type=str) class SynapseSite(Site):