This repository has been archived by the owner on Jun 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrfc6555.py
319 lines (255 loc) · 10.4 KB
/
rfc6555.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# Copyright 2021 Seth Michael Larson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python implementation of the Happy Eyeballs Algorithm described in RFC 6555"""
import errno
import socket
try:
from selectors import EVENT_WRITE, DefaultSelector
except (ImportError, AttributeError):
from selectors2 import EVENT_WRITE, DefaultSelector
# time.perf_counter() is defined in Python 3.3
try:
from time import perf_counter
except (ImportError, AttributeError):
from time import time as perf_counter
# This list is due to socket.error and IOError not being a
# subclass of OSError until later versions of Python.
_SOCKET_ERRORS = (socket.error, OSError, IOError)
# Detects whether an IPv6 socket can be allocated.
def _detect_ipv6():
if getattr(socket, "has_ipv6", False) and hasattr(socket, "AF_INET6"):
_sock = None
try:
_sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
_sock.bind(("::1", 0))
return True
except _SOCKET_ERRORS:
if _sock:
_sock.close()
return False
_HAS_IPV6 = _detect_ipv6()
# These are error numbers for asynchronous operations which can
# be safely ignored by RFC 6555 as being non-errors.
_ASYNC_ERRNOS = set([errno.EINPROGRESS, errno.EAGAIN, errno.EWOULDBLOCK])
if hasattr(errno, "WSAWOULDBLOCK"):
_ASYNC_ERRNOS.add(errno.WSAWOULDBLOCK)
_DEFAULT_CACHE_DURATION = 60 * 10 # 10 minutes according to the RFC.
# This value that can be used to disable RFC 6555 globally.
RFC6555_ENABLED = _HAS_IPV6
__all__ = ["RFC6555_ENABLED", "create_connection", "cache"]
__version__ = "0.1.0"
__author__ = "Seth Michael Larson"
__email__ = "sethmichaellarson@gmail.com"
__license__ = "Apache-2.0"
class _RFC6555CacheManager(object):
def __init__(self):
self.validity_duration = _DEFAULT_CACHE_DURATION
self.enabled = True
self.entries = {}
def add_entry(self, address, family):
if self.enabled:
current_time = perf_counter()
# Don't over-write old entries to reset their expiry.
if address not in self.entries or self.entries[address][1] > current_time:
self.entries[address] = (family, current_time + self.validity_duration)
def get_entry(self, address):
if not self.enabled or address not in self.entries:
return None
family, expiry = self.entries[address]
if perf_counter() > expiry:
del self.entries[address]
return None
return family
cache = _RFC6555CacheManager()
class _RFC6555ConnectionManager(object):
def __init__(
self, address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None
):
self.address = address
self.timeout = timeout
self.source_address = source_address
self._error = None
self._selector = DefaultSelector()
self._sockets = []
self._start_time = None
def create_connection(self):
self._start_time = perf_counter()
host, port = self.address
addr_info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
ret = self._connect_with_cached_family(addr_info)
# If it's a list, then these are the remaining values to try.
if isinstance(ret, list):
addr_info = ret
else:
cache.add_entry(self.address, ret.family)
return ret
# If we don't get any results back then just skip to the end.
if not addr_info:
raise socket.error("getaddrinfo returns an empty list")
sock = self._attempt_connect_with_addr_info(addr_info)
if sock:
cache.add_entry(self.address, sock.family)
return sock
elif self._error:
raise self._error
else:
raise socket.timeout()
def _attempt_connect_with_addr_info(self, addr_info):
sock = None
try:
for family, socktype, proto, _, sockaddr in addr_info:
self._create_socket(family, socktype, proto, sockaddr)
sock = self._wait_for_connection(False)
if sock:
break
if sock is None:
sock = self._wait_for_connection(True)
finally:
self._remove_all_sockets()
return sock
def _connect_with_cached_family(self, addr_info):
family = cache.get_entry(self.address)
if family is None:
return addr_info
is_family = []
not_family = []
for value in addr_info:
if value[0] == family:
is_family.append(value)
else:
not_family.append(value)
sock = self._attempt_connect_with_addr_info(is_family)
if sock is not None:
return sock
return not_family
def _create_socket(self, family, socktype, proto, sockaddr):
sock = None
try:
sock = socket.socket(family, socktype, proto)
# If we're using the 'default' socket timeout we have
# to set it to a real value here as this is the earliest
# opportunity to without pre-allocating a socket just for
# this purpose.
if self.timeout is socket._GLOBAL_DEFAULT_TIMEOUT:
self.timeout = sock.gettimeout()
if self.source_address:
sock.bind(self.source_address)
# Make the socket non-blocking so we can use our selector.
sock.settimeout(0.0)
if self._is_acceptable_errno(sock.connect_ex(sockaddr)):
self._selector.register(sock, EVENT_WRITE)
self._sockets.append(sock)
except _SOCKET_ERRORS as e:
self._error = e
if sock is not None:
_RFC6555ConnectionManager._close_socket(sock)
def _wait_for_connection(self, last_wait):
self._remove_all_errored_sockets()
# This is a safe-guard to make sure sock.gettimeout() is called in the
# case that the default socket timeout is used. If there are no
# sockets then we may not have called sock.gettimeout() yet.
if not self._sockets:
return None
# If this is the last time we're waiting for connections
# then we should wait until we should raise a timeout
# error, otherwise we should only wait >0.2 seconds as
# recommended by RFC 6555.
if last_wait:
if self.timeout is None:
select_timeout = None
else:
select_timeout = self._get_remaining_time()
else:
select_timeout = self._get_select_time()
# Wait for any socket to become writable as a sign of being connected.
for key, _ in self._selector.select(select_timeout):
sock = key.fileobj
if not self._is_socket_errored(sock):
# Restore the old proper timeout of the socket.
sock.settimeout(self.timeout)
# Remove it from this list to exempt the socket from cleanup.
self._sockets.remove(sock)
self._selector.unregister(sock)
return sock
return None
def _get_remaining_time(self):
if self.timeout is None:
return None
return max(self.timeout - (perf_counter() - self._start_time), 0.0)
def _get_select_time(self):
if self.timeout is None:
return 0.2
return min(0.2, self._get_remaining_time())
def _remove_all_errored_sockets(self):
socks = []
for sock in self._sockets:
if self._is_socket_errored(sock):
socks.append(sock)
for sock in socks:
self._selector.unregister(sock)
self._sockets.remove(sock)
_RFC6555ConnectionManager._close_socket(sock)
@staticmethod
def _close_socket(sock):
try:
sock.close()
except _SOCKET_ERRORS:
pass
def _is_acceptable_errno(self, errno):
if errno == 0 or errno in _ASYNC_ERRNOS:
return True
self._error = socket.error()
self._error.errno = errno
return False
def _is_socket_errored(self, sock):
errno = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
return not self._is_acceptable_errno(errno)
def _remove_all_sockets(self):
for sock in self._sockets:
self._selector.unregister(sock)
_RFC6555ConnectionManager._close_socket(sock)
self._sockets = []
def create_connection(
address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None
):
if RFC6555_ENABLED and _HAS_IPV6:
manager = _RFC6555ConnectionManager(address, timeout, source_address)
return manager.create_connection()
else:
# This code is the same as socket.create_connection() but is
# here to make sure the same code is used across all Python versions as
# the source_address parameter was added to socket.create_connection() in 3.2
# This segment of code is licensed under the Python Software Foundation License
# See LICENSE: https://github.com/python/cpython/blob/3.6/LICENSE
host, port = address
err = None
for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
af, socktype, proto, _, sa = res
sock = None
try:
sock = socket.socket(af, socktype, proto)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except socket.error as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise socket.error("getaddrinfo returns an empty list")