From cccb9cbf267cd753c958297f6f3784d4f20799db Mon Sep 17 00:00:00 2001 From: Sergey Skripnick Date: Thu, 3 Dec 2015 23:51:38 +0200 Subject: [PATCH] Fix wsgi.environment unix socket issue --- aiohttp/wsgi.py | 24 +++++++++++++++--------- tests/test_wsgi.py | 13 +++++++++++++ 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/aiohttp/wsgi.py b/aiohttp/wsgi.py index 841ce14dc73..b4239bf94ab 100644 --- a/aiohttp/wsgi.py +++ b/aiohttp/wsgi.py @@ -88,16 +88,22 @@ def create_wsgi_environ(self, message, payload): # http://www.ietf.org/rfc/rfc3875 remote = self.transport.get_extra_info('peername') - environ['REMOTE_ADDR'] = remote[0] - environ['REMOTE_PORT'] = remote[1] - - sockname = self.transport.get_extra_info('sockname') - environ['SERVER_PORT'] = str(sockname[1]) - host = message.headers.get("HOST", None) - if host: - environ['SERVER_NAME'] = host.split(":")[0] + if remote: + environ['REMOTE_ADDR'] = remote[0] + environ['REMOTE_PORT'] = remote[1] + _host, port = self.transport.get_extra_info('sockname') + environ['SERVER_PORT'] = str(port) + host = message.headers.get("HOST", None) + # SERVER_NAME should be set to value of Host header, but this + # header is not required. In this case we shoud set it to local + # address of socket + environ['SERVER_NAME'] = host.split(":")[0] if host else _host else: - environ['SERVER_NAME'] = sockname[0] + # Dealing with unix socket, so request was received from client by + # upstream server and this data may be found in the headers + for header in ('REMOTE_ADDR', 'REMOTE_PORT', + 'SERVER_NAME', 'SERVER_PORT'): + environ[header] = message.headers.get(header, '') path_info = uri_parts.path if script_name: diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index 2d3593c8249..b28eb7ba896 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -275,3 +275,16 @@ def test_http_1_0_no_host(self): environ = self._make_one() self.assertEqual(environ['SERVER_NAME'], '2.3.4.5') self.assertEqual(environ['SERVER_PORT'], '80') + + def test_unix_socket(self): + self.transport.get_extra_info = unittest.mock.Mock(return_value=None) + headers = multidict.MultiDict({ + 'SERVER_NAME': '1.2.3.4', 'SERVER_PORT': '5678', + 'REMOTE_ADDR': '4.3.2.1', 'REMOTE_PORT': '8765'}) + self.message = protocol.RawRequestMessage( + 'GET', '/', (1, 0), headers, True, 'deflate') + environ = self._make_one() + self.assertEqual(environ['SERVER_NAME'], '1.2.3.4') + self.assertEqual(environ['SERVER_PORT'], '5678') + self.assertEqual(environ['REMOTE_ADDR'], '4.3.2.1') + self.assertEqual(environ['REMOTE_PORT'], '8765')