-
Notifications
You must be signed in to change notification settings - Fork 0
/
cgutils.py
214 lines (176 loc) · 7.22 KB
/
cgutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import tvm.relay as r
from utils import *
def simple_convish_weight(conv):
return conv.args[1]
def conv_shape(conv):
return f'{simple_convish_weight(conv).checked_type.shape}'
def is_conv(call):
return get_op_name(call.op) in ['nn.conv2d', 'qnn.conv2d']
def is_dense(call):
return get_op_name(call.op) in ['nn.dense', 'qnn.dense']
def is_simple_convish(expr):
if not isinstance(expr, r.Call):
return False
return is_conv(expr) or is_dense(expr)
def is_convish(expr):
"""A convish can be a [q]{nn.conv2d, nn.dense} or a {TupleGetItem(0)+nn.batch_norm,
nn.bias_add}-wrapped [q]{nn.conv2d, nn.dense}.
It represents a small logical group in which the operators' parameters' shape
change together."""
if is_simple_convish(expr):
return True
if isinstance(expr, r.TupleGetItem) and \
expr.index == 0 and isinstance(expr.tuple_value, r.Call) and \
get_op_name(expr.tuple_value.op) == 'nn.batch_norm' and \
is_simple_convish(expr.tuple_value.args[0]):
return True
if not isinstance(expr, r.Call):
return False
op_name = get_op_name(expr.op)
if op_name == 'nn.bias_add' and \
isinstance(expr.args[0], r.Call) and is_simple_convish(expr.args[0]):
return True
return False
def convish_components(convish):
assert is_convish(convish)
if is_simple_convish(convish):
return [convish]
if isinstance(convish, r.TupleGetItem):
return [convish.tuple_value.args[0], convish.tuple_value, convish]
op_name = get_op_name(convish.op)
if op_name == 'nn.bias_add':
return [convish.args[0], convish]
raise NotImplementedError(f'{op_name} is not supported')
def is_dense_convish(convish):
return is_dense(convish_components(convish)[0])
def is_shape_preserving(call):
op_name = get_op_name(call.op)
return not any(x in op_name for x in ['pool'])
def convish_weight(convish):
return simple_convish_weight(convish_components(convish)[0])
def is_layerish(expr):
if is_any_type(expr, r.Call, r.Var, r.Tuple, r.Constant):
return True
if isinstance(expr, r.TupleGetItem) and \
expr.index == 0 and \
is_any_type(expr.tuple_value, 'nn.batch_norm', 'nn.dropout'):
return True
return False
def layerish_components(layerish):
assert is_layerish(layerish)
if isinstance(layerish, r.TupleGetItem):
return [layerish.tuple_value, layerish]
return [layerish]
def layerish_core(layerish):
return layerish_components(layerish)[0]
def layerish_parents(layerish, layerishs_only=False, pred=None):
core = layerish_core(layerish)
if isinstance(core, r.Call):
return [x for x in core.args if
(not layerishs_only or is_layerish(x)) and
(not pred or pred(x))]
if isinstance(core, r.Tuple):
return [x for x in core.fields if
(not layerishs_only or is_layerish(x)) and
(not pred or pred(x))]
return []
class ConvishVisitor(r.ExprVisitor):
def __init__(self, post_order=False):
super().__init__()
self.handled = set() # TODO: Check correctness
self.post_order = post_order
def visit_convish(self, convish):
raise NotImplementedError()
def visit_maybe_convish(self, convish, superf):
if convish in self.handled or not is_convish(convish):
return superf(convish)
[self.handled.add(c) for c in convish_components(convish)]
if self.post_order:
ret = superf(convish)
self.visit_convish(convish)
return ret
self.visit_convish(convish)
return superf(convish)
def visit_call(self, call):
return self.visit_maybe_convish(call, super().visit_call)
def visit_tuple_getitem(self, t):
return self.visit_maybe_convish(t, super().visit_tuple_getitem)
class LayerishVisitor(r.ExprVisitor):
def __init__(self, post_order=False):
super().__init__()
# A set to allow the passthrough traversal behaviour for the first
# n-1 components in a complex layerish.
self.passthru = set()
self.post_order = post_order
def visit_layerish(self, layerish):
raise NotImplementedError()
def visit_maybe_layerish(self, layerish, superf):
if layerish in self.passthru:
return superf(layerish)
if not is_layerish(layerish):
raise NotImplementedError(f'{desc_expr(layerish)} is not a layerish')
[self.passthru.add(c) for c in layerish_components(layerish)[:-1]]
if self.post_order:
ret = superf(layerish)
self.visit_layerish(layerish)
return ret
self.visit_layerish(layerish)
return superf(layerish)
def visit_call(self, call):
return self.visit_maybe_layerish(call, super().visit_call)
def visit_tuple_getitem(self, t):
return self.visit_maybe_layerish(t, super().visit_tuple_getitem)
def visit_var(self, var):
return self.visit_maybe_layerish(var, super().visit_var)
def visit_tuple(self, var):
# Handcrafted Relay tuples
return self.visit_maybe_layerish(var, super().visit_tuple)
def visit_constant(self, const):
return self.visit_maybe_layerish(const, super().visit_constant)
class LayerishChildrenFinder(LayerishVisitor):
def __init__(self):
super().__init__()
self.children = {}
def visit_layerish(self, layerish):
core = layerish_core(layerish)
for parent in layerish_parents(layerish, layerishs_only=True):
if parent not in self.children:
self.children[parent] = set()
self.children[parent].add(layerish)
if layerish not in self.children: # For the first visited layerish
self.children[layerish] = set()
def get_children_map(self):
return self.children
def get_children(self, layerish):
if layerish not in self.children:
raise KeyError(f'{desc_expr(layerish)}')
return self.children[layerish]
class BPVisitor(LayerishVisitor):
"""A visitor that ensures every node's children are visited before itself.
Suitable for backpropagation."""
def __init__(self, expr):
super().__init__()
self.expr = expr
self.handled = set()
self.cf = LayerishChildrenFinder()
self.cf.visit(self.expr)
# [print(f'{desc_expr(x)}') for x in self.cf.get_children_map().keys()]
def visit(self, expr):
# A hack to disable memo_map as we want to defer the visited check
self.memo_map = {}
return super().visit(expr)
def visit_maybe_layerish(self, layerish, superf):
if layerish in self.passthru:
return superf(layerish)
if layerish in self.handled:
return # We're not a mutator so just return nothing
children = self.cf.get_children(layerish)
if not children.issubset(self.handled):
return # We're not a mutator so just return nothing
[self.passthru.add(c) for c in layerish_components(layerish)[:-1]]
self.handled.add(layerish)
self.visit_layerish(layerish)
return superf(layerish)
def run(self):
self.handled.clear()
return super().visit(self.expr)