Skip to content

Commit

Permalink
Test connector (#2721)
Browse files Browse the repository at this point in the history
* fixes #2717 Makes unittest example work (#2718)

* Convert a couplt asserts

* Convert test

* Convert other test

* Convert more

* Cleanup
  • Loading branch information
asvetlov authored Feb 11, 2018
1 parent 00b1f7f commit 07d8995
Showing 1 changed file with 119 additions and 138 deletions.
257 changes: 119 additions & 138 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
import hashlib
import os.path
import platform
import shutil
import socket
import ssl
import tempfile
import unittest
import uuid
from unittest import mock

Expand Down Expand Up @@ -42,6 +39,29 @@ def ssl_key():
return ('localhost', 80, True)


@pytest.fixture
def unix_sockname(tmpdir):
sock_path = tmpdir / 'socket.sock'
return str(sock_path)


@pytest.fixture
def unix_server(loop, unix_sockname):
runners = []

async def go(app):
runner = web.AppRunner(app)
runners.append(runner)
await runner.setup()
site = web.UnixSite(runner, unix_sockname)
await site.start()

yield go

for runner in runners:
loop.run_until_complete(runner.cleanup())


def test_del(loop):
conn = aiohttp.BaseConnector(loop=loop)
proto = mock.Mock(should_close=False)
Expand Down Expand Up @@ -1813,173 +1833,134 @@ def test_default_use_dns_cache(loop):
assert conn.use_dns_cache


class TestHttpClientConnector(unittest.TestCase):
async def test_resolver_not_called_with_address_is_ip(loop):
resolver = mock.MagicMock()
connector = aiohttp.TCPConnector(resolver=resolver)

def setUp(self):
self.handler = None
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
req = ClientRequest('GET',
URL('http://127.0.0.1:{}'.format(unused_port())),
loop=loop,
response_class=mock.Mock())

def tearDown(self):
if self.handler:
self.loop.run_until_complete(self.handler.shutdown())
self.loop.stop()
self.loop.run_forever()
self.loop.close()
gc.collect()
with pytest.raises(OSError):
await connector.connect(req)

async def create_server(self, method, path, handler, ssl_context=None):
app = web.Application()
app.router.add_route(method, path, handler)

port = unused_port()
self.handler = app.make_handler(loop=self.loop, tcp_keepalive=False)
srv = await self.loop.create_server(
self.handler, '127.0.0.1', port, ssl=ssl_context)
scheme = 's' if ssl_context is not None else ''
url = "http{}://127.0.0.1:{}".format(scheme, port) + path
self.addCleanup(srv.close)
return app, srv, url

async def create_unix_server(self, method, path, handler):
tmpdir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, tmpdir)
app = web.Application()
app.router.add_route(method, path, handler)

self.handler = app.make_handler(
loop=self.loop, tcp_keepalive=False, access_log=None)
sock_path = os.path.join(tmpdir, 'socket.sock')
srv = await self.loop.create_unix_server(
self.handler, sock_path)
url = "http://127.0.0.1" + path
self.addCleanup(srv.close)
return app, srv, url, sock_path

def test_tcp_connector_raise_connector_ssl_error(self):
async def handler(request):
return web.Response()

here = os.path.join(os.path.dirname(__file__), '..', 'tests')
keyfile = os.path.join(here, 'sample.key')
certfile = os.path.join(here, 'sample.crt')
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.load_cert_chain(certfile, keyfile)

app, srv, url = self.loop.run_until_complete(
self.create_server('get', '/', handler, ssl_context=sslcontext)
)
resolver.resolve.assert_not_called()

port = unused_port()
conn = aiohttp.TCPConnector(loop=self.loop,
local_addr=('127.0.0.1', port))

session = aiohttp.ClientSession(connector=conn)
async def test_tcp_connector_raise_connector_ssl_error(aiohttp_server):
async def handler(request):
return web.Response()

with pytest.raises(aiohttp.ClientConnectorSSLError) as ctx:
self.loop.run_until_complete(session.request('get', url))
app = web.Application()
app.router.add_get('/', handler)

self.assertIsInstance(ctx.value.os_error, ssl.SSLError)
self.assertIsInstance(ctx.value, aiohttp.ClientSSLError)
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
keyfile = os.path.join(here, 'sample.key')
certfile = os.path.join(here, 'sample.crt')
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.load_cert_chain(certfile, keyfile)

self.loop.run_until_complete(session.close())
conn.close()
srv = await aiohttp_server(app, ssl=sslcontext)

def test_tcp_connector_do_not_raise_connector_ssl_error(self):
async def handler(request):
return web.Response()
port = unused_port()
conn = aiohttp.TCPConnector(local_addr=('127.0.0.1', port))

here = os.path.join(os.path.dirname(__file__), '..', 'tests')
keyfile = os.path.join(here, 'sample.key')
certfile = os.path.join(here, 'sample.crt')
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.load_cert_chain(certfile, keyfile)
session = aiohttp.ClientSession(connector=conn)
url = srv.make_url('/')

app, srv, url = self.loop.run_until_complete(
self.create_server('get', '/', handler, ssl_context=sslcontext)
)
with pytest.raises(aiohttp.ClientConnectorSSLError) as ctx:
print(url)
await session.get(url)

assert isinstance(ctx.value.os_error, ssl.SSLError)
assert isinstance(ctx.value, aiohttp.ClientSSLError)

await session.close()

port = unused_port()
conn = aiohttp.TCPConnector(loop=self.loop,
local_addr=('127.0.0.1', port))

session = aiohttp.ClientSession(connector=conn)
async def test_tcp_connector_do_not_raise_connector_ssl_error(aiohttp_server):
async def handler(request):
return web.Response()

r = self.loop.run_until_complete(
session.request('get', url, ssl=sslcontext))
app = web.Application()
app.router.add_get('/', handler)

r.release()
first_conn = next(iter(conn._conns.values()))[0][0]
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
keyfile = os.path.join(here, 'sample.key')
certfile = os.path.join(here, 'sample.crt')
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.load_cert_chain(certfile, keyfile)

try:
_sslcontext = first_conn.transport._ssl_protocol._sslcontext
except AttributeError:
_sslcontext = first_conn.transport._sslcontext
srv = await aiohttp_server(app, ssl=sslcontext)
port = unused_port()
conn = aiohttp.TCPConnector(local_addr=('127.0.0.1', port))

self.assertIs(_sslcontext, sslcontext)
r.close()
session = aiohttp.ClientSession(connector=conn)
url = srv.make_url('/')

self.loop.run_until_complete(session.close())
conn.close()
r = await session.get(url, ssl=sslcontext)

def test_tcp_connector_uses_provided_local_addr(self):
async def handler(request):
return web.Response()
r.release()
first_conn = next(iter(conn._conns.values()))[0][0]

app, srv, url = self.loop.run_until_complete(
self.create_server('get', '/', handler)
)
try:
_sslcontext = first_conn.transport._ssl_protocol._sslcontext
except AttributeError:
_sslcontext = first_conn.transport._sslcontext

port = unused_port()
conn = aiohttp.TCPConnector(loop=self.loop,
local_addr=('127.0.0.1', port))
assert _sslcontext is sslcontext
r.close()

session = aiohttp.ClientSession(connector=conn)
await session.close()
conn.close()

r = self.loop.run_until_complete(
session.request('get', url)
)

r.release()
first_conn = next(iter(conn._conns.values()))[0][0]
self.assertEqual(
first_conn.transport._sock.getsockname(), ('127.0.0.1', port))
r.close()
self.loop.run_until_complete(session.close())
conn.close()
async def test_tcp_connector_uses_provided_local_addr(aiohttp_server):
async def handler(request):
return web.Response()

@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'requires unix')
def test_unix_connector(self):
async def handler(request):
return web.Response()
app = web.Application()
app.router.add_get('/', handler)
srv = await aiohttp_server(app)

app, srv, url, sock_path = self.loop.run_until_complete(
self.create_unix_server('get', '/', handler))
port = unused_port()
conn = aiohttp.TCPConnector(local_addr=('127.0.0.1', port))

connector = aiohttp.UnixConnector(sock_path, loop=self.loop)
self.assertEqual(sock_path, connector.path)
session = aiohttp.ClientSession(connector=conn)
url = srv.make_url('/')

session = client.ClientSession(
connector=connector, loop=self.loop)
r = self.loop.run_until_complete(
session.request('get', url))
self.assertEqual(r.status, 200)
r.close()
self.loop.run_until_complete(session.close())
r = await session.get(url)
r.release()

def test_resolver_not_called_with_address_is_ip(self):
resolver = mock.MagicMock()
connector = aiohttp.TCPConnector(resolver=resolver, loop=self.loop)
first_conn = next(iter(conn._conns.values()))[0][0]
assert first_conn.transport.get_extra_info(
'sockname') == ('127.0.0.1', port)
r.close()
await session.close()
conn.close()

req = ClientRequest('GET',
URL('http://127.0.0.1:{}'.format(unused_port())),
loop=self.loop,
response_class=mock.Mock())

with self.assertRaises(OSError):
self.loop.run_until_complete(connector.connect(req))
@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason='requires UNIX sockets')
async def test_unix_connector(unix_server, unix_sockname):
async def handler(request):
return web.Response()

resolver.resolve.assert_not_called()
app = web.Application()
app.router.add_get('/', handler)
await unix_server(app)

url = "http://127.0.0.1/"

connector = aiohttp.UnixConnector(unix_sockname)
assert unix_sockname == connector.path

session = client.ClientSession(connector=connector)
r = await session.get(url)
assert r.status == 200
r.close()
await session.close()


class TestDNSCacheTable:
Expand Down

0 comments on commit 07d8995

Please sign in to comment.