-
Notifications
You must be signed in to change notification settings - Fork 863
/
util.py
175 lines (138 loc) · 4.72 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
Utility functions for TorchServe
"""
import enum
import inspect
import itertools
import json
import logging
import os
import re
from functools import wraps
from warnings import simplefilter, warn
simplefilter("once", category=(PendingDeprecationWarning, DeprecationWarning))
import yaml
class PT2Backend(str, enum.Enum):
EAGER = "eager"
AOT_EAGER = "aot_eager"
INDUCTOR = "inductor"
NVFUSER = "nvfuser"
AOT_NVFUSER = "aot_nvfuser"
AOT_CUDAGRAPHS = "aot_cudagraphs"
OFI = "ofi"
FX2TRT = "fx2trt"
ONNXRT = "onnxrt"
IPEX = "ipex"
TORCHXLA_TRACE_ONCE = "torchxla_trace_once"
logger = logging.getLogger(__name__)
CLEANUP_REGEX = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
def list_classes_from_module(module, parent_class=None):
"""
Parse user defined module to get all model service classes in it.
:param module:
:param parent_class:
:return: List of model service class definitions
"""
# Parsing the module to get all defined classes
classes = [
cls[1]
for cls in inspect.getmembers(
module,
lambda member: inspect.isclass(member)
and member.__module__ == module.__name__,
)
]
# filter classes that is subclass of parent_class
if parent_class is not None:
return [c for c in classes if issubclass(c, parent_class)]
return classes
def check_valid_pt2_backend(backend: str) -> bool:
backend_values = [member.value for member in PT2Backend]
if backend in backend_values:
return True
else:
logger.warning(f"{backend} is not a supported backend")
return False
def load_label_mapping(mapping_file_path):
"""
Load a JSON mapping { class ID -> friendly class name }.
Used in BaseHandler.
"""
if not os.path.isfile(mapping_file_path):
logger.warning(
f"{mapping_file_path!r} is missing. Inference output will not include class name."
)
return None
with open(mapping_file_path) as f:
mapping = json.load(f)
if not isinstance(mapping, dict):
raise Exception(
'index->name JSON mapping should be in "class": "label" format.'
)
# Older examples had a different syntax than others. This code accommodates those.
if "object_type_names" in mapping and isinstance(
mapping["object_type_names"], list
):
mapping = {str(k): v for k, v in enumerate(mapping["object_type_names"])}
return mapping
for key, value in mapping.items():
new_value = value
if isinstance(new_value, list):
new_value = value[-1]
if not isinstance(new_value, str):
raise Exception(
"labels in index->name mapping must be either str or List[str]"
)
mapping[key] = new_value
return mapping
def map_class_to_label(probs, mapping=None, lbl_classes=None):
"""
Given a list of classes & probabilities, return a dictionary of
{ friendly class name -> probability }
"""
if not isinstance(probs, list):
raise Exception("Convert classes to list before doing mapping")
if mapping is not None and not isinstance(mapping, dict):
raise Exception("Mapping must be a dict")
if lbl_classes is None:
lbl_classes = itertools.repeat(range(len(probs[0])), len(probs))
results = [
{
(mapping[str(lbl_class)] if mapping is not None else str(lbl_class)): prob
for lbl_class, prob in zip(*row)
}
for row in zip(lbl_classes, probs)
]
return results
def get_yaml_config(yaml_file_path):
config = {}
with open(yaml_file_path, "r") as file:
config = yaml.safe_load(file)
return config
class PredictionException(Exception):
def __init__(self, message, error_code=500):
self.message = message
self.error_code = error_code
super().__init__(message)
def __str__(self):
return f"{self.message} : {self.error_code}"
def deprecated(version, replacement="", klass=PendingDeprecationWarning):
"""This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used.
Args:
version: The version in which the function will be removed.
replacement: The replacement function, if any.
klass: The category of warning
"""
def deprecator(func):
@wraps(func)
def wrapper(*args, **kwargs):
warn(
f"{func.__name__} is deprecated in {version} and moved to {replacement}",
klass,
stacklevel=2,
)
return func(*args, **kwargs)
return wrapper
return deprecator