-
-
Notifications
You must be signed in to change notification settings - Fork 4
/
asgi_cors.py
131 lines (120 loc) · 5.14 KB
/
asgi_cors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import fnmatch
import asyncio
from functools import wraps
def asgi_cors_decorator(
allow_all=False,
hosts=None,
host_wildcards=None,
callback=None,
headers=None,
methods=None,
max_age=None,
):
hosts = hosts or []
host_wildcards = host_wildcards or []
headers = headers or []
methods = methods or []
# We need hosts and host_wildcards to be b""
hosts = set(h.encode("utf8") if isinstance(h, str) else h for h in hosts)
host_wildcards = [
h.encode("utf8") if isinstance(h, str) else h for h in host_wildcards
]
headers = [h.encode("utf8") if isinstance(h, str) else h for h in headers]
methods = [h.encode("utf8") if isinstance(h, str) else h for h in methods]
if any(h.endswith(b"/") for h in (hosts or [])) or any(
h.endswith(b"/") for h in (host_wildcards or [])
):
assert False, "Error: CORS origin rules should never end in a /"
def _asgi_cors_decorator(app):
@wraps(app)
async def app_wrapped_with_cors(scope, receive, send):
async def wrapped_send(event):
if event["type"] == "http.response.start":
original_headers = event.get("headers") or []
access_control_allow_origin = None
if allow_all:
access_control_allow_origin = b"*"
elif hosts or host_wildcards or callback:
incoming_origin = dict(scope.get("headers") or []).get(
b"origin"
)
if incoming_origin:
matches_hosts = incoming_origin in hosts
matches_wildcards = any(
fnmatch.fnmatch(incoming_origin, host_wildcard)
for host_wildcard in host_wildcards
)
matches_callback = False
if callback is not None:
if asyncio.iscoroutinefunction(callback):
matches_callback = await callback(incoming_origin)
else:
matches_callback = callback(incoming_origin)
if matches_hosts or matches_wildcards or matches_callback:
access_control_allow_origin = incoming_origin
if access_control_allow_origin is not None:
# Construct a new event with new headers
new_headers = [
p
for p in original_headers
if p[0]
not in (
b"access-control-allow-origin",
b"access-control-allow-headers",
b"access-control-allow-methods",
b"access-control-max-age",
)
]
if access_control_allow_origin:
new_headers.append(
[
b"access-control-allow-origin",
access_control_allow_origin,
]
)
if headers:
new_headers.append(
[
b"access-control-allow-headers",
b", ".join(
h.encode("utf-8") if isinstance(h, str) else h
for h in headers
),
]
)
if methods:
new_headers.append(
[
b"access-control-allow-methods",
b", ".join(
m.encode("utf-8") if isinstance(m, str) else m
for m in methods
),
]
)
if max_age:
new_headers.append(
[b"access-control-max-age", str(max_age)]
)
event = {
"type": "http.response.start",
"status": event["status"],
"headers": new_headers,
}
await send(event)
await app(scope, receive, wrapped_send)
return app_wrapped_with_cors
return _asgi_cors_decorator
def asgi_cors(
app,
allow_all=False,
hosts=None,
host_wildcards=None,
callback=None,
headers=None,
methods=None,
max_age=None,
):
return asgi_cors_decorator(
allow_all, hosts, host_wildcards, callback, headers, methods, max_age
)(app)