-
Notifications
You must be signed in to change notification settings - Fork 105
/
anyio.py
145 lines (131 loc) · 5.09 KB
/
anyio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import ssl
import typing
import anyio
from .._exceptions import (
ConnectError,
ConnectTimeout,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
class AnyIOStream(AsyncNetworkStream):
def __init__(self, stream: anyio.abc.ByteStream) -> None:
self._stream = stream
async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
exc_map = {
TimeoutError: ReadTimeout,
anyio.BrokenResourceError: ReadError,
anyio.ClosedResourceError: ReadError,
}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
try:
return await self._stream.receive(max_bytes=max_bytes)
except anyio.EndOfStream: # pragma: nocover
return b""
async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
if not buffer:
return
exc_map = {
TimeoutError: WriteTimeout,
anyio.BrokenResourceError: WriteError,
anyio.ClosedResourceError: WriteError,
}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
await self._stream.send(item=buffer)
async def aclose(self) -> None:
await self._stream.aclose()
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> AsyncNetworkStream:
exc_map = {
TimeoutError: ConnectTimeout,
anyio.BrokenResourceError: ConnectError,
}
with map_exceptions(exc_map):
try:
with anyio.fail_after(timeout):
ssl_stream = await anyio.streams.tls.TLSStream.wrap(
self._stream,
ssl_context=ssl_context,
hostname=server_hostname,
standard_compatible=False,
server_side=False,
)
except Exception as exc: # pragma: nocover
await self.aclose()
raise exc
return AnyIOStream(ssl_stream)
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object":
return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
if info == "client_addr":
return self._stream.extra(anyio.abc.SocketAttribute.local_address, None)
if info == "server_addr":
return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
if info == "socket":
return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
if info == "is_readable":
sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
return is_socket_readable(sock)
return None
class AnyIOBackend(AsyncNetworkBackend):
async def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
if socket_options is None:
socket_options = [] # pragma: no cover
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
anyio.BrokenResourceError: ConnectError,
}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
stream: anyio.abc.ByteStream = await anyio.connect_tcp(
remote_host=host,
remote_port=port,
local_host=local_address,
)
# By default TCP sockets opened in `asyncio` include TCP_NODELAY.
for option in socket_options:
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
return AnyIOStream(stream)
async def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
anyio.BrokenResourceError: ConnectError,
}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
for option in socket_options:
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
return AnyIOStream(stream)
async def sleep(self, seconds: float) -> None:
await anyio.sleep(seconds) # pragma: nocover