-
Notifications
You must be signed in to change notification settings - Fork 230
/
Copy pathmxnet_model_service.py
189 lines (153 loc) · 6.85 KB
/
mxnet_model_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# http://www.apache.org/licenses/LICENSE-2.0
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
MXNetModelService defines an API for MXNet service.
"""
import json
import os
import mxnet as mx
from mxnet.io import DataBatch
from model_handler import ModelHandler
class MXNetModelService(ModelHandler):
"""
MXNetBaseService defines the fundamental loading model and inference
operations when serving MXNet model. This is a base class and needs to be
inherited.
"""
def __init__(self):
super(MXNetModelService, self).__init__()
self.mxnet_ctx = None
self.mx_model = None
self.labels = None
self.signature = None
self.epoch = 0
# noinspection PyMethodMayBeStatic
def get_model_files_prefix(self, context):
return context.manifest["model"]["modelName"]
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
super(MXNetModelService, self).initialize(context)
assert self._batch_size == 1, "Batch is not supported."
properties = context.system_properties
model_dir = properties.get("model_dir")
gpu_id = properties.get("gpu_id")
signature_file_path = os.path.join(model_dir, "signature.json")
if not os.path.isfile(signature_file_path):
raise RuntimeError("Missing signature.json file.")
with open(signature_file_path) as f:
self.signature = json.load(f)
model_files_prefix = self.get_model_files_prefix(context)
archive_synset = os.path.join(model_dir, "synset.txt")
if os.path.isfile(archive_synset):
synset = archive_synset
self.labels = [line.strip() for line in open(synset).readlines()]
data_names = []
data_shapes = []
for input_data in self.signature["inputs"]:
data_name = input_data["data_name"]
data_shape = input_data["data_shape"]
# Set batch size
data_shape[0] = self._batch_size
# Replace 0 entry in data shape with 1 for binding executor.
for idx in range(len(data_shape)):
if data_shape[idx] == 0:
data_shape[idx] = 1
data_names.append(data_name)
data_shapes.append((data_name, tuple(data_shape)))
checkpoint_prefix = "{}/{}".format(model_dir, model_files_prefix)
# Load MXNet module
self.mxnet_ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id)
sym, arg_params, aux_params = mx.model.load_checkpoint(checkpoint_prefix, self.epoch)
# noinspection PyTypeChecker
self.mx_model = mx.mod.Module(symbol=sym, context=self.mxnet_ctx,
data_names=data_names, label_names=None)
self.mx_model.bind(for_training=False, data_shapes=data_shapes)
self.mx_model.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)
def preprocess(self, batch):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch))
ret = []
param_name = self.signature['inputs'][0]['data_name']
for idx, request in enumerate(batch):
data = request.get(param_name)
if data is None:
data = request.get("body")
if data is None:
data = request.get("data")
ret.append(map(mx.nd.array, data))
return ret
def inference(self, model_input):
"""
Internal inference methods for MXNet. Run forward computation and
return output.
:param model_input: list of NDArray
Preprocessed inputs in NDArray format.
:return: list of NDArray
Inference output.
"""
if self.error is not None:
return None
# Check input shape
check_input_shape(model_input, self.signature)
model_input = [item.as_in_context(self.mxnet_ctx) for item in model_input]
self.mx_model.forward(DataBatch(model_input))
model_input = self.mx_model.get_outputs()
# by pass lazy evaluation get_outputs either returns a list of nd arrays
# a list of list of NDArray
for d in model_input:
if isinstance(d, list):
for n in model_input:
if isinstance(n, mx.ndarray.ndarray.NDArray):
n.wait_to_read()
elif isinstance(d, mx.ndarray.ndarray.NDArray):
d.wait_to_read()
return model_input
def postprocess(self, inference_output):
if self.error is not None:
return [self.error] * self._batch_size
return [str(d.asnumpy().tolist()) for d in inference_output]
def check_input_shape(inputs, signature):
"""
Check input data shape consistency with signature.
Parameters
----------
inputs : List of NDArray
Input data in NDArray format.
signature : dict
Dictionary containing model signature.
"""
assert isinstance(inputs, list), 'Input data must be a list.'
assert len(inputs) == len(signature['inputs']), \
"Input number mismatches with " \
"signature. %d expected but got %d." \
% (len(signature['inputs']), len(inputs))
for input_data, sig_input in zip(inputs, signature["inputs"]):
assert isinstance(input_data, mx.nd.NDArray), 'Each input must be NDArray.'
assert len(input_data.shape) == len(sig_input["data_shape"]), \
'Shape dimension of input %s mismatches with ' \
'signature. %d expected but got %d.' \
% (sig_input['data_name'],
len(sig_input['data_shape']),
len(input_data.shape))
for idx in range(len(input_data.shape)):
if idx != 0 and sig_input['data_shape'][idx] != 0:
assert sig_input['data_shape'][idx] == input_data.shape[idx], \
'Input %s has different shape with ' \
'signature. %s expected but got %s.' \
% (sig_input['data_name'], sig_input['data_shape'],
input_data.shape)