forked from luca-medeiros/lang-segment-anything
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tcp_server.py
95 lines (83 loc) · 2.91 KB
/
tcp_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
'''
This is just adapted from the example in the readme,
The main usage is for the built image to have the weights cached.
'''
from PIL import Image
from lang_sam import LangSAM
import numpy as np
import socket
import gc
import torch
import cv2
def send_arr_to_tcp(arr, conn):
conn.sendall(len(arr.shape).to_bytes(4, byteorder='big'))
for s in arr.shape:
conn.sendall(s.to_bytes(4, byteorder='big'))
conn.sendall(arr.tobytes())
data = conn.recv(2)
assert data.decode('utf-8') == 'ok'
def recv_arr_from_tcp(conn):
len_shape = int.from_bytes(conn.recv(4), byteorder='big')
shape = []
bytes_recv = 1
for i in range(len_shape):
s = int.from_bytes(conn.recv(4), byteorder='big')
shape.append(s)
bytes_recv *= s
shape = tuple(shape)
arr = b''
while bytes_recv > 0:
data = conn.recv(65536)
arr += data
bytes_recv -= len(data)
arr = np.frombuffer(arr, dtype=np.uint8).reshape(shape)
conn.sendall('ok'.encode('utf-8'))
return arr
def send_str_to_tcp(s, conn):
conn.sendall(len(s).to_bytes(4, byteorder='big'))
conn.sendall(s.encode('utf-8'))
data = conn.recv(2)
assert data.decode('utf-8') == 'ok'
def recv_str_from_tcp(conn):
length = int.from_bytes(conn.recv(4), byteorder='big')
s = b''
while length > 0:
data = conn.recv(65536)
s += data
length -= len(data)
conn.sendall('ok'.encode('utf-8'))
return s.decode('utf-8')
if __name__ == '__main__':
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 12345))
s.listen()
model = LangSAM(sam_type='vit_t')
print('initialized.')
while True:
conn, addr = s.accept()
with conn:
print(f"Connected by {addr}")
# model = LangSAM()
img = recv_arr_from_tcp(conn)
text_prompt = recv_str_from_tcp(conn)
if len(text_prompt) == 0:
text_prompt = 'bottle, can'
pil_img = Image.fromarray(img, mode="RGB")
masks, boxes, phrases, logits = model.predict(pil_img, text_prompt)
print(masks.shape)
if len(masks.shape) == 2:
masks = np.expand_dims(masks, axis=0)
print('type 2', masks.shape)
elif len(masks.shape) == 1:
# masks = np.zeros((1, img.shape[0], img.shape[1]))
masks = np.zeros((0, img.shape[0], img.shape[1]))
print('type 1')
else:
masks = masks.detach().cpu().numpy()
print('all ok')
send_arr_to_tcp(masks.astype(np.uint8), conn)
# model.cpu()
# del model
# gc.collect()
# torch.cuda.empty_cache()
# print('model deleted.')