From 9e0ae2fedbf4abe12cbe823c168286f3990496c5 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 20 Nov 2018 19:27:17 -0600 Subject: [PATCH 01/25] fix --- .gitignore | 1 + synapse/http/client.py | 39 +++++++++++++-- synapse/http/endpoint.py | 35 -------------- synapse/rest/media/v1/preview_url_resource.py | 9 +++- tests/rest/media/v1/test_url_preview.py | 48 +++++++++++++++++++ 5 files changed, 92 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index 3b2252ad8a5a..bc769581d371 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,4 @@ env/ .vscode/ .ropeproject/ +.coverage.* diff --git a/synapse/http/client.py b/synapse/http/client.py index 3d05f83b8c3f..96e152051468 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -29,6 +29,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web._newclient import ResponseDone from twisted.web.client import ( + URI, Agent, BrowserLikeRedirectAgent, ContentDecoderAgent, @@ -59,10 +60,20 @@ class SimpleHttpClient(object): A simple, no-frills HTTP client with methods that wrap up common ways of using HTTP in Matrix """ - def __init__(self, hs): + + def __init__(self, hs, treq_args=None, whitelist=None, blacklist=None, _treq=None): + + if not _treq: + self._treq = treq + else: + self._treq = _treq + self.hs = hs pool = HTTPConnectionPool(reactor) + self._extra_treq_args = treq_args + self.whitelist = whitelist + self.blacklist = blacklist # the pusher makes lots of concurrent SSL connections to sygnal, and # tends to do so in batches, so we need to allow the pool to keep lots @@ -81,6 +92,7 @@ def __init__(self, hs): ) self.user_agent = hs.version_string self.clock = hs.get_clock() + self.reactor = hs.get_reactor() if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) @@ -88,6 +100,22 @@ def __init__(self, hs): @defer.inlineCallbacks def request(self, method, uri, data=b'', headers=None): + + # Check our IP whitelists/blacklists before making the request. + if self.blacklist: + split_uri = URI.fromBytes(uri.encode('utf8')) + address = yield self.reactor.resolve(split_uri.host) + + from netaddr import IPAddress + + ip_address = IPAddress(address) + + if ip_address in self.blacklist: + if self.whitelist is None or ip_address not in self.whitelist: + raise ConnectError( + "Refusing to spider blacklisted IP address %s" % address + ) + # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.labels(method).inc() @@ -96,8 +124,13 @@ def request(self, method, uri, data=b'', headers=None): logger.info("Sending request %s %s", method, redact_uri(uri)) try: - request_deferred = treq.request( - method, uri, agent=self.agent, data=data, headers=headers + request_deferred = self._treq.request( + method, + uri, + agent=self.agent, + data=data, + headers=headers, + **self._extra_treq_args ) request_deferred = timeout_deferred( request_deferred, 60, self.hs.get_reactor(), diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 91025037a301..f86a0b624eae 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -218,41 +218,6 @@ def update_request_time(res): return d -class SpiderEndpoint(object): - """An endpoint which refuses to connect to blacklisted IP addresses - Implements twisted.internet.interfaces.IStreamClientEndpoint. - """ - def __init__(self, reactor, host, port, blacklist, whitelist, - endpoint=HostnameEndpoint, endpoint_kw_args={}): - self.reactor = reactor - self.host = host - self.port = port - self.blacklist = blacklist - self.whitelist = whitelist - self.endpoint = endpoint - self.endpoint_kw_args = endpoint_kw_args - - @defer.inlineCallbacks - def connect(self, protocolFactory): - address = yield self.reactor.resolve(self.host) - - from netaddr import IPAddress - ip_address = IPAddress(address) - - if ip_address in self.blacklist: - if self.whitelist is None or ip_address not in self.whitelist: - raise ConnectError( - "Refusing to spider blacklisted IP address %s" % address - ) - - logger.info("Connecting to %s:%s", address, self.port) - endpoint = self.endpoint( - self.reactor, address, self.port, **self.endpoint_kw_args - ) - connection = yield endpoint.connect(protocolFactory) - defer.returnValue(connection) - - class SRVClientEndpoint(object): """An endpoint which looks up SRV records for a service. Cycles through the list of servers starting with each call to connect diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index d0ecf241b624..cd539452937b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -35,7 +35,7 @@ from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError -from synapse.http.client import SpiderHttpClient +from synapse.http.client import SimpleHttpClient from synapse.http.server import ( respond_with_json, respond_with_json_bytes, @@ -69,7 +69,12 @@ def __init__(self, hs, media_repo, media_storage): self.max_spider_size = hs.config.max_spider_size self.server_name = hs.hostname self.store = hs.get_datastore() - self.client = SpiderHttpClient(hs) + self.client = SimpleHttpClient( + hs, + treq_args={"browser_like_redirects": True}, + whitelist=hs.config.url_preview_ip_range_whitelist, + blacklist=hs.config.url_preview_ip_range_blacklist, + ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index c62f71b44aa6..39ca7ea89c05 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -18,6 +18,7 @@ from mock import Mock from twisted.internet.defer import Deferred +from twisted.web.http_headers import Headers from synapse.config.repository import MediaStorageProviderConfig from synapse.util.logcontext import make_deferred_yieldable @@ -39,6 +40,7 @@ def make_homeserver(self, reactor, clock): config = self.default_config() config.url_preview_enabled = True config.max_spider_size = 9999999 + config.url_preview_ip_range_blacklist = None config.url_preview_url_blacklist = [] config.media_store_path = self.storage_path @@ -85,6 +87,7 @@ def write_to(r): self.media_repo = hs.get_media_repository_resource() preview_url = self.media_repo.children[b'preview_url'] + self._old_client = preview_url.client preview_url.client = client self.preview_url = preview_url @@ -240,3 +243,48 @@ def test_non_ascii_preview_content_type(self): self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + + def test_ipaddr(self): + """ + IP addresses can be previewed directly. + """ + # We don't want a fully mocked out client, just a mocked out Treq + treq = Mock() + d = Deferred() + treq.request = Mock(return_value=d) + self.preview_url.client = self._old_client + self.preview_url.client._treq = treq + + request, channel = self.make_request( + "GET", "url_preview?url=http://8.8.8.8", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + self.assertEqual(treq.request.call_count, 1) + + end_content = ( + b'' + b'' + b'' + b'' + ) + + # Assemble a mocked out response + def deliver(to): + to.dataReceived(end_content) + to.connectionLost(Mock()) + + res = Mock() + res.code = 200 + res.headers = Headers({b"Content-Type": [b"text/html"]}) + res.deliverBody = deliver + + # Deliver the mocked out response + d.callback(res) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) From 1474953a06b22393540a5a7f770bb44b496c9a80 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 20 Nov 2018 19:27:56 -0600 Subject: [PATCH 02/25] changelog --- changelog.d/4215.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/4215.misc diff --git a/changelog.d/4215.misc b/changelog.d/4215.misc new file mode 100644 index 000000000000..bb90594836a7 --- /dev/null +++ b/changelog.d/4215.misc @@ -0,0 +1 @@ +Getting URL previews of IP addresses no longer fails on Python 3. From 31e0fd9cd78ae3c8f9599baa6095c2910ab1d807 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 20 Nov 2018 19:29:36 -0600 Subject: [PATCH 03/25] fixes --- synapse/http/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 96e152051468..326d6fea873d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -43,7 +43,6 @@ from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri -from synapse.http.endpoint import SpiderEndpoint from synapse.util.async_helpers import timeout_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable From 15202526a7287c3367f0d1486cbd495283d6b441 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 20 Nov 2018 19:53:03 -0600 Subject: [PATCH 04/25] fixes --- synapse/http/client.py | 60 ++---------------------------------------- 1 file changed, 2 insertions(+), 58 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 326d6fea873d..2a94304c64f0 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -60,16 +60,11 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args=None, whitelist=None, blacklist=None, _treq=None): - - if not _treq: - self._treq = treq - else: - self._treq = _treq - + def __init__(self, hs, treq_args={}}, whitelist=None, blacklist=None, _treq=treq): self.hs = hs pool = HTTPConnectionPool(reactor) + self._treq = _treq self._extra_treq_args = treq_args self.whitelist = whitelist self.blacklist = blacklist @@ -495,57 +490,6 @@ def post_urlencoded_get_raw(self, url, args={}): defer.returnValue(e.response) -class SpiderEndpointFactory(object): - def __init__(self, hs): - self.blacklist = hs.config.url_preview_ip_range_blacklist - self.whitelist = hs.config.url_preview_ip_range_whitelist - self.policyForHTTPS = hs.get_http_client_context_factory() - - def endpointForURI(self, uri): - logger.info("Getting endpoint for %s", uri.toBytes()) - - if uri.scheme == b"http": - endpoint_factory = HostnameEndpoint - elif uri.scheme == b"https": - tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) - - def endpoint_factory(reactor, host, port, **kw): - return wrapClientTLS( - tlsCreator, - HostnameEndpoint(reactor, host, port, **kw)) - else: - logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) - return None - return SpiderEndpoint( - reactor, uri.host, uri.port, self.blacklist, self.whitelist, - endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15), - ) - - -class SpiderHttpClient(SimpleHttpClient): - """ - Separate HTTP client for spidering arbitrary URLs. - Special in that it follows retries and has a UA that looks - like a browser. - - used by the preview_url endpoint in the content repo. - """ - def __init__(self, hs): - SimpleHttpClient.__init__(self, hs) - # clobber the base class's agent and UA: - self.agent = ContentDecoderAgent( - BrowserLikeRedirectAgent( - Agent.usingEndpointFactory( - reactor, - SpiderEndpointFactory(hs) - ) - ), [(b'gzip', GzipDecoder)] - ) - # We could look like Chrome: - # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) - # Chrome Safari" % hs.version_string) - - def encode_urlencode_args(args): return {k: encode_urlencode_arg(v) for k, v in args.items()} From ec14bcad8ae565f62854baf1bd8c83e0f9227c5c Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 20 Nov 2018 19:55:37 -0600 Subject: [PATCH 05/25] fixes --- synapse/http/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 2a94304c64f0..cc70258e90a4 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -60,7 +60,7 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args={}}, whitelist=None, blacklist=None, _treq=treq): + def __init__(self, hs, treq_args={}, whitelist=None, blacklist=None, _treq=treq): self.hs = hs pool = HTTPConnectionPool(reactor) From 7af04df6ed134456e6a69d3cff81c5539d823e20 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 20 Nov 2018 20:32:33 -0600 Subject: [PATCH 06/25] fix pep8 --- synapse/http/client.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index cc70258e90a4..a3b95077b0a4 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -22,18 +22,15 @@ import treq from canonicaljson import encode_canonical_json, json from prometheus_client import Counter +from twisted.internet.error import ConnectError from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, protocol, reactor, ssl -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web._newclient import ResponseDone from twisted.web.client import ( URI, Agent, - BrowserLikeRedirectAgent, - ContentDecoderAgent, - GzipDecoder, HTTPConnectionPool, PartialDownloadError, readBody, From 606e39bc895ef7c9ced863213158df4463a07bf7 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 20 Nov 2018 20:44:48 -0600 Subject: [PATCH 07/25] fix pep8 --- synapse/http/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index a3b95077b0a4..4a53d4c273e5 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -22,11 +22,11 @@ import treq from canonicaljson import encode_canonical_json, json from prometheus_client import Counter -from twisted.internet.error import ConnectError from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, protocol, reactor, ssl +from twisted.internet.error import ConnectError from twisted.web._newclient import ResponseDone from twisted.web.client import ( URI, From 759169be3b59b5d06ef21530d67fb560dc09dc0c Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 23 Nov 2018 04:00:59 +1100 Subject: [PATCH 08/25] add some code coverage to blacklists --- synapse/http/client.py | 6 +- synapse/rest/media/v1/preview_url_resource.py | 3 + tests/rest/media/v1/test_url_preview.py | 70 +++++++++++++++---- 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 4a53d4c273e5..f6fce195fea7 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -98,13 +98,13 @@ def request(self, method, uri, data=b'', headers=None): address = yield self.reactor.resolve(split_uri.host) from netaddr import IPAddress - ip_address = IPAddress(address) if ip_address in self.blacklist: if self.whitelist is None or ip_address not in self.whitelist: - raise ConnectError( - "Refusing to spider blacklisted IP address %s" % address + raise SynapseError( + 403, "IP address blocked by IP blacklist entry", + Codes.UNKNOWN ) # A small wrapper around self.agent.request() so we can easily attach diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index cd539452937b..21e79df71640 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -323,6 +323,9 @@ def _download_url(self, url, user): length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) + except SynapseError as e: + # Pass SynapseErrors through directly. + raise except Exception as e: # FIXME: pass through 404s and other error messages nicely logger.warn("Error downloading %s: %r", url, e) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 39ca7ea89c05..3c88201355bc 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -17,7 +17,9 @@ from mock import Mock -from twisted.internet.defer import Deferred +from netaddr import IPSet + +from twisted.internet.defer import Deferred, succeed from twisted.web.http_headers import Headers from synapse.config.repository import MediaStorageProviderConfig @@ -40,7 +42,8 @@ def make_homeserver(self, reactor, clock): config = self.default_config() config.url_preview_enabled = True config.max_spider_size = 9999999 - config.url_preview_ip_range_blacklist = None + config.url_preview_ip_range_blacklist = IPSet(("192.168.1.1", "1.0.0.0/8")) + config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",)) config.url_preview_url_blacklist = [] config.media_store_path = self.storage_path @@ -244,17 +247,34 @@ def test_non_ascii_preview_content_type(self): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + def make_response(self, body, headers): + + # Assemble a mocked out response + def deliver(to): + to.dataReceived(body) + to.connectionLost(Mock()) + + res = Mock() + res.code = 200 + res.headers = Headers(headers) + res.deliverBody = deliver + + return res + def test_ipaddr(self): """ IP addresses can be previewed directly. """ - # We don't want a fully mocked out client, just a mocked out Treq + # Mock out Treq to one we control treq = Mock() d = Deferred() treq.request = Mock(return_value=d) self.preview_url.client = self._old_client self.preview_url.client._treq = treq + # Hardcode the URL resolving to the IP we want + self.reactor.resolve = lambda x: succeed("8.8.8.8") + request, channel = self.make_request( "GET", "url_preview?url=http://8.8.8.8", shorthand=False ) @@ -270,17 +290,8 @@ def test_ipaddr(self): b'' ) - # Assemble a mocked out response - def deliver(to): - to.dataReceived(end_content) - to.connectionLost(Mock()) - - res = Mock() - res.code = 200 - res.headers = Headers({b"Content-Type": [b"text/html"]}) - res.deliverBody = deliver - - # Deliver the mocked out response + # Build and deliver the mocked out response. + res = self.make_response(end_content, {b"Content-Type": [b"text/html"]}) d.callback(res) self.pump() @@ -288,3 +299,34 @@ def deliver(to): self.assertEqual( channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) + + def test_blacklisted_ip_range(self): + """ + Blacklisted IP addresses are not spidered. + """ + # Mock out Treq to one we control + treq = Mock() + d = Deferred() + treq.request = Mock(return_value=d) + self.preview_url.client = self._old_client + self.preview_url.client._treq = treq + + # Hardcode the URL resolving to the IP we want + self.reactor.resolve = lambda x: succeed("192.168.1.1") + + request, channel = self.make_request( + "GET", "url_preview?url=http://192.168.1.1", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # Treq is NOT called, because it will be blacklisted + self.assertEqual(treq.request.call_count, 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) From a2f7ae1b439ae62bb29b14fbecab385231c6e4fc Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Mon, 26 Nov 2018 14:52:26 +1100 Subject: [PATCH 09/25] tests --- tests/rest/media/v1/test_url_preview.py | 75 ++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 3c88201355bc..0e3e35893e83 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -300,7 +300,7 @@ def test_ipaddr(self): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_blacklisted_ip_range(self): + def test_blacklisted_ip_specific(self): """ Blacklisted IP addresses are not spidered. """ @@ -330,3 +330,76 @@ def test_blacklisted_ip_range(self): 'error': 'IP address blocked by IP blacklist entry', }, ) + + def test_blacklisted_ip_range(self): + """ + Blacklisted IP ranges are not spidered. + """ + # Mock out Treq to one we control + treq = Mock() + d = Deferred() + treq.request = Mock(return_value=d) + self.preview_url.client = self._old_client + self.preview_url.client._treq = treq + + # Hardcode the URL resolving to the IP we want + self.reactor.resolve = lambda x: succeed("1.1.1.2") + + request, channel = self.make_request( + "GET", "url_preview?url=http://1.1.1.2", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # Treq is NOT called, because it will be blacklisted + self.assertEqual(treq.request.call_count, 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + + def test_blacklisted_ip_range_whitelisted_ip(self): + """ + Blacklisted but then subsequently whitelisted IP addresses can be + spidered. + """ + # Mock out Treq to one we control + treq = Mock() + d = Deferred() + treq.request = Mock(return_value=d) + self.preview_url.client = self._old_client + self.preview_url.client._treq = treq + + # Hardcode the URL resolving to the IP we want. This is an IP that is + # caught by a blacklist range, but is then subsequently whitelisted. + self.reactor.resolve = lambda x: succeed("1.1.1.1") + + request, channel = self.make_request( + "GET", "url_preview?url=http://1.1.1.1", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + self.assertEqual(treq.request.call_count, 1) + + end_content = ( + b'' + b'' + b'' + b'' + ) + + # Build and deliver the mocked out response. + res = self.make_response(end_content, {b"Content-Type": [b"text/html"]}) + d.callback(res) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) From ca82572deea2553296f3c867fae862b164102fc7 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Mon, 26 Nov 2018 21:43:12 +1100 Subject: [PATCH 10/25] fix pep8 --- synapse/http/client.py | 1 - synapse/rest/media/v1/preview_url_resource.py | 2 +- tests/rest/media/v1/test_url_preview.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index f6fce195fea7..2c5e945fa453 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -26,7 +26,6 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, protocol, reactor, ssl -from twisted.internet.error import ConnectError from twisted.web._newclient import ResponseDone from twisted.web.client import ( URI, diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 21e79df71640..e9667977e8b9 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -323,7 +323,7 @@ def _download_url(self, url, user): length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size, ) - except SynapseError as e: + except SynapseError: # Pass SynapseErrors through directly. raise except Exception as e: diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 0e3e35893e83..5ffb5785254f 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -362,7 +362,6 @@ def test_blacklisted_ip_range(self): }, ) - def test_blacklisted_ip_range_whitelisted_ip(self): """ Blacklisted but then subsequently whitelisted IP addresses can be From fc1dd4800cf273a102802f3fae7dc35d3d1a485c Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 28 Nov 2018 21:26:34 +1100 Subject: [PATCH 11/25] cleanup --- .coveragerc | 1 - 1 file changed, 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index ca333961f3b4..9873a3073882 100644 --- a/.coveragerc +++ b/.coveragerc @@ -9,4 +9,3 @@ source= [report] precision = 2 -ignore_errors = True From c70d2a1af39cd3ee3baefa334601ab1063b45d7c Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 1 Dec 2018 02:30:37 +1100 Subject: [PATCH 12/25] log when we block --- synapse/http/client.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 2c5e945fa453..7cb0c5c09d8e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -94,16 +94,22 @@ def request(self, method, uri, data=b'', headers=None): # Check our IP whitelists/blacklists before making the request. if self.blacklist: split_uri = URI.fromBytes(uri.encode('utf8')) - address = yield self.reactor.resolve(split_uri.host) + address = yield make_deferred_yieldable( + self.reactor.resolve(split_uri.host) + ) from netaddr import IPAddress + ip_address = IPAddress(address) if ip_address in self.blacklist: if self.whitelist is None or ip_address not in self.whitelist: + logger.info( + "Blocked accessing %s because of blacklisted IP %s" + % (split_uri.host.decode('utf8'), ip_address) + ) raise SynapseError( - 403, "IP address blocked by IP blacklist entry", - Codes.UNKNOWN + 403, "IP address blocked by IP blacklist entry", Codes.UNKNOWN ) # A small wrapper around self.agent.request() so we can easily attach From 5756f6490b677d9838d6fdd51a4db432bfbd0902 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 1 Dec 2018 02:43:55 +1100 Subject: [PATCH 13/25] docstring --- synapse/http/client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/synapse/http/client.py b/synapse/http/client.py index 7cb0c5c09d8e..9071bbb3c80a 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -90,7 +90,16 @@ def __init__(self, hs, treq_args={}, whitelist=None, blacklist=None, _treq=treq) @defer.inlineCallbacks def request(self, method, uri, data=b'', headers=None): + """ + Args: + method (str): HTTP method to use. + uri (str): URI to query. + data (bytes): Data to send in the request body, if applicable. + headers (t.w.http_headers.Headers): Request headers. + Raises: + SynapseError: If the IP is blacklisted. + """ # Check our IP whitelists/blacklists before making the request. if self.blacklist: split_uri = URI.fromBytes(uri.encode('utf8')) From 0421f0b0b0a444adcedb8eefb76800be6c70e96a Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 1 Dec 2018 02:44:35 +1100 Subject: [PATCH 14/25] docstring --- synapse/http/client.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/synapse/http/client.py b/synapse/http/client.py index 9071bbb3c80a..4d6234fb7296 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -57,6 +57,16 @@ class SimpleHttpClient(object): """ def __init__(self, hs, treq_args={}, whitelist=None, blacklist=None, _treq=treq): + """ + Args: + hs (synapse.server.HomeServer) + treq_args (dict): Extra keyword arguments to be given to treq.request. + blacklist (netaddr.IPSet): The IP addresses that are blacklisted that + we may not request. + whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can + request if it were otherwise caught in a blacklist. + _treq (treq): Treq implementation, can be overridden for testing. + """ self.hs = hs pool = HTTPConnectionPool(reactor) From e59a5eebda76f12f2fb003d0f6f68223184f5616 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 1 Dec 2018 02:47:45 +1100 Subject: [PATCH 15/25] comment --- synapse/rest/media/v1/preview_url_resource.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index e9667977e8b9..6bdd1d3442f6 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -324,7 +324,10 @@ def _download_url(self, url, user): url, output_stream=f, max_size=self.max_spider_size, ) except SynapseError: - # Pass SynapseErrors through directly. + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. Currently, this is only if the IP we are + # trying to fetch from is blacklisted. raise except Exception as e: # FIXME: pass through 404s and other error messages nicely From d82e498216d967f7174db75d1a004e6378baf425 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 5 Dec 2018 23:15:07 +1100 Subject: [PATCH 16/25] fix --- synapse/http/client.py | 170 +++++++++---- tests/rest/media/v1/test_url_preview.py | 312 ++++++++++-------------- 2 files changed, 260 insertions(+), 222 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 4d6234fb7296..fb7a5a007645 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -23,7 +23,11 @@ from canonicaljson import encode_canonical_json, json from prometheus_client import Counter +from netaddr import IPAddress + +from hyperlink import URL from OpenSSL import SSL +from twisted.python.failure import Failure from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, protocol, reactor, ssl from twisted.web._newclient import ResponseDone @@ -43,12 +47,95 @@ from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable +from twisted.internet.address import IPv4Address, IPv6Address + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) incoming_responses_counter = Counter("synapse_http_client_responses", "", ["method", "code"]) +def check_against_blacklist(ip_address, whitelist, blacklist): + if ip_address in blacklist: + if whitelist is None or ip_address not in whitelist: + return True + return False + + +class IPBlacklistingResolver(object): + def __init__(self, reactor, whitelist, blacklist): + self._reactor = reactor + self._whitelist = whitelist + self._blacklist = blacklist + + def resolveHostName(recv, hostname, portNumber=0): + + r = recv() + d = Deferred() + + @provider(IResolutionReceiver) + class EndpointReceiver(object): + @staticmethod + def resolutionBegan(resolutionInProgress): + pass + @staticmethod + def addressResolved(address): + print(repr(address)) + ip_address = IPAddress(address.host) + + if check_against_blacklist(ip_address, self.whitelist, self._blacklist): + logger.info( + "Dropped %s from DNS resolution to %s" + % (ip_address, hostname) + ) + raise SynapseError(403, "IP address blocked by IP blacklist entry") + + addresses.append(address) + @staticmethod + def resolutionComplete(): + d.callback(addresses) + + self._reactor.nameResolver.resolveHostName( + EndpointReceiver, hostname, portNumber=portNumber + ) + + def _callback(addrs): + r.resolutionBegan(None) + for i in addrs: + r.addressResolved(i) + r.resolutionComplete() + + d.addCallback(_callback) + + return r + + +class BlacklistingAgentWrapper(Agent): + + def __init__(self, agent, reactor, *args, whitelist=None, blacklist=None, **kwargs): + self._agent = agent + self._whitelist = whitelist + self._blacklist = blacklist + + # Put in our own blacklisting resolver. + agent._nameResolver = IPBlacklistingResolver(reactor, whitelist, blacklist) + + def request(self, method, uri, headers=None, bodyProducer=None): + h = URL.from_text(uri.decode('ascii')) + + try: + ip_address = IPAddress(h.host) + + if check_against_blacklist(ip_address, self._whitelist, self._blacklist): + logger.info("Blocking access to %s because of blacklist" % (ip_address,)) + e = SynapseError(403, "IP address blocked by IP blacklist entry") + return defer.fail(Failure(e)) + except: + # Not an IP + pass + + return self._agent.request(method, uri, headers=headers, bodyProducer=bodyProducer) + class SimpleHttpClient(object): """ @@ -56,7 +143,7 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args={}, whitelist=None, blacklist=None, _treq=treq): + def __init__(self, hs, treq_args={}, whitelist=None, blacklist=None): """ Args: hs (synapse.server.HomeServer) @@ -65,31 +152,13 @@ def __init__(self, hs, treq_args={}, whitelist=None, blacklist=None, _treq=treq) we may not request. whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. - _treq (treq): Treq implementation, can be overridden for testing. """ self.hs = hs - pool = HTTPConnectionPool(reactor) - self._treq = _treq + self._whitelist = whitelist + self._blacklist = blacklist self._extra_treq_args = treq_args - self.whitelist = whitelist - self.blacklist = blacklist - - # the pusher makes lots of concurrent SSL connections to sygnal, and - # tends to do so in batches, so we need to allow the pool to keep lots - # of idle connections around. - pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) - pool.cachedConnectionTimeout = 2 * 60 - - # The default context factory in Twisted 14.0.0 (which we require) is - # BrowserLikePolicyForHTTPS which will do regular cert validation - # 'like a browser' - self.agent = Agent( - reactor, - connectTimeout=15, - contextFactory=hs.get_http_client_context_factory(), - pool=pool, - ) + self.user_agent = hs.version_string self.clock = hs.get_clock() self.reactor = hs.get_reactor() @@ -97,6 +166,38 @@ def __init__(self, hs, treq_args={}, whitelist=None, blacklist=None, _treq=treq) self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) self.user_agent = self.user_agent.encode('ascii') + self._make_agent() + + def _make_agent(self, _agent=False): + + if _agent: + self.agent = _agent + else: + # the pusher makes lots of concurrent SSL connections to sygnal, and + # tends to do so in batches, so we need to allow the pool to keep + # lots of idle connections around. + pool = HTTPConnectionPool(self.reactor) + pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + pool.cachedConnectionTimeout = 2 * 60 + + # The default context factory in Twisted 14.0.0 (which we require) is + # BrowserLikePolicyForHTTPS which will do regular cert validation + # 'like a browser' + self.agent = Agent( + reactor, + connectTimeout=15, + contextFactory=self.hs.get_http_client_context_factory(), + pool=pool, + ) + + # If we have an IP blacklist, use the blacklisting Agent wrapper. + if self._blacklist: + self.agent = BlacklistingAgentWrapper( + self.agent, + reactor, + whitelist=self._whitelist, + blacklist=self._blacklist, + ) @defer.inlineCallbacks def request(self, method, uri, data=b'', headers=None): @@ -110,27 +211,6 @@ def request(self, method, uri, data=b'', headers=None): Raises: SynapseError: If the IP is blacklisted. """ - # Check our IP whitelists/blacklists before making the request. - if self.blacklist: - split_uri = URI.fromBytes(uri.encode('utf8')) - address = yield make_deferred_yieldable( - self.reactor.resolve(split_uri.host) - ) - - from netaddr import IPAddress - - ip_address = IPAddress(address) - - if ip_address in self.blacklist: - if self.whitelist is None or ip_address not in self.whitelist: - logger.info( - "Blocked accessing %s because of blacklisted IP %s" - % (split_uri.host.decode('utf8'), ip_address) - ) - raise SynapseError( - 403, "IP address blocked by IP blacklist entry", Codes.UNKNOWN - ) - # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.labels(method).inc() @@ -139,7 +219,7 @@ def request(self, method, uri, data=b'', headers=None): logger.info("Sending request %s %s", method, redact_uri(uri)) try: - request_deferred = self._treq.request( + request_deferred = treq.request( method, uri, agent=self.agent, @@ -397,7 +477,7 @@ def get_file(self, url, output_stream, max_size=None, headers=None): resp_headers = dict(response.headers.getAllRawHeaders()) if (b'Content-Length' in resp_headers and - int(resp_headers[b'Content-Length']) > max_size): + int(resp_headers[b'Content-Length'][0]) > max_size): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 5ffb5785254f..7823bd0d374b 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -18,9 +18,13 @@ from mock import Mock from netaddr import IPSet +import attr +from twisted.web._newclient import ResponseDone from twisted.internet.defer import Deferred, succeed from twisted.web.http_headers import Headers +from twisted.web.client import Response +from twisted.python.failure import Failure from synapse.config.repository import MediaStorageProviderConfig from synapse.util.logcontext import make_deferred_yieldable @@ -29,10 +33,38 @@ from tests import unittest +@attr.s +class FakeResponse(object): + version = attr.ib() + code = attr.ib() + phrase = attr.ib() + headers = attr.ib() + body = attr.ib() + absoluteURI = attr.ib() + + @property + def request(self): + @attr.s + class FakeTransport(object): + absoluteURI = self.absoluteURI + + return FakeTransport() + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) + + class URLPreviewTests(unittest.HomeserverTestCase): hijack_auth = True user_id = "@test:user" + end_content = ( + b'' + b'' + b'' + b'' + ) def make_homeserver(self, reactor, clock): @@ -67,72 +99,46 @@ def make_homeserver(self, reactor, clock): def prepare(self, reactor, clock, hs): - self.fetches = [] + self.media_repo = hs.get_media_repository_resource() + self.preview_url = self.media_repo.children[b'preview_url'] - def get_file(url, output_stream, max_size): - """ - Returns tuple[int,dict,str,int] of file length, response headers, - absolute URI, and response code. - """ + class Agent(object): + def request(_self, *args, **kwargs): + return self._on_request(*args, **kwargs) - def write_to(r): - data, response = r - output_stream.write(data) - return response + # Load in the Agent we want + self.preview_url.client._make_agent(Agent()) - d = Deferred() - d.addCallback(write_to) - self.fetches.append((d, url)) - return make_deferred_yieldable(d) + def test_cache_returns_correct_type(self): - client = Mock() - client.get_file = get_file + calls = [0] - self.media_repo = hs.get_media_repository_resource() - preview_url = self.media_repo.children[b'preview_url'] - self._old_client = preview_url.client - preview_url.client = client - self.preview_url = preview_url + def _on_request(method, uri, headers=None, bodyProducer=None): - def test_cache_returns_correct_type(self): + calls[0] += 1 + h = Headers( + { + b"Content-Length": [b"%d" % (len(self.end_content))], + b"Content-Type": [b'text/html; charset="utf8"'], + } + ) + resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) + return succeed(resp) + + self._on_request = _on_request request, channel = self.make_request( "GET", "url_preview?url=matrix.org", shorthand=False ) request.render(self.preview_url) - self.pump() - - # We've made one fetch - self.assertEqual(len(self.fetches), 1) - - end_content = ( - b'' - b'' - b'' - b'' - ) - - self.fetches[0][0].callback( - ( - end_content, - ( - len(end_content), - { - b"Content-Length": [b"%d" % (len(end_content))], - b"Content-Type": [b'text/html; charset="utf8"'], - }, - "https://example.com", - 200, - ), - ) - ) - self.pump() self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) + self.assertEqual(calls[0], 1) + # Check the cache returns the correct response request, channel = self.make_request( "GET", "url_preview?url=matrix.org", shorthand=False @@ -141,7 +147,7 @@ def test_cache_returns_correct_type(self): self.pump() # Only one fetch, still, since we'll lean on the cache - self.assertEqual(len(self.fetches), 1) + self.assertEqual(calls[0], 1) # Check the cache response has the same content self.assertEqual(channel.code, 200) @@ -162,7 +168,7 @@ def test_cache_returns_correct_type(self): self.pump() # Only one fetch, still, since we'll lean on the cache - self.assertEqual(len(self.fetches), 1) + self.assertEqual(calls[0], 1) # Check the cache response has the same content self.assertEqual(channel.code, 200) @@ -172,15 +178,6 @@ def test_cache_returns_correct_type(self): def test_non_ascii_preview_httpequiv(self): - request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False - ) - request.render(self.preview_url) - self.pump() - - # We've made one fetch - self.assertEqual(len(self.fetches), 1) - end_content = ( b'' b'' @@ -189,37 +186,31 @@ def test_non_ascii_preview_httpequiv(self): b'' ) - self.fetches[0][0].callback( - ( - end_content, - ( - len(end_content), - { - b"Content-Length": [b"%d" % (len(end_content))], - # This charset=utf-8 should be ignored, because the - # document has a meta tag overriding it. - b"Content-Type": [b'text/html; charset="utf8"'], - }, - "https://example.com", - 200, - ), - ) - ) + def _on_request(method, uri, headers=None, bodyProducer=None): - self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + h = Headers( + { + b"Content-Length": [b"%d" % (len(end_content))], + # This charset=utf-8 should be ignored, because the + # document has a meta tag overriding it. + b"Content-Type": [b'text/html; charset="utf8"'], + } + ) + resp = FakeResponse(b"1.1", 200, b"OK", h, end_content, uri) + return succeed(resp) - def test_non_ascii_preview_content_type(self): + self._on_request = _on_request request, channel = self.make_request( "GET", "url_preview?url=matrix.org", shorthand=False ) request.render(self.preview_url) + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") - # We've made one fetch - self.assertEqual(len(self.fetches), 1) + def test_non_ascii_preview_content_type(self): end_content = ( b'' @@ -228,72 +219,49 @@ def test_non_ascii_preview_content_type(self): b'' ) - self.fetches[0][0].callback( - ( - end_content, - ( - len(end_content), - { - b"Content-Length": [b"%d" % (len(end_content))], - b"Content-Type": [b'text/html; charset="windows-1251"'], - }, - "https://example.com", - 200, - ), + def _on_request(method, uri, headers=None, bodyProducer=None): + + h = Headers( + { + b"Content-Length": [b"%d" % (len(end_content))], + b"Content-Type": [b'text/html; charset="windows-1251"'], + } ) - ) + resp = FakeResponse(b"1.1", 200, b"OK", h, end_content, uri) + return succeed(resp) + + self._on_request = _on_request + request, channel = self.make_request( + "GET", "url_preview?url=matrix.org", shorthand=False + ) + request.render(self.preview_url) self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") - def make_response(self, body, headers): - - # Assemble a mocked out response - def deliver(to): - to.dataReceived(body) - to.connectionLost(Mock()) - - res = Mock() - res.code = 200 - res.headers = Headers(headers) - res.deliverBody = deliver - - return res - def test_ipaddr(self): """ IP addresses can be previewed directly. """ - # Mock out Treq to one we control - treq = Mock() - d = Deferred() - treq.request = Mock(return_value=d) - self.preview_url.client = self._old_client - self.preview_url.client._treq = treq - # Hardcode the URL resolving to the IP we want - self.reactor.resolve = lambda x: succeed("8.8.8.8") + def _on_request(method, uri, headers=None, bodyProducer=None): + + h = Headers( + { + b"Content-Length": [b"%d" % (len(self.end_content))], + b"Content-Type": [b'text/html'], + } + ) + resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) + return succeed(resp) + + self._on_request = _on_request request, channel = self.make_request( "GET", "url_preview?url=http://8.8.8.8", shorthand=False ) request.render(self.preview_url) - self.pump() - - self.assertEqual(treq.request.call_count, 1) - - end_content = ( - b'' - b'' - b'' - b'' - ) - - # Build and deliver the mocked out response. - res = self.make_response(end_content, {b"Content-Type": [b"text/html"]}) - d.callback(res) - self.pump() self.assertEqual(channel.code, 200) self.assertEqual( @@ -304,24 +272,25 @@ def test_blacklisted_ip_specific(self): """ Blacklisted IP addresses are not spidered. """ - # Mock out Treq to one we control - treq = Mock() - d = Deferred() - treq.request = Mock(return_value=d) - self.preview_url.client = self._old_client - self.preview_url.client._treq = treq - # Hardcode the URL resolving to the IP we want - self.reactor.resolve = lambda x: succeed("192.168.1.1") + def _on_request(method, uri, headers=None, bodyProducer=None): + + h = Headers( + { + b"Content-Length": [b"%d" % (len(self.end_content))], + b"Content-Type": [b'text/html'], + } + ) + resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) + return succeed(resp) + + self._on_request = _on_request request, channel = self.make_request( "GET", "url_preview?url=http://192.168.1.1", shorthand=False ) request.render(self.preview_url) self.pump() - - # Treq is NOT called, because it will be blacklisted - self.assertEqual(treq.request.call_count, 0) self.assertEqual(channel.code, 403) self.assertEqual( channel.json_body, @@ -335,15 +304,19 @@ def test_blacklisted_ip_range(self): """ Blacklisted IP ranges are not spidered. """ - # Mock out Treq to one we control - treq = Mock() - d = Deferred() - treq.request = Mock(return_value=d) - self.preview_url.client = self._old_client - self.preview_url.client._treq = treq - # Hardcode the URL resolving to the IP we want - self.reactor.resolve = lambda x: succeed("1.1.1.2") + def _on_request(method, uri, headers=None, bodyProducer=None): + + h = Headers( + { + b"Content-Length": [b"%d" % (len(self.end_content))], + b"Content-Type": [b'text/html'], + } + ) + resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) + return succeed(resp) + + self._on_request = _on_request request, channel = self.make_request( "GET", "url_preview?url=http://1.1.1.2", shorthand=False @@ -351,8 +324,6 @@ def test_blacklisted_ip_range(self): request.render(self.preview_url) self.pump() - # Treq is NOT called, because it will be blacklisted - self.assertEqual(treq.request.call_count, 0) self.assertEqual(channel.code, 403) self.assertEqual( channel.json_body, @@ -367,36 +338,23 @@ def test_blacklisted_ip_range_whitelisted_ip(self): Blacklisted but then subsequently whitelisted IP addresses can be spidered. """ - # Mock out Treq to one we control - treq = Mock() - d = Deferred() - treq.request = Mock(return_value=d) - self.preview_url.client = self._old_client - self.preview_url.client._treq = treq + def _on_request(method, uri, headers=None, bodyProducer=None): + + h = Headers( + { + b"Content-Length": [b"%d" % (len(self.end_content))], + b"Content-Type": [b'text/html'], + } + ) + resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) + return succeed(resp) - # Hardcode the URL resolving to the IP we want. This is an IP that is - # caught by a blacklist range, but is then subsequently whitelisted. - self.reactor.resolve = lambda x: succeed("1.1.1.1") + self._on_request = _on_request request, channel = self.make_request( "GET", "url_preview?url=http://1.1.1.1", shorthand=False ) request.render(self.preview_url) - self.pump() - - self.assertEqual(treq.request.call_count, 1) - - end_content = ( - b'' - b'' - b'' - b'' - ) - - # Build and deliver the mocked out response. - res = self.make_response(end_content, {b"Content-Type": [b"text/html"]}) - d.callback(res) - self.pump() self.assertEqual(channel.code, 200) self.assertEqual( From 1210131e9eb5189c2902ac9715c01f2913375bc4 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 5 Dec 2018 23:18:29 +1100 Subject: [PATCH 17/25] merge in tests --- tests/rest/media/v1/test_url_preview.py | 75 +++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 7823bd0d374b..a49279c0e977 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -22,6 +22,10 @@ from twisted.web._newclient import ResponseDone from twisted.internet.defer import Deferred, succeed +from twisted.internet._resolver import HostResolution +from twisted.internet.address import IPv4Address +from twisted.internet.defer import Deferred +from twisted.internet.error import DNSLookupError from twisted.web.http_headers import Headers from twisted.web.client import Response from twisted.python.failure import Failure @@ -109,6 +113,32 @@ def request(_self, *args, **kwargs): # Load in the Agent we want self.preview_url.client._make_agent(Agent()) + self.lookups = {} + + class Resolver(object): + def resolveHostName( + _self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics='TCP', + ): + + resolution = HostResolution(hostName) + resolutionReceiver.resolutionBegan(resolution) + if hostName not in self.lookups: + raise DNSLookupError("OH NO") + + for i in self.lookups[hostName]: + resolutionReceiver.addressResolved( + i[0]('TCP', i[1], portNumber) + ) + resolutionReceiver.resolutionComplete() + return resolutionReceiver + + self.reactor.nameResolver = Resolver() + def test_cache_returns_correct_type(self): calls = [0] @@ -127,6 +157,8 @@ def _on_request(method, uri, headers=None, bodyProducer=None): self._on_request = _on_request + def test_cache_returns_correct_type(self): + request, channel = self.make_request( "GET", "url_preview?url=matrix.org", shorthand=False ) @@ -259,7 +291,7 @@ def _on_request(method, uri, headers=None, bodyProducer=None): self._on_request = _on_request request, channel = self.make_request( - "GET", "url_preview?url=http://8.8.8.8", shorthand=False + "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) self.pump() @@ -287,7 +319,7 @@ def _on_request(method, uri, headers=None, bodyProducer=None): self._on_request = _on_request request, channel = self.make_request( - "GET", "url_preview?url=http://192.168.1.1", shorthand=False + "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) self.pump() @@ -319,7 +351,7 @@ def _on_request(method, uri, headers=None, bodyProducer=None): self._on_request = _on_request request, channel = self.make_request( - "GET", "url_preview?url=http://1.1.1.2", shorthand=False + "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) self.pump() @@ -352,7 +384,7 @@ def _on_request(method, uri, headers=None, bodyProducer=None): self._on_request = _on_request request, channel = self.make_request( - "GET", "url_preview?url=http://1.1.1.1", shorthand=False + "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) self.pump() @@ -360,3 +392,38 @@ def _on_request(method, uri, headers=None, bodyProducer=None): self.assertEqual( channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) + + def test_blacklisted_ip_with_external_ip(self): + """ + If a hostname resolves a blacklisted IP, even if there's a + non-blacklisted one, it will be rejected. + """ + def _on_request(method, uri, headers=None, bodyProducer=None): + + h = Headers( + { + b"Content-Length": [b"%d" % (len(self.end_content))], + b"Content-Type": [b'text/html'], + } + ) + resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) + return succeed(resp) + + self._on_request = _on_request + + # Hardcode the URL resolving to the IP we want. + self.lookups[u"example.com"] = ["1.1.1.2", "8.8.8.8"] + + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) From 6fade491ea591ef1f8fb9e43c3bcabf5175bace6 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 6 Dec 2018 01:11:07 +1100 Subject: [PATCH 18/25] fix --- synapse/http/client.py | 188 +++++++------- tests/rest/media/v1/test_url_preview.py | 324 +++++++++++++----------- tests/server.py | 8 + 3 files changed, 285 insertions(+), 235 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index fb7a5a007645..3505125a9936 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -21,23 +21,21 @@ import treq from canonicaljson import encode_canonical_json, json -from prometheus_client import Counter - +from hyperlink import URL from netaddr import IPAddress +from prometheus_client import Counter +from zope.interface import implementer, provider -from hyperlink import URL from OpenSSL import SSL -from twisted.python.failure import Failure from OpenSSL.SSL import VERIFY_NONE -from twisted.internet import defer, protocol, reactor, ssl -from twisted.web._newclient import ResponseDone -from twisted.web.client import ( - URI, - Agent, - HTTPConnectionPool, - PartialDownloadError, - readBody, +from twisted.internet import defer, protocol, ssl +from twisted.internet.interfaces import ( + IReactorPluggableNameResolver, + IResolutionReceiver, ) +from twisted.python.failure import Failure +from twisted.web._newclient import ResponseDone +from twisted.web.client import Agent, HTTPConnectionPool, PartialDownloadError, readBody from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers @@ -47,13 +45,13 @@ from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable -from twisted.internet.address import IPv4Address, IPv6Address - logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) -incoming_responses_counter = Counter("synapse_http_client_responses", "", - ["method", "code"]) +incoming_responses_counter = Counter( + "synapse_http_client_responses", "", ["method", "code"] +) + def check_against_blacklist(ip_address, whitelist, blacklist): if ip_address in blacklist: @@ -68,29 +66,32 @@ def __init__(self, reactor, whitelist, blacklist): self._whitelist = whitelist self._blacklist = blacklist - def resolveHostName(recv, hostname, portNumber=0): + def resolveHostName(self, recv, hostname, portNumber=0): r = recv() - d = Deferred() + d = defer.Deferred() + addresses = [] @provider(IResolutionReceiver) class EndpointReceiver(object): @staticmethod def resolutionBegan(resolutionInProgress): pass + @staticmethod def addressResolved(address): - print(repr(address)) ip_address = IPAddress(address.host) - if check_against_blacklist(ip_address, self.whitelist, self._blacklist): + if check_against_blacklist( + ip_address, self._whitelist, self._blacklist + ): logger.info( - "Dropped %s from DNS resolution to %s" - % (ip_address, hostname) + "Dropped %s from DNS resolution to %s" % (ip_address, hostname) ) raise SynapseError(403, "IP address blocked by IP blacklist entry") addresses.append(address) + @staticmethod def resolutionComplete(): d.callback(addresses) @@ -111,15 +112,11 @@ def _callback(addrs): class BlacklistingAgentWrapper(Agent): - def __init__(self, agent, reactor, *args, whitelist=None, blacklist=None, **kwargs): self._agent = agent self._whitelist = whitelist self._blacklist = blacklist - # Put in our own blacklisting resolver. - agent._nameResolver = IPBlacklistingResolver(reactor, whitelist, blacklist) - def request(self, method, uri, headers=None, bodyProducer=None): h = URL.from_text(uri.decode('ascii')) @@ -127,14 +124,18 @@ def request(self, method, uri, headers=None, bodyProducer=None): ip_address = IPAddress(h.host) if check_against_blacklist(ip_address, self._whitelist, self._blacklist): - logger.info("Blocking access to %s because of blacklist" % (ip_address,)) + logger.info( + "Blocking access to %s because of blacklist" % (ip_address,) + ) e = SynapseError(403, "IP address blocked by IP blacklist entry") return defer.fail(Failure(e)) - except: + except Exception: # Not an IP pass - return self._agent.request(method, uri, headers=headers, bodyProducer=bodyProducer) + return self._agent.request( + method, uri, headers=headers, bodyProducer=bodyProducer + ) class SimpleHttpClient(object): @@ -163,7 +164,7 @@ def __init__(self, hs, treq_args={}, whitelist=None, blacklist=None): self.clock = hs.get_clock() self.reactor = hs.get_reactor() if hs.config.user_agent_suffix: - self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) + self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) self.user_agent = self.user_agent.encode('ascii') self._make_agent() @@ -184,7 +185,7 @@ def _make_agent(self, _agent=False): # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' self.agent = Agent( - reactor, + self.reactor, connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, @@ -192,9 +193,29 @@ def _make_agent(self, _agent=False): # If we have an IP blacklist, use the blacklisting Agent wrapper. if self._blacklist: + + nameResolver = IPBlacklistingResolver( + self.reactor, self._whitelist, self._blacklist + ) + + @implementer(IReactorPluggableNameResolver) + class Reactor(object): + def __getattr__(_self, attr): + if attr == "nameResolver": + return nameResolver + else: + return getattr(self.reactor, attr) + + self.agent = Agent( + Reactor(), + connectTimeout=15, + contextFactory=self.hs.get_http_client_context_factory(), + pool=pool, + ) + self.agent = BlacklistingAgentWrapper( self.agent, - reactor, + self.reactor, whitelist=self._whitelist, blacklist=self._blacklist, ) @@ -228,22 +249,29 @@ def request(self, method, uri, data=b'', headers=None): **self._extra_treq_args ) request_deferred = timeout_deferred( - request_deferred, 60, self.hs.get_reactor(), + request_deferred, + 60, + self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) response = yield make_deferred_yieldable(request_deferred) incoming_responses_counter.labels(method, response.code).inc() logger.info( - "Received response to %s %s: %s", - method, redact_uri(uri), response.code + "Received response to %s %s: %s", method, redact_uri(uri), response.code ) defer.returnValue(response) except Exception as e: + print(e) + import traceback + traceback.print_exc() incoming_responses_counter.labels(method, "ERR").inc() logger.info( "Error sending request to %s %s: %s %s", - method, redact_uri(uri), type(e).__name__, e.args[0] + method, + redact_uri(uri), + type(e).__name__, + e.args[0], ) raise @@ -268,8 +296,9 @@ def post_urlencoded_get_json(self, uri, args={}, headers=None): # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) - query_bytes = urllib.parse.urlencode( - encode_urlencode_args(args), True).encode("utf8") + query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode( + "utf8" + ) actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], @@ -279,10 +308,7 @@ def post_urlencoded_get_json(self, uri, args={}, headers=None): actual_headers.update(headers) response = yield self.request( - "POST", - uri, - headers=Headers(actual_headers), - data=query_bytes + "POST", uri, headers=Headers(actual_headers), data=query_bytes ) if 200 <= response.code < 300: @@ -321,10 +347,7 @@ def post_json_get_json(self, uri, post_json, headers=None): actual_headers.update(headers) response = yield self.request( - "POST", - uri, - headers=Headers(actual_headers), - data=json_str + "POST", uri, headers=Headers(actual_headers), data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -392,10 +415,7 @@ def put_json(self, uri, json_body, args={}, headers=None): actual_headers.update(headers) response = yield self.request( - "PUT", - uri, - headers=Headers(actual_headers), - data=json_str + "PUT", uri, headers=Headers(actual_headers), data=json_str ) body = yield make_deferred_yieldable(readBody(response)) @@ -427,17 +447,11 @@ def get_raw(self, uri, args={}, headers=None): query_bytes = urllib.parse.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) - actual_headers = { - b"User-Agent": [self.user_agent], - } + actual_headers = {b"User-Agent": [self.user_agent]} if headers: actual_headers.update(headers) - response = yield self.request( - "GET", - uri, - headers=Headers(actual_headers), - ) + response = yield self.request("GET", uri, headers=Headers(actual_headers)) body = yield make_deferred_yieldable(readBody(response)) @@ -462,22 +476,18 @@ def get_file(self, url, output_stream, max_size=None, headers=None): headers, absolute URI of the response and HTTP response code. """ - actual_headers = { - b"User-Agent": [self.user_agent], - } + actual_headers = {b"User-Agent": [self.user_agent]} if headers: actual_headers.update(headers) - response = yield self.request( - "GET", - url, - headers=Headers(actual_headers), - ) + response = yield self.request("GET", url, headers=Headers(actual_headers)) resp_headers = dict(response.headers.getAllRawHeaders()) - if (b'Content-Length' in resp_headers and - int(resp_headers[b'Content-Length'][0]) > max_size): + if ( + b'Content-Length' in resp_headers + and int(resp_headers[b'Content-Length'][0]) > max_size + ): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, @@ -487,26 +497,20 @@ def get_file(self, url, output_stream, max_size=None, headers=None): if response.code > 299: logger.warn("Got %d when downloading %s" % (response.code, url)) - raise SynapseError( - 502, - "Got error %d" % (response.code,), - Codes.UNKNOWN, - ) + raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it # straight back in again try: - length = yield make_deferred_yieldable(_readBodyToFile( - response, output_stream, max_size, - )) + length = yield make_deferred_yieldable( + _readBodyToFile(response, output_stream, max_size) + ) except Exception as e: logger.exception("Failed to download body") raise SynapseError( - 502, - ("Failed to download remote body: %s" % e), - Codes.UNKNOWN, + 502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN ) defer.returnValue( @@ -515,13 +519,14 @@ def get_file(self, url, output_stream, max_size=None, headers=None): resp_headers, response.request.absoluteURI.decode('ascii'), response.code, - ), + ) ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. + class _ReadBodyToFileProtocol(protocol.Protocol): def __init__(self, stream, deferred, max_size): self.stream = stream @@ -533,11 +538,13 @@ def dataReceived(self, data): self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback(SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - )) + self.deferred.errback( + SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + ) + ) self.deferred = defer.Deferred() self.transport.loseConnection() @@ -555,6 +562,7 @@ def connectionLost(self, reason): # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. + def _readBodyToFile(response, stream, max_size): d = defer.Deferred() response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) @@ -577,10 +585,12 @@ def post_urlencoded_get_raw(self, url, args={}): "POST", url, data=query_bytes, - headers=Headers({ - b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [self.user_agent], - }) + headers=Headers( + { + b"Content-Type": [b"application/x-www-form-urlencoded"], + b"User-Agent": [self.user_agent], + } + ), ) try: diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index a49279c0e977..5693c224cceb 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -15,26 +15,21 @@ import os -from mock import Mock - -from netaddr import IPSet import attr +from netaddr import IPSet -from twisted.web._newclient import ResponseDone -from twisted.internet.defer import Deferred, succeed from twisted.internet._resolver import HostResolution -from twisted.internet.address import IPv4Address -from twisted.internet.defer import Deferred +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError -from twisted.web.http_headers import Headers -from twisted.web.client import Response from twisted.python.failure import Failure +from twisted.test.proto_helpers import AccumulatingProtocol +from twisted.web._newclient import ResponseDone from synapse.config.repository import MediaStorageProviderConfig -from synapse.util.logcontext import make_deferred_yieldable from synapse.util.module_loader import load_module from tests import unittest +from tests.server import FakeTransport @attr.s @@ -78,7 +73,7 @@ def make_homeserver(self, reactor, clock): config = self.default_config() config.url_preview_enabled = True config.max_spider_size = 9999999 - config.url_preview_ip_range_blacklist = IPSet(("192.168.1.1", "1.0.0.0/8")) + config.url_preview_ip_range_blacklist = IPSet(("192.168.1.1", "1.0.0.0/8", "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "2001:800::/21")) config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",)) config.url_preview_url_blacklist = [] config.media_store_path = self.storage_path @@ -106,13 +101,6 @@ def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.preview_url = self.media_repo.children[b'preview_url'] - class Agent(object): - def request(_self, *args, **kwargs): - return self._on_request(*args, **kwargs) - - # Load in the Agent we want - self.preview_url.client._make_agent(Agent()) - self.lookups = {} class Resolver(object): @@ -131,56 +119,44 @@ def resolveHostName( raise DNSLookupError("OH NO") for i in self.lookups[hostName]: - resolutionReceiver.addressResolved( - i[0]('TCP', i[1], portNumber) - ) + resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber)) resolutionReceiver.resolutionComplete() return resolutionReceiver self.reactor.nameResolver = Resolver() def test_cache_returns_correct_type(self): - - calls = [0] - - def _on_request(method, uri, headers=None, bodyProducer=None): - - calls[0] += 1 - h = Headers( - { - b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b'text/html; charset="utf8"'], - } - ) - resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) - return succeed(resp) - - self._on_request = _on_request - - def test_cache_returns_correct_type(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) + self.pump() self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - self.assertEqual(calls[0], 1) - # Check the cache returns the correct response request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # Only one fetch, still, since we'll lean on the cache - self.assertEqual(calls[0], 1) - # Check the cache response has the same content self.assertEqual(channel.code, 200) self.assertEqual( @@ -188,20 +164,17 @@ def test_cache_returns_correct_type(self): ) # Clear the in-memory cache - self.assertIn("matrix.org", self.preview_url._cache) - self.preview_url._cache.pop("matrix.org") - self.assertNotIn("matrix.org", self.preview_url._cache) + self.assertIn("http://matrix.org", self.preview_url._cache) + self.preview_url._cache.pop("http://matrix.org") + self.assertNotIn("http://matrix.org", self.preview_url._cache) # Check the database cache returns the correct response request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) self.pump() - # Only one fetch, still, since we'll lean on the cache - self.assertEqual(calls[0], 1) - # Check the cache response has the same content self.assertEqual(channel.code, 200) self.assertEqual( @@ -209,6 +182,7 @@ def test_cache_returns_correct_type(self): ) def test_non_ascii_preview_httpequiv(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( b'' @@ -218,31 +192,31 @@ def test_non_ascii_preview_httpequiv(self): b'' ) - def _on_request(method, uri, headers=None, bodyProducer=None): - - h = Headers( - { - b"Content-Length": [b"%d" % (len(end_content))], - # This charset=utf-8 should be ignored, because the - # document has a meta tag overriding it. - b"Content-Type": [b'text/html; charset="utf8"'], - } - ) - resp = FakeResponse(b"1.1", 200, b"OK", h, end_content, uri) - return succeed(resp) - - self._on_request = _on_request - request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n" + ) + % (len(end_content),) + + end_content + ) self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") def test_non_ascii_preview_content_type(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( b'' @@ -251,23 +225,25 @@ def test_non_ascii_preview_content_type(self): b'' ) - def _on_request(method, uri, headers=None, bodyProducer=None): - - h = Headers( - { - b"Content-Length": [b"%d" % (len(end_content))], - b"Content-Type": [b'text/html; charset="windows-1251"'], - } - ) - resp = FakeResponse(b"1.1", 200, b"OK", h, end_content, uri) - return succeed(resp) - - self._on_request = _on_request - request, channel = self.make_request( - "GET", "url_preview?url=matrix.org", shorthand=False + "GET", "url_preview?url=http://matrix.org", shorthand=False ) request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n" + ) + % (len(end_content),) + + end_content + ) + self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") @@ -276,24 +252,24 @@ def test_ipaddr(self): """ IP addresses can be previewed directly. """ - - def _on_request(method, uri, headers=None, bodyProducer=None): - - h = Headers( - { - b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b'text/html'], - } - ) - resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) - return succeed(resp) - - self._on_request = _on_request + self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] request, channel = self.make_request( "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) + self.pump() self.assertEqual(channel.code, 200) self.assertEqual( @@ -302,27 +278,18 @@ def _on_request(method, uri, headers=None, bodyProducer=None): def test_blacklisted_ip_specific(self): """ - Blacklisted IP addresses are not spidered. + Blacklisted IP addresses, found via DNS, are not spidered. """ - - def _on_request(method, uri, headers=None, bodyProducer=None): - - h = Headers( - { - b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b'text/html'], - } - ) - resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) - return succeed(resp) - - self._on_request = _on_request + self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] request, channel = self.make_request( "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) self.pump() + + # No requests made. + self.assertEqual(len(self.reactor.tcpClients), 0) self.assertEqual(channel.code, 403) self.assertEqual( channel.json_body, @@ -334,24 +301,52 @@ def _on_request(method, uri, headers=None, bodyProducer=None): def test_blacklisted_ip_range(self): """ - Blacklisted IP ranges are not spidered. + Blacklisted IP ranges, IPs found over DNS, are not spidered. """ + self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] - def _on_request(method, uri, headers=None, bodyProducer=None): + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() - h = Headers( - { - b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b'text/html'], - } - ) - resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) - return succeed(resp) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) - self._on_request = _on_request + def test_blacklisted_ip_specific_direct(self): + """ + Blacklisted IP addresses, accessed directly, are not spidered. + """ + request, channel = self.make_request( + "GET", "url_preview?url=http://192.168.1.1", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # No requests made. + self.assertEqual(len(self.reactor.tcpClients), 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + def test_blacklisted_ip_range_direct(self): + """ + Blacklisted IP ranges, accessed directly, are not spidered. + """ request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + "GET", "url_preview?url=http://1.1.1.2", shorthand=False ) request.render(self.preview_url) self.pump() @@ -370,23 +365,26 @@ def test_blacklisted_ip_range_whitelisted_ip(self): Blacklisted but then subsequently whitelisted IP addresses can be spidered. """ - def _on_request(method, uri, headers=None, bodyProducer=None): - - h = Headers( - { - b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b'text/html'], - } - ) - resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) - return succeed(resp) - - self._on_request = _on_request + self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] request, channel = self.make_request( "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) + self.pump() self.assertEqual(channel.code, 200) self.assertEqual( @@ -398,27 +396,61 @@ def test_blacklisted_ip_with_external_ip(self): If a hostname resolves a blacklisted IP, even if there's a non-blacklisted one, it will be rejected. """ - def _on_request(method, uri, headers=None, bodyProducer=None): + # Hardcode the URL resolving to the IP we want. + self.lookups[u"example.com"] = [ + (IPv4Address, "1.1.1.2"), + (IPv4Address, "8.8.8.8"), + ] - h = Headers( - { - b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b'text/html'], - } - ) - resp = FakeResponse(b"1.1", 200, b"OK", h, self.end_content, uri) - return succeed(resp) + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) - self._on_request = _on_request + def test_blacklisted_ipv6_specific(self): + """ + Blacklisted IP addresses, found via DNS, are not spidered. + """ + self.lookups["example.com"] = [(IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")] - # Hardcode the URL resolving to the IP we want. - self.lookups[u"example.com"] = ["1.1.1.2", "8.8.8.8"] + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # No requests made. + self.assertEqual(len(self.reactor.tcpClients), 0) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body, + { + 'errcode': 'M_UNKNOWN', + 'error': 'IP address blocked by IP blacklist entry', + }, + ) + + def test_blacklisted_ipv6_range(self): + """ + Blacklisted IP ranges, IPs found over DNS, are not spidered. + """ + self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] request, channel = self.make_request( "GET", "url_preview?url=http://example.com", shorthand=False ) request.render(self.preview_url) self.pump() + self.assertEqual(channel.code, 403) self.assertEqual( channel.json_body, diff --git a/tests/server.py b/tests/server.py index ceec2f2d4e92..db43fa0db8df 100644 --- a/tests/server.py +++ b/tests/server.py @@ -383,8 +383,16 @@ def abortConnection(self): self.disconnecting = True def pauseProducing(self): + if not self.producer: + return + self.producer.pauseProducing() + def resumeProducing(self): + if not self.producer: + return + self.producer.resumeProducing() + def unregisterProducer(self): if not self.producer: return From 168a4940fadd54ead831beb3ded5032271a6d211 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 6 Dec 2018 01:11:38 +1100 Subject: [PATCH 19/25] fix --- synapse/http/client.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 3505125a9936..5612ee7b82c0 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -262,9 +262,6 @@ def request(self, method, uri, data=b'', headers=None): ) defer.returnValue(response) except Exception as e: - print(e) - import traceback - traceback.print_exc() incoming_responses_counter.labels(method, "ERR").inc() logger.info( "Error sending request to %s %s: %s %s", From 920625f05134b33e670361a8e9978e3c0027a2f9 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 6 Dec 2018 01:13:12 +1100 Subject: [PATCH 20/25] fix --- tests/rest/media/v1/test_url_preview.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 5693c224cceb..650ce95a6f6d 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -73,7 +73,14 @@ def make_homeserver(self, reactor, clock): config = self.default_config() config.url_preview_enabled = True config.max_spider_size = 9999999 - config.url_preview_ip_range_blacklist = IPSet(("192.168.1.1", "1.0.0.0/8", "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "2001:800::/21")) + config.url_preview_ip_range_blacklist = IPSet( + ( + "192.168.1.1", + "1.0.0.0/8", + "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "2001:800::/21", + ) + ) config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",)) config.url_preview_url_blacklist = [] config.media_store_path = self.storage_path @@ -420,7 +427,9 @@ def test_blacklisted_ipv6_specific(self): """ Blacklisted IP addresses, found via DNS, are not spidered. """ - self.lookups["example.com"] = [(IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")] + self.lookups["example.com"] = [ + (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") + ] request, channel = self.make_request( "GET", "url_preview?url=http://example.com", shorthand=False From 7d17b4756aaa780e8571284bd3e58ed639619c13 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 6 Dec 2018 04:03:02 +1100 Subject: [PATCH 21/25] fix py2 --- synapse/http/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 5612ee7b82c0..f8242bf613f3 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -112,7 +112,7 @@ def _callback(addrs): class BlacklistingAgentWrapper(Agent): - def __init__(self, agent, reactor, *args, whitelist=None, blacklist=None, **kwargs): + def __init__(self, agent, reactor, whitelist=None, blacklist=None): self._agent = agent self._whitelist = whitelist self._blacklist = blacklist From 79a565520c3680d7ec61668a8cf0ad241778f617 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 15 Dec 2018 01:10:37 +1100 Subject: [PATCH 22/25] fix missing hyperlink --- synapse/config/repository.py | 10 ++++++++++ synapse/http/client.py | 3 ++- tests/rest/media/v1/test_url_preview.py | 5 +++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 06c62ab62c0b..b6360588952b 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -23,6 +23,10 @@ "Missing netaddr library. This is required for URL preview API." ) +MISSING_HYPERLINK = ( + "Missing hyperlink library. This is required for URL preview API." +) + MISSING_LXML = ( """Missing lxml library. This is required for URL preview API. @@ -151,6 +155,12 @@ def read_config(self, config): except ImportError: raise ConfigError(MISSING_LXML) + try: + import hyperlink + hyperlink # To stop unused lint. + except ImportError: + raise ConfigError(MISSING_HYPERLINK) + try: from netaddr import IPSet except ImportError: diff --git a/synapse/http/client.py b/synapse/http/client.py index f8242bf613f3..3ccc3043b910 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -21,7 +21,6 @@ import treq from canonicaljson import encode_canonical_json, json -from hyperlink import URL from netaddr import IPAddress from prometheus_client import Counter from zope.interface import implementer, provider @@ -118,6 +117,8 @@ def __init__(self, agent, reactor, whitelist=None, blacklist=None): self._blacklist = blacklist def request(self, method, uri, headers=None, bodyProducer=None): + from hyperlink import URL + h = URL.from_text(uri.decode('ascii')) try: diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 650ce95a6f6d..b4353cd91fdf 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -56,6 +56,11 @@ def deliverBody(self, protocol): class URLPreviewTests(unittest.HomeserverTestCase): + try: + from hyperlink import URL + except ImportError: + skip = "Hyperlink is missing -- running on an older Twisted" + hijack_auth = True user_id = "@test:user" end_content = ( From 34952cefe719eb03e004913499778c353f119b92 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 20 Dec 2018 18:47:56 +1100 Subject: [PATCH 23/25] review comments --- synapse/http/client.py | 132 +++++++++++------- synapse/rest/media/v1/preview_url_resource.py | 4 +- 2 files changed, 81 insertions(+), 55 deletions(-) diff --git a/synapse/http/client.py b/synapse/http/client.py index 3ccc3043b910..84aa4eb3dcf7 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -52,18 +52,35 @@ ) -def check_against_blacklist(ip_address, whitelist, blacklist): - if ip_address in blacklist: - if whitelist is None or ip_address not in whitelist: +def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): + """ + Args: + ip_address (netaddr.IPAddress) + ip_whitelist (netaddr.IPSet) + ip_blacklist (netaddr.IPSet) + """ + if ip_address in ip_blacklist: + if ip_whitelist is None or ip_address not in ip_whitelist: return True return False class IPBlacklistingResolver(object): - def __init__(self, reactor, whitelist, blacklist): + """ + A proxy for reactor.nameResolver which only produces non-blacklisted IP + addresses, preventing DNS rebinding attacks on URL preview. + """ + + def __init__(self, reactor, ip_whitelist, ip_blacklist): + """ + Args: + reactor (twisted.internet.reactor) + ip_whitelist (netaddr.IPSet) + ip_blacklist (netaddr.IPSet) + """ self._reactor = reactor - self._whitelist = whitelist - self._blacklist = blacklist + self._ip_whitelist = ip_whitelist + self._ip_blacklist = ip_blacklist def resolveHostName(self, recv, hostname, portNumber=0): @@ -82,7 +99,7 @@ def addressResolved(address): ip_address = IPAddress(address.host) if check_against_blacklist( - ip_address, self._whitelist, self._blacklist + ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info( "Dropped %s from DNS resolution to %s" % (ip_address, hostname) @@ -111,10 +128,22 @@ def _callback(addrs): class BlacklistingAgentWrapper(Agent): - def __init__(self, agent, reactor, whitelist=None, blacklist=None): + """ + An Agent wrapper which will prevent access to IP addresses being accessed + directly (without an IP address lookup). + """ + + def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None): + """ + Args: + agent (twisted.web.client.Agent): The Agent to wrap. + reactor (twisted.internet.reactor) + ip_whitelist (netaddr.IPSet) + ip_blacklist (netaddr.IPSet) + """ self._agent = agent - self._whitelist = whitelist - self._blacklist = blacklist + self._ip_whitelist = ip_whitelist + self._ip_blacklist = ip_blacklist def request(self, method, uri, headers=None, bodyProducer=None): from hyperlink import URL @@ -124,7 +153,9 @@ def request(self, method, uri, headers=None, bodyProducer=None): try: ip_address = IPAddress(h.host) - if check_against_blacklist(ip_address, self._whitelist, self._blacklist): + if check_against_blacklist( + ip_address, self._ip_whitelist, self._ip_blacklist + ): logger.info( "Blocking access to %s because of blacklist" % (ip_address,) ) @@ -145,58 +176,35 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args={}, whitelist=None, blacklist=None): + def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): """ Args: hs (synapse.server.HomeServer) treq_args (dict): Extra keyword arguments to be given to treq.request. - blacklist (netaddr.IPSet): The IP addresses that are blacklisted that + ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that we may not request. - whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can + ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. """ self.hs = hs - self._whitelist = whitelist - self._blacklist = blacklist + self._ip_whitelist = ip_whitelist + self._ip_blacklist = ip_blacklist self._extra_treq_args = treq_args self.user_agent = hs.version_string self.clock = hs.get_clock() - self.reactor = hs.get_reactor() if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) self.user_agent = self.user_agent.encode('ascii') - self._make_agent() - - def _make_agent(self, _agent=False): - - if _agent: - self.agent = _agent - else: - # the pusher makes lots of concurrent SSL connections to sygnal, and - # tends to do so in batches, so we need to allow the pool to keep - # lots of idle connections around. - pool = HTTPConnectionPool(self.reactor) - pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) - pool.cachedConnectionTimeout = 2 * 60 - - # The default context factory in Twisted 14.0.0 (which we require) is - # BrowserLikePolicyForHTTPS which will do regular cert validation - # 'like a browser' - self.agent = Agent( - self.reactor, - connectTimeout=15, - contextFactory=self.hs.get_http_client_context_factory(), - pool=pool, - ) - - # If we have an IP blacklist, use the blacklisting Agent wrapper. - if self._blacklist: + if self._ip_blacklist: + real_reactor = hs.get_reactor() + # If we have an IP blacklist, we need to use a DNS resolver which + # filters out blacklisted IP addresses, to prevent DNS rebinding. nameResolver = IPBlacklistingResolver( - self.reactor, self._whitelist, self._blacklist + real_reactor, self._ip_whitelist, self._ip_blacklist ) @implementer(IReactorPluggableNameResolver) @@ -205,20 +213,38 @@ def __getattr__(_self, attr): if attr == "nameResolver": return nameResolver else: - return getattr(self.reactor, attr) + return getattr(real_reactor, attr) - self.agent = Agent( - Reactor(), - connectTimeout=15, - contextFactory=self.hs.get_http_client_context_factory(), - pool=pool, - ) + self.reactor = Reactor() + else: + self.reactor = hs.get_reactor() + + # the pusher makes lots of concurrent SSL connections to sygnal, and + # tends to do so in batches, so we need to allow the pool to keep + # lots of idle connections around. + pool = HTTPConnectionPool(self.reactor) + pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + pool.cachedConnectionTimeout = 2 * 60 + + # The default context factory in Twisted 14.0.0 (which we require) is + # BrowserLikePolicyForHTTPS which will do regular cert validation + # 'like a browser' + self.agent = Agent( + self.reactor, + connectTimeout=15, + contextFactory=self.hs.get_http_client_context_factory(), + pool=pool, + ) + if self._ip_blacklist: + # If we have an IP blacklist, we then install the blacklisting Agent + # which prevents direct access to IP addresses, that are not caught + # by the DNS resolution. self.agent = BlacklistingAgentWrapper( self.agent, self.reactor, - whitelist=self._whitelist, - blacklist=self._blacklist, + ip_whitelist=self._ip_whitelist, + ip_blacklist=self._ip_blacklist, ) @defer.inlineCallbacks diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 6bdd1d3442f6..fb6acecebe86 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -72,8 +72,8 @@ def __init__(self, hs, media_repo, media_storage): self.client = SimpleHttpClient( hs, treq_args={"browser_like_redirects": True}, - whitelist=hs.config.url_preview_ip_range_whitelist, - blacklist=hs.config.url_preview_ip_range_blacklist, + ip_whitelist=hs.config.url_preview_ip_range_whitelist, + ip_blacklist=hs.config.url_preview_ip_range_blacklist, ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path From 9e64ab8a929b1a2b2469a8c575aa8c194746b8cb Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 20 Dec 2018 18:49:17 +1100 Subject: [PATCH 24/25] review comments --- synapse/rest/media/v1/preview_url_resource.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index fb6acecebe86..ba3ab1d37dae 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -326,8 +326,7 @@ def _download_url(self, url, user): except SynapseError: # Pass SynapseErrors through directly, so that the servlet # handler will return a SynapseError to the client instead of - # blank data or a 500. Currently, this is only if the IP we are - # trying to fetch from is blacklisted. + # blank data or a 500. raise except Exception as e: # FIXME: pass through 404s and other error messages nicely From 3b9e19009da98ba387ddecf054b04a961414b5ca Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 20 Dec 2018 20:53:03 +1100 Subject: [PATCH 25/25] urllib does what we need here --- synapse/config/repository.py | 10 ---------- synapse/http/client.py | 6 ++---- tests/rest/media/v1/test_url_preview.py | 5 ----- 3 files changed, 2 insertions(+), 19 deletions(-) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index b6360588952b..06c62ab62c0b 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -23,10 +23,6 @@ "Missing netaddr library. This is required for URL preview API." ) -MISSING_HYPERLINK = ( - "Missing hyperlink library. This is required for URL preview API." -) - MISSING_LXML = ( """Missing lxml library. This is required for URL preview API. @@ -155,12 +151,6 @@ def read_config(self, config): except ImportError: raise ConfigError(MISSING_LXML) - try: - import hyperlink - hyperlink # To stop unused lint. - except ImportError: - raise ConfigError(MISSING_HYPERLINK) - try: from netaddr import IPSet except ImportError: diff --git a/synapse/http/client.py b/synapse/http/client.py index 84aa4eb3dcf7..afcf698b294e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -146,12 +146,10 @@ def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None): self._ip_blacklist = ip_blacklist def request(self, method, uri, headers=None, bodyProducer=None): - from hyperlink import URL - - h = URL.from_text(uri.decode('ascii')) + h = urllib.parse.urlparse(uri.decode('ascii')) try: - ip_address = IPAddress(h.host) + ip_address = IPAddress(h.hostname) if check_against_blacklist( ip_address, self._ip_whitelist, self._ip_blacklist diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index b4353cd91fdf..650ce95a6f6d 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -56,11 +56,6 @@ def deliverBody(self, protocol): class URLPreviewTests(unittest.HomeserverTestCase): - try: - from hyperlink import URL - except ImportError: - skip = "Hyperlink is missing -- running on an older Twisted" - hijack_auth = True user_id = "@test:user" end_content = (