From 13097c691db6b8cfc645e5620860c141e28fa696 Mon Sep 17 00:00:00 2001 From: Martin Richard Date: Mon, 29 Aug 2016 12:09:33 +0200 Subject: [PATCH] use aiodns.DNSResolver.gethostbyname() if available --- aiohttp/resolver.py | 18 ++++++++++ tests/test_resolver.py | 76 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 80 insertions(+), 14 deletions(-) diff --git a/aiohttp/resolver.py b/aiohttp/resolver.py index ac6b152cf5f..3d181919689 100644 --- a/aiohttp/resolver.py +++ b/aiohttp/resolver.py @@ -52,8 +52,25 @@ def __init__(self, loop=None, *args, **kwargs): self._loop = loop self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs) + if not hasattr(self._resolver, 'gethostbyname'): + # aiodns 1.1 is not available, fallback to DNSResolver.query + self.resolve = self.resolve_with_query + @asyncio.coroutine def resolve(self, host, port=0, family=socket.AF_INET): + hosts = [] + resp = yield from self._resolver.gethostbyname(host, family) + + for address in resp.addresses: + hosts.append( + {'hostname': host, + 'host': address, 'port': port, + 'family': family, 'proto': 0, + 'flags': socket.AI_NUMERICHOST}) + return hosts + + @asyncio.coroutine + def resolve_with_query(self, host, port=0, family=socket.AF_INET): if family == socket.AF_INET6: qtype = 'AAAA' else: @@ -68,6 +85,7 @@ def resolve(self, host, port=0, family=socket.AF_INET): 'host': rr.host, 'port': port, 'family': family, 'proto': 0, 'flags': socket.AI_NUMERICHOST}) + return hosts @asyncio.coroutine diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 2adc0e1d491..bb349ae2f9b 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -9,18 +9,30 @@ try: import aiodns + gethostbyname = hasattr(aiodns.DNSResolver, 'gethostbyname') except ImportError: aiodns = None + gethostbyname = False class FakeResult: + def __init__(self, addresses): + self.addresses = addresses + + +class FakeQueryResult: def __init__(self, host): self.host = host @asyncio.coroutine -def fake_result(result): - return [FakeResult(host=h) +def fake_result(addresses): + return FakeResult(addresses=tuple(addresses)) + + +@asyncio.coroutine +def fake_query_result(result): + return [FakeQueryResult(host=h) for h in result] @@ -36,23 +48,36 @@ def fake(*args, **kwargs): return fake -@pytest.mark.skipif(aiodns is None, reason="aiodns required") +@pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") @asyncio.coroutine def test_async_resolver_positive_lookup(loop): - with patch('aiodns.DNSResolver.query') as mock_query: - mock_query.return_value = fake_result(['127.0.0.1']) + with patch('aiodns.DNSResolver') as mock: + mock().gethostbyname.return_value = fake_result(['127.0.0.1']) resolver = AsyncResolver(loop=loop) real = yield from resolver.resolve('www.python.org') ipaddress.ip_address(real[0]['host']) - mock_query.assert_called_with('www.python.org', 'A') + mock().gethostbyname.assert_called_with('www.python.org', + socket.AF_INET) @pytest.mark.skipif(aiodns is None, reason="aiodns required") @asyncio.coroutine +def test_async_resolver_query_positive_lookup(loop): + with patch('aiodns.DNSResolver') as mock: + del mock().gethostbyname + mock().query.return_value = fake_query_result(['127.0.0.1']) + resolver = AsyncResolver(loop=loop) + real = yield from resolver.resolve('www.python.org') + ipaddress.ip_address(real[0]['host']) + mock().query.assert_called_with('www.python.org', 'A') + + +@pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") +@asyncio.coroutine def test_async_resolver_multiple_replies(loop): - with patch('aiodns.DNSResolver.query') as mock_query: + with patch('aiodns.DNSResolver') as mock: ips = ['127.0.0.1', '127.0.0.2', '127.0.0.3', '127.0.0.4'] - mock_query.return_value = fake_result(ips) + mock().gethostbyname.return_value = fake_result(ips) resolver = AsyncResolver(loop=loop) real = yield from resolver.resolve('www.google.com') ips = [ipaddress.ip_address(x['host']) for x in real] @@ -61,9 +86,32 @@ def test_async_resolver_multiple_replies(loop): @pytest.mark.skipif(aiodns is None, reason="aiodns required") @asyncio.coroutine -def test_async_negative_lookup(loop): - with patch('aiodns.DNSResolver.query') as mock_query: - mock_query.side_effect = aiodns.error.DNSError() +def test_async_resolver_query_multiple_replies(loop): + with patch('aiodns.DNSResolver') as mock: + del mock().gethostbyname + ips = ['127.0.0.1', '127.0.0.2', '127.0.0.3', '127.0.0.4'] + mock().query.return_value = fake_query_result(ips) + resolver = AsyncResolver(loop=loop) + real = yield from resolver.resolve('www.google.com') + ips = [ipaddress.ip_address(x['host']) for x in real] + + +@pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") +@asyncio.coroutine +def test_async_resolver_negative_lookup(loop): + with patch('aiodns.DNSResolver') as mock: + mock().gethostbyname.side_effect = aiodns.error.DNSError() + resolver = AsyncResolver(loop=loop) + with pytest.raises(aiodns.error.DNSError): + yield from resolver.resolve('doesnotexist.bla') + + +@pytest.mark.skipif(aiodns is None, reason="aiodns required") +@asyncio.coroutine +def test_async_resolver_query_negative_lookup(loop): + with patch('aiodns.DNSResolver') as mock: + del mock().gethostbyname + mock().query.side_effect = aiodns.error.DNSError() resolver = AsyncResolver(loop=loop) with pytest.raises(aiodns.error.DNSError): yield from resolver.resolve('doesnotexist.bla') @@ -125,13 +173,13 @@ def test_default_loop_for_async_resolver(loop): @pytest.mark.skipif(aiodns is None, reason="aiodns required") @asyncio.coroutine def test_async_resolver_ipv6_positive_lookup(loop): - with patch('aiodns.DNSResolver.query') as mock_query: - mock_query.return_value = fake_result(['::1']) + with patch('aiodns.DNSResolver.gethostbyname') as mock_ghn: + mock_ghn.return_value = fake_result(['::1']) resolver = AsyncResolver(loop=loop) real = yield from resolver.resolve('www.python.org', family=socket.AF_INET6) ipaddress.ip_address(real[0]['host']) - mock_query.assert_called_with('www.python.org', 'AAAA') + mock_ghn.assert_called_with('www.python.org', socket.AF_INET6) def test_async_resolver_aiodns_not_present(loop, monkeypatch):