diff --git a/aiohttp/client.py b/aiohttp/client.py index bd542f01505..9194ab6cb13 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -726,6 +726,26 @@ def on_request_createconn_start(self): def on_request_createconn_end(self): return self._connector.on_createconn_end + @property + def on_request_reuseconn(self): + return self._connector.on_reuseconn + + @property + def on_request_resolvehost_start(self): + return self._connector.on_resolvehost_start + + @property + def on_request_resolvehost_end(self): + return self._connector.on_resolvehost_end + + @property + def on_request_dnscache_hit(self): + return self._connector.on_dnscache_hit + + @property + def on_request_dnscache_miss(self): + return self._connector.on_dnscache_miss + # req resp signals @property diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 7a0c06fc698..510f898326e 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -205,6 +205,7 @@ def __init__(self, *, keepalive_timeout=sentinel, self._on_queued_end = FuncSignal() self._on_createconn_start = FuncSignal() self._on_createconn_end = FuncSignal() + self._on_reuseconn = FuncSignal() def __del__(self, _warnings=warnings): if self._closed: @@ -407,7 +408,10 @@ def connect(self, req, trace_context=None): self.on_createconn_start.send(trace_context) try: - proto = yield from self._create_connection(req) + proto = yield from self._create_connection( + req, + trace_context=trace_context + ) if self._closed: proto.close() raise ClientConnectionError("Connector is closed.") @@ -424,6 +428,8 @@ def connect(self, req, trace_context=None): self._acquired_per_host[key].remove(placeholder) self.on_createconn_end.send(trace_context) + else: + self.on_reuseconn.send(trace_context) self._acquired.add(proto) self._acquired_per_host[key].add(proto) @@ -518,7 +524,7 @@ def _release(self, key, protocol, *, should_close=False): self, '_cleanup', self._keepalive_timeout, self._loop) @asyncio.coroutine - def _create_connection(self, req): + def _create_connection(self, req, trace_context=None): raise NotImplementedError() @property @@ -537,6 +543,10 @@ def on_createconn_start(self): def on_createconn_end(self): return self._on_createconn_end + @property + def on_reuseconn(self): + return self._on_reuseconn + _SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0) @@ -659,6 +669,11 @@ def __init__(self, *, verify_ssl=True, fingerprint=None, self._family = family self._local_addr = local_addr + self._on_resolvehost_start = FuncSignal() + self._on_resolvehost_end = FuncSignal() + self._on_dnscache_hit = FuncSignal() + self._on_dnscache_miss = FuncSignal() + def close(self): """Close all ongoing DNS calls.""" for ev in self._throttle_dns_events.values(): @@ -723,32 +738,40 @@ def clear_dns_cache(self, host=None, port=None): self._cached_hosts.clear() @asyncio.coroutine - def _resolve_host(self, host, port): + def _resolve_host(self, host, port, trace_context=None): if is_ip_address(host): return [{'hostname': host, 'host': host, 'port': port, 'family': self._family, 'proto': 0, 'flags': 0}] if not self._use_dns_cache: - return (yield from self._resolver.resolve( + self.on_resolvehost_start.send(trace_context) + res = (yield from self._resolver.resolve( host, port, family=self._family)) + self.on_resolvehost_end.send(trace_context) + return res key = (host, port) if (key in self._cached_hosts) and\ (not self._cached_hosts.expired(key)): + self.on_dnscache_hit.send(trace_context) return self._cached_hosts.next_addrs(key) if key in self._throttle_dns_events: + self.on_dnscache_hit.send(trace_context) yield from self._throttle_dns_events[key].wait() else: + self.on_dnscache_miss.send(trace_context) self._throttle_dns_events[key] = \ EventResultOrError(self._loop) try: + self.on_resolvehost_start.send(trace_context) addrs = yield from \ asyncio.shield(self._resolver.resolve(host, port, family=self._family), loop=self._loop) + self.on_resolvehost_end.send(trace_context) self._cached_hosts.add(key, addrs) self._throttle_dns_events[key].set() except Exception as e: @@ -762,15 +785,21 @@ def _resolve_host(self, host, port): return self._cached_hosts.next_addrs(key) @asyncio.coroutine - def _create_connection(self, req): + def _create_connection(self, req, trace_context=None): """Create connection. Has same keyword arguments as BaseEventLoop.create_connection. """ if req.proxy: - _, proto = yield from self._create_proxy_connection(req) + _, proto = yield from self._create_proxy_connection( + req, + trace_context=None + ) else: - _, proto = yield from self._create_direct_connection(req) + _, proto = yield from self._create_direct_connection( + req, + trace_context=None + ) return proto @@ -814,11 +843,15 @@ def _get_fingerprint_and_hashfunc(self, req): return (None, None) @asyncio.coroutine - def _create_direct_connection(self, req): + def _create_direct_connection(self, req, trace_context=None): sslcontext = self._get_ssl_context(req) fingerprint, hashfunc = self._get_fingerprint_and_hashfunc(req) - hosts = yield from self._resolve_host(req.url.raw_host, req.port) + hosts = yield from self._resolve_host( + req.url.raw_host, + req.port, + trace_context=trace_context + ) for hinfo in hosts: try: @@ -859,7 +892,7 @@ def _create_direct_connection(self, req): raise ClientConnectorError(req.connection_key, exc) from exc @asyncio.coroutine - def _create_proxy_connection(self, req): + def _create_proxy_connection(self, req, trace_context=None): headers = {} if req.proxy_headers is not None: headers = req.proxy_headers @@ -937,6 +970,22 @@ def _create_proxy_connection(self, req): return transport, proto + @property + def on_resolvehost_start(self): + return self._on_resolvehost_start + + @property + def on_resolvehost_end(self): + return self._on_resolvehost_end + + @property + def on_dnscache_hit(self): + return self._on_dnscache_hit + + @property + def on_dnscache_miss(self): + return self._on_dnscache_miss + class UnixConnector(BaseConnector): """Unix socket connector. @@ -970,7 +1019,7 @@ def path(self): return self._path @asyncio.coroutine - def _create_connection(self, req): + def _create_connection(self, req, trace_context=None): _, proto = yield from self._loop.create_unix_connection( self._factory, self._path) return proto diff --git a/tests/test_client_session.py b/tests/test_client_session.py index d92d1ff94b2..cf1416eda04 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -379,7 +379,7 @@ def test_reraise_os_error(create_session): session = create_session(request_class=req_factory) @asyncio.coroutine - def create_connection(req): + def create_connection(req, trace_context=None): # return self.transport, self.protocol return mock.Mock() session._connector._create_connection = create_connection @@ -585,6 +585,15 @@ def test_request_tracing_proxies_connector_signals(loop): id(connector.on_createconn_start) assert id(session.on_request_createconn_end) ==\ id(connector.on_createconn_end) + assert id(session.on_request_reuseconn) == id(connector.on_reuseconn) + assert id(session.on_request_resolvehost_start) ==\ + id(connector.on_resolvehost_start) + assert id(session.on_request_resolvehost_end) ==\ + id(connector.on_resolvehost_end) + assert id(session.on_request_dnscache_hit) ==\ + id(connector.on_dnscache_hit) + assert id(session.on_request_dnscache_miss) ==\ + id(connector.on_dnscache_miss) @asyncio.coroutine diff --git a/tests/test_connector.py b/tests/test_connector.py index d0d94a2d4a9..29c0ef14a8d 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -507,6 +507,45 @@ def test_tcp_connector_dns_throttle_requests_cancelled_when_close( yield from f +@asyncio.coroutine +def test_tcp_connector_dns_tracing(loop, dns_response): + trace_context = mock.Mock() + on_resolvehost_start = mock.Mock() + on_resolvehost_end = mock.Mock() + on_dnscache_hit = mock.Mock() + on_dnscache_miss = mock.Mock() + + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: + conn = aiohttp.TCPConnector( + loop=loop, + use_dns_cache=True, + ttl_dns_cache=10 + ) + conn.on_resolvehost_start.append(on_resolvehost_start) + conn.on_resolvehost_end.append(on_resolvehost_end) + conn.on_dnscache_hit.append(on_dnscache_hit) + conn.on_dnscache_miss.append(on_dnscache_miss) + + m_resolver().resolve.return_value = dns_response() + + yield from conn._resolve_host( + 'localhost', + 8080, + trace_context=trace_context + ) + on_resolvehost_start.assert_called_once_with(trace_context) + on_resolvehost_end.assert_called_once_with(trace_context) + on_dnscache_miss.assert_called_once_with(trace_context) + assert not on_dnscache_hit.called + + yield from conn._resolve_host( + 'localhost', + 8080, + trace_context=trace_context + ) + on_dnscache_hit.assert_called_once_with(trace_context) + + def test_get_pop_empty_conns(loop): # see issue #473 conn = aiohttp.BaseConnector(loop=loop) @@ -946,6 +985,26 @@ def f(): conn.close() +@asyncio.coroutine +def test_connect_reuseconn_tracing(loop, key): + proto = mock.Mock() + proto.is_connected.return_value = True + trace_context = mock.Mock() + on_reuseconn = mock.Mock() + + req = ClientRequest('GET', URL('http://localhost1:80'), + loop=loop, + response_class=mock.Mock()) + + conn = aiohttp.BaseConnector(loop=loop, limit=1) + conn.on_reuseconn.append(on_reuseconn) + conn._conns[key] = [(proto, loop.time())] + yield from conn.connect(req, trace_context=trace_context) + + on_reuseconn.assert_called_with(trace_context) + conn.close() + + @asyncio.coroutine def test_connect_with_limit_and_limit_per_host(loop, key): proto = mock.Mock()