-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
225 lines (198 loc) · 8.25 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#!/usr/bin/env python3
from __future__ import annotations
import asyncio
import json
import time
import traceback
from asyncio import StreamReader, StreamWriter
from collections import defaultdict
from dataclasses import dataclass
from itertools import count
from typing import ByteString, Dict
import numpy as np
from src.lib.layouts import TensorLayout
from src.lib.predecode import Predecoder
from src.modelconfig import ModelConfig, ProcessorConfig
from src.server import monitor_client
from src.server.comm import (
json_confirmation,
json_ping,
json_ready,
json_result,
)
from src.server.model_manager import ModelManager
from src.server.monitor_client import MonitorStats, image_preview
from src.server.reader import read_item
from src.server.work_distributor import SmartProcessor, WorkDistributor
from src.utils import get_predecoder
IP = "0.0.0.0"
PORT = 5678
PORT2 = 5680
def str_preview(s: ByteString, max_len=16):
if len(s) < max_len:
return s.hex()
return f"{s[:max_len - 6].hex()}...{s[-3:].hex()}"
@dataclass
class State:
model_config: ModelConfig = None
predecoder: Predecoder = None
tensor_layout: TensorLayout = None
# TODO document that this happens on a single dedicated thread
# TODO turn this into a class or something
def processor(work_distributor: WorkDistributor, monitor_stats: MonitorStats):
"""Process work items received from work distributor."""
model_manager = ModelManager()
smart_processor = SmartProcessor(work_distributor)
states: Dict[int, State] = defaultdict(State)
while True:
try:
guid, item = smart_processor.get()
request_type, item = item
state = states[guid]
if request_type == "terminate":
work_distributor.put(guid, None)
elif request_type == "acquire":
model_config = item
assert state.model_config is None
model_manager.acquire(model_config)
state.model_config = model_config
# TODO have client provide the tiled_layout
state.tensor_layout = model_manager.input_tensor_layout(
model_config
)
elif request_type == "release":
model_config = item
assert model_config == state.model_config
model_manager.release(model_config)
state.model_config = None
elif request_type == "init_postencoder":
postencoder_config = item
state.predecoder = get_predecoder(
postencoder_config, state.model_config, state.tensor_layout
)
elif request_type == "predict":
frame_number, buf = item
model_config = state.model_config
confirmation = json_confirmation(
frame_number=frame_number, num_bytes=len(buf)
)
confirmation = f"{confirmation}\n".encode("utf8")
work_distributor.put(guid, confirmation)
# TODO predecode_time separately from inference_time
t0 = time.time()
data_tensor = state.predecoder.run(buf)
data_tensor = data_tensor[np.newaxis, ...]
preds = model_manager.predict(model_config, data_tensor)
preds = model_manager.decode_predictions(model_config, preds)
t1 = time.time()
inference_time = int(1000 * (t1 - t0))
result = json_result(
frame_number=frame_number,
inference_time=int(1000 * (t1 - t0)),
predictions=preds,
)
result = f"{result}\n".encode("utf8")
work_distributor.put(guid, result)
monitor_stats.add(
frame_number=frame_number,
# data_shape=..., # TODO different shapes for data?
inference_time=inference_time,
predictions=preds,
data=image_preview(data_tensor),
)
elif request_type == "ready":
ready = json_ready(model_config=state.model_config)
ready = f"{ready}\n".encode("utf8")
work_distributor.put(guid, ready)
elif request_type == "ping":
id_ = item
response = f"{json_ping(id_)}\n".encode("utf8")
work_distributor.put(guid, response)
else:
raise ValueError("Unknown request type")
except Exception:
traceback.print_exc()
async def produce(reader: StreamReader, putter):
"""Reads from socket, and pushes requests to processor."""
model_config: ModelConfig = None
try:
while True:
print("Read begin")
input_type, item = await read_item(reader)
print("Read end")
if input_type == "terminate":
break
# TODO merge with processor()?
if input_type == "frame":
frame_number, buf = item
print(f"Produce: {frame_number} {str_preview(buf)}")
with open("frame.dat", "wb") as f:
f.write(buf)
await putter(("predict", item))
# TODO why are all json input types handled in this way?
elif input_type == "json":
# TODO this is all very confusing... clarify why next_model_config exists and why we need prev_model_loaded
print(f"Produce: {item}")
processor_config = ProcessorConfig.from_json_dict(item)
prev_model_config = model_config
model_config = processor_config.model_config
postencoder_config = processor_config.postencoder_config
prev_valid = prev_model_config is not None
changed = prev_valid and prev_model_config != model_config
if changed:
await putter(("release", prev_model_config))
if changed or not prev_valid:
await putter(("acquire", model_config))
await putter(("init_postencoder", postencoder_config))
await putter(("ready", None))
elif input_type == "ping":
await putter(("ping", item))
finally:
if model_config is not None:
await putter(("release", model_config))
await putter(("terminate", None))
async def consume(writer: StreamWriter, getter):
"""Receives items and writes them to socket."""
try:
for i in count():
item = await getter()
if item is None:
break
item_d = json.loads(item.decode("utf8"))
item_d.pop("predictions", None)
print(f"Consume {i}: {item_d}")
# print(json.dumps(item_d, indent=4))
print("Write begin")
writer.write(item)
print("Drain...")
await writer.drain()
print("Write end")
finally:
print("Closing client...")
writer.close()
def handle_client(work_distributor: WorkDistributor):
async def client_handler(reader: StreamReader, writer: StreamWriter):
print("New client...")
ip, port = writer.get_extra_info("peername")
print(f"Connected to {ip}:{port}")
putter, getter = work_distributor.register()
coros = [produce(reader, putter), consume(writer, getter)]
tasks = map(asyncio.create_task, coros)
await asyncio.wait(tasks)
return client_handler
async def main():
work_distributor = WorkDistributor()
monitor_stats = MonitorStats()
loop = asyncio.get_event_loop()
loop.run_in_executor(None, processor, work_distributor, monitor_stats)
client_handler = handle_client(work_distributor)
server = await asyncio.start_server(client_handler, IP, PORT)
monitor_handler = monitor_client.handle_client(monitor_stats)
monitor_server = await asyncio.start_server(monitor_handler, IP, PORT2)
print("Started server")
await asyncio.wait(
[server.serve_forever(), monitor_server.serve_forever()]
)
if __name__ == "__main__":
asyncio.run(main())
# TODO read, inference, write in parallel, no? (multiprocess.executorpool)