forked from cap-ntu/Video-to-Retail-Platform
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscene_search_model_server.py
97 lines (78 loc) · 2.99 KB
/
scene_search_model_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
import os
import os.path as osp
import time
from concurrent import futures
# rpc imports
import grpc
from hysia.search.search import DatabasePklSearch
from hysia.utils.logger import Logger
from rpc.rpccode import api2msl_pb2, api2msl_pb2_grpc
from .utils import StreamSuppressor
SERVER_ROOT = os.path.dirname(os.path.abspath(__file__)) + '/'
# Time constant
_ONE_DAY_IN_SECONDS = 24 * 60 * 60
# TVQA dataset for efficient test
# VIDEO_DATA_PATH = '/data/disk2/hysia_data/UNC_TVQA_DATASET'
# search_machine = BasicSearch(VIDEO_DATA_PATH)
logger = Logger(
name='scene_search_model_server',
severity_levels={'StreamHandler': 'INFO'}
)
video_path = osp.join(SERVER_ROOT, '../output/multi_features')
def load_search_machine():
with StreamSuppressor():
search_machine = DatabasePklSearch(video_path)
return search_machine
# Custom request servicer
class Api2MslServicer(api2msl_pb2_grpc.Api2MslServicer):
def __init__(self):
super().__init__()
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
logger.info('Using GPU:' + os.environ['CUDA_VISIBLE_DEVICES'])
self.search_machine = load_search_machine()
def GetJson(self, request, context):
meta = json.loads(request.meta)
img_path = request.buf.decode()
logger.info('Searching by ' + img_path)
# Decode image from buf
with StreamSuppressor():
results = self.search_machine.search(
image_query=img_path,
subtitle_query=meta['text'],
face_query=None,
topK=5,
tv_name=meta['target_videos'][0] if len(meta['target_videos']) else None
# TODO Currently only support one target video
)
# Convert tensor to list to make it serializable
for res in results:
# TODO Here has some bugs, json can not accept numpy results
if not type(res['FEATURE']) == list:
res['FEATURE'] = res['FEATURE'].tolist()
try:
if not type(res['AUDIO_FEATURE']) == list:
res['AUDIO_FEATURE'] = res['AUDIO_FEATURE'].tolist()
except:
pass
try:
if not type(res['SUBTITLE_FEATURE']) == list and not type(res['SUBTITLE_FEATURE']) == str:
res['SUBTITLE_FEATURE'] = res['SUBTITLE_FEATURE'].tolist()
except:
pass
return api2msl_pb2.JsonReply(json=json.dumps(results), meta='')
def main():
# gRPC server configurations
server = grpc.server(futures.ThreadPoolExecutor(max_workers=8))
api2msl_pb2_grpc.add_Api2MslServicer_to_server(Api2MslServicer(), server)
server.add_insecure_port('[::]:50053')
server.start()
logger.info('Listening on port 50053')
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
logger.info('Shutting down scene search model server')
server.stop(0)
if __name__ == '__main__':
main()