-
Notifications
You must be signed in to change notification settings - Fork 115
/
netadapt_pruner.py
397 lines (311 loc) · 15.4 KB
/
netadapt_pruner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
import copy
import math
import multiprocessing as mp
import os
import shutil
import sys
import typing
import torch
from tinynn.graph import modifier
from tinynn.util.util import get_logger
from tinynn.prune import OneShotChannelPruner
if sys.version_info.major == 3 and sys.version_info.minor < 7:
from futures3.process import ProcessPoolExecutor
else:
from concurrent.futures import ProcessPoolExecutor
log = get_logger(__name__)
def device_init(device_ids):
if torch.cuda.is_available():
device_id = device_ids.get()
log.info(f'Init pool process with cuda id {device_id}')
torch.cuda.set_device(device_id)
class NetAdaptPruner(OneShotChannelPruner):
required_params = (
'budget_type',
'metrics',
'netadapt_max_iter',
'budget_reduce_rate_init',
'budget_reduce_rate_decay',
'netadapt_lr',
)
required_context_params = ('val_loader', 'train_loader', 'train_func', 'validate_func', 'optimizer', 'criterion')
default_values = {'netadapt_dir': 'netadapt_train/', 'netadapt_max_rounds': -1, 'netadapt_min_feature_size': 8}
context_from_params_dict = {
'optimizer': ['netadapt_optimizer', 'optimizer'],
'criterion': ['netadapt_criterion', 'criterion'],
}
condition_dict = {
'netadapt_max_iter': lambda x: 0 < x,
'budget_reduce_rate_init': lambda x: 0 <= x <= 1,
'budget_reduce_rate_decay': lambda x: 0 <= x <= 1,
'netadapt_lr': lambda x: 0 < x < 1,
'netadapt_max_rounds': lambda x: x >= -1,
'budget_type': lambda x: x in ('flops', 'weights', 'latency'),
}
budget: int
budget_ratio: float
budget_type: str
budget_reduce_rate_init: float
budget_reduce_rate_decay: float
netadapt_max_iter: int
netadapt_lr: float
netadapt_max_rounds: int
netadapt_min_feature_size: int
netadapt_dir: str
init_flops: int
def __init__(self, model, dummy_input, config, context):
self.default_sparsity = 0.0
self.original_channels = {}
super().__init__(model, dummy_input, config)
self.parse_context(context)
self.iteration = 1
for node in self.center_nodes:
self.original_channels[node.unique_name] = node.module.weight.shape[0]
def parse_config(self):
"""Parses the context and copy the needed items to the pruner"""
super(OneShotChannelPruner, self).parse_config()
all_param_keys = list(self.required_params) + list(self.default_values.keys())
for param_key in all_param_keys:
if param_key not in ['sparsity', 'metrics']:
setattr(self, param_key, self.config[param_key])
metrics = self.config['metrics']
if hasattr(modifier, metrics):
self.metric_func = getattr(modifier, metrics)
else:
raise Exception(f'{metrics} is not a known metrics for {type(self).__name__}')
budget = None
budget_ratio = None
if self.budget_type != 'flops':
# TODO: Implement budget type: weights, latency
raise NotImplementedError('Only `budget_type == "flops"` is supported')
if 'budget' in self.config:
budget = self.config['budget']
if not isinstance(budget, int):
raise Exception('The type of `budget` should be int')
if budget < 0:
raise Exception('The value of `budget` doesn\'t meet the requirement: x >= 0')
if 'budget_ratio' in self.config:
budget_ratio = self.config['budget_ratio']
if not isinstance(budget_ratio, float):
raise Exception('The type of `budget_ratio` should be float')
if budget_ratio < 0 or budget_ratio > 1:
raise Exception('The value of `budget_ratio` doesn\'t meet the requirement: 0 <= x <= 1')
if (budget is not None) != (budget_ratio is not None):
self.init_flops = self.calc_flops()
if budget_ratio is not None:
self.budget = int(budget_ratio * self.init_flops)
else:
self.budget = budget
log.info(f'Global Target/Initial FLOPS: {self.budget}/{self.init_flops}')
else:
raise Exception('You should define either `budget` or `budget_ratio` for NetAdaptPruner, not both of them')
def generate_config(self, path: str) -> None:
"""Generates a new copy the updated configuration with the given path"""
super(OneShotChannelPruner, self).generate_config(path)
def get_sparsity_state(self) -> typing.Dict[str, float]:
"""Calculate the sparsity of the subgraphs in the original model"""
sparsity = {}
for node in self.center_nodes:
sparsity[node.unique_name] = (
self.original_channels[node.unique_name] - node.module.weight.shape[0]
) / self.original_channels[node.unique_name]
return sparsity
def get_pruned_subgraph_info(
self, iteration: int, subgraph_id: int, current_flops: float, target_flops: float
) -> typing.Tuple[float, typing.Dict[str, float], int]:
"""Prunes the given subgraph so that the flops of the model <
target flops and finetunes model with the pruned subgraph"""
if torch.cuda.is_available():
self.context.device = torch.device("cuda", torch.cuda.current_device())
else:
self.context.device = torch.device("cpu")
# Copies the model, otherwise the one in the main process will be updated as well
self.model = copy.deepcopy(self.model)
self.model.eval()
self.reset()
log.info(f'Processing subgraph index {subgraph_id} at iteration: {iteration}, device: {self.context.device}')
sparsity, new_flops = self.find_prune_plan(subgraph_id, current_flops, target_flops)
if len(sparsity) == 0:
# Cannot find a plan, early stop
log.warning(f'Subgraph: {subgraph_id}, Iteration: {iteration}, cannot find a plan')
return -1, {}, -1
# Apply the mask to get a pruned model
super().apply_mask()
# Regenerate the optimizer, since the model has changed
self.context.optimizer = type(self.context.optimizer)(
self.model.parameters(), **self.context.optimizer.defaults
)
# Use fork to speed up data loading in child processes
mp_context = mp.get_context('fork')
self.context.train_loader.multiprocessing_context = mp_context
self.context.val_loader.multiprocessing_context = mp_context
num_iter_per_epoch = len(self.context.train_loader)
max_epoch = math.ceil(self.netadapt_max_iter / num_iter_per_epoch)
max_iter_last_epoch = (self.netadapt_max_iter - 1) % num_iter_per_epoch + 1
for epoch in range(1, max_epoch + 1):
self.context.epoch = epoch
if epoch == max_epoch:
old_max_iter = self.context.max_iteration
self.context.max_iteration = max_iter_last_epoch
self.context.train_func(self.model, self.context)
if epoch == max_epoch:
self.context.max_iteration = old_max_iter
save_path = os.path.join(self.netadapt_dir, f'iter_{iteration}_subgraph_{subgraph_id}.pth')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
log.info("Saving model to {}".format(save_path))
torch.save(self.model.state_dict(), save_path)
acc = self.context.validate_func(self.model, self.context)
self.context.best_epoch = self.context.epoch
log.info(f'Subgraph: {subgraph_id}, Iteration: {iteration}, FLOPS: {new_flops}, Accuracy: {acc}')
del self.model
return acc, sparsity, new_flops
def find_prune_plan(
self, subgraph_id: int, pre_flops: float, target_flops: float
) -> typing.Tuple[typing.Dict[str, float], int]:
"""Figures out the best plan to prune the given subgraph so that the flops of the model < target flops"""
nodes = []
for m in self.graph_modifier.sub_graphs[subgraph_id]:
if m.node in self.center_nodes and m.output_modify_:
nodes.append(m.node)
num_out_channels = nodes[0].next_tensors[0].shape[1]
# All possible number of channels according to `netadapt_min_feature_size`
candidate_channels = list(
range(
num_out_channels // self.netadapt_min_feature_size * self.netadapt_min_feature_size,
self.netadapt_min_feature_size - 1,
-self.netadapt_min_feature_size,
)
)
# Possible sparsity for pruning the model
sparsity_per_step = list(
map(
lambda x: (num_out_channels - x[0]) / num_out_channels,
zip(candidate_channels[1:], candidate_channels[:-1]),
)
)
# Init sparsity dictionary
self.sparsity = self.sparsity.fromkeys(self.sparsity, 0.0)
post_flops = pre_flops
diff_flops = None
possible_diff_flops = None
for idx, s in enumerate(sparsity_per_step):
for node in nodes:
self.sparsity[node.unique_name] = s
# Fast-forward when possible
if diff_flops is not None:
post_flops -= diff_flops
if not post_flops < target_flops:
continue
target_out_channels = candidate_channels[idx + 1]
for node in nodes:
self.sparsity[node.unique_name] = (num_out_channels - target_out_channels) / num_out_channels
# Reset masks
if idx != 0:
self.graph_modifier.unregister_masker()
self.graph_modifier.reset_masker()
# Get the updated masks
super().register_mask()
# Update flops for the current graph
if diff_flops is None:
next_flops = self.calc_flops()
if idx == 1:
# Skip further calculations if FLOPs varies proportional to output channel size
num_cur_flops = post_flops
num_next_flops = next_flops
if (
num_cur_flops % candidate_channels[idx] == 0
and num_next_flops % candidate_channels[idx + 1] == 0
and num_cur_flops // candidate_channels[idx] == num_next_flops // candidate_channels[idx + 1]
):
diff_flops = num_cur_flops - num_next_flops
else:
possible_diff_flops = num_cur_flops - num_next_flops
elif idx == 2:
# Skip further calculations if FLOPS(n) - FLOP(n-1) is constant
if post_flops - next_flops == possible_diff_flops:
diff_flops = possible_diff_flops
possible_diff_flops = None
post_flops = next_flops
log.info(
f'Subgraph: {subgraph_id}, '
f'Channels: {candidate_channels[idx + 1]}/{num_out_channels}, '
f'FLOPS(pre/post/target): {pre_flops}/{post_flops}/{target_flops:.2f}'
)
# Early stop if we get the desired sparsity
if post_flops < target_flops:
sparsity = copy.deepcopy(self.sparsity)
return sparsity, post_flops
return {}, -1
def prune_subgraph(self, sparsity: typing.Dict[str, float]) -> None:
"""Prunes the model with the sparsity dictionary given"""
self.sparsity = copy.deepcopy(sparsity)
super().prune()
def prune(self):
"""The main function for pruning"""
# PyTorch forbids initialization CUDA in the default fork settings
# So `spawn` is used here instead
mp_context = mp.get_context('spawn')
if torch.cuda.is_available():
max_workers = torch.cuda.device_count()
else:
max_workers = 1
available_cores = mp_context.Queue()
for i in range(max_workers):
available_cores.put(i)
with ProcessPoolExecutor(
max_workers=max_workers, mp_context=mp_context, initializer=device_init, initargs=(available_cores,)
) as pool:
cur_flops = self.calc_flops()
while (
self.netadapt_max_rounds == -1 or self.iteration <= self.netadapt_max_rounds
) and cur_flops > self.budget:
log.info(f"Start iteration {self.iteration}")
# Acquire target FLOPs for the current ratio
target_flops = cur_flops - self.budget_reduce_rate_init * cur_flops * (
self.budget_reduce_rate_decay ** (self.iteration - 1)
)
# Regenerate modifier, graph and center nodes before sending to child processes
if self.iteration != 1:
self.reset()
# Prepare jobs for child processes
num_subgraphs = len(self.graph_modifier.sub_graphs)
results = pool.map(
self.get_pruned_subgraph_info,
[self.iteration] * num_subgraphs,
range(num_subgraphs),
[cur_flops] * num_subgraphs,
[target_flops] * num_subgraphs,
)
# Get the step we should take by comparing the best accuracy among all sub jobs
max_idx, (best_acc, best_sparsity, best_flops) = max(enumerate(results), key=lambda x: x[1][0])
if len(best_sparsity) == 0:
log.error('All subgraphs yield invalid result, stopping')
break
# Sync model back to the main process
self.prune_subgraph(best_sparsity)
load_path = os.path.join(self.netadapt_dir, f'iter_{self.iteration}_subgraph_{max_idx}.pth')
self.model.load_state_dict(torch.load(load_path, map_location='cpu'))
global_sparsity = self.get_sparsity_state()
global_sparsity['best_flops'] = best_flops
global_sparsity['iteration'] = self.iteration
sparsity_path = os.path.join(self.netadapt_dir, f'iter_{self.iteration}_sparsity.yml')
super(OneShotChannelPruner, self).generate_config(sparsity_path, global_sparsity)
save_path = os.path.join(self.netadapt_dir, f'iter_{self.iteration}.pth')
shutil.copy(load_path, save_path)
# Print info and update current flops
log.info(f'Iter: {self.iteration}, Acc: {best_acc}, FLOPS: {best_flops}, Subgraph: {max_idx}')
cur_flops = best_flops
self.iteration += 1
def restore(self, iteration: int) -> torch.nn.Module:
"""Restores a model at specific iteration"""
sparsity_path = os.path.join(self.netadapt_dir, f'iter_{iteration}_sparsity.yml')
weights_path = os.path.join(self.netadapt_dir, f'iter_{iteration}.pth')
for path in (sparsity_path, weights_path):
if not os.path.exists(path):
raise FileNotFoundError(f'{path} is required for restoring the model')
config = self.load_config(sparsity_path)
self.prune_subgraph(config)
self.model.load_state_dict(torch.load(weights_path, map_location='cpu'))
self.iteration = config['iteration'] + 1
log.info(f"Restored from iteration {iteration}")
return self.model