forked from eladhoffer/quantized.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathchannel_selection.py
94 lines (80 loc) · 4.31 KB
/
channel_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch as th
import numpy as np
import torch.nn.functional as F
class ChannelSelect():
def __init__(self, ids):
self._ids = ids
def __call__(self, x):
return x[:, self._ids]
def _extractNormalizedQuants(layer_name, tracker_dict_per_class):
layer_quants = []
min_len_perc = None
for class_id in range(0, len(tracker_dict_per_class)):
percentiles, quantiles = tracker_dict_per_class[class_id][layer_name].get_distribution_histogram()
layer_quants.append(quantiles)
if min_len_perc is None or len(percentiles) < min_len_perc:
min_len_perc = len(percentiles)
for id, q in enumerate(layer_quants):
num_p = len(q)
if num_p > min_len_perc:
slice_ = (num_p - min_len_perc) // 2
layer_quants[id] = layer_quants[id][slice_:-slice_]
quants = th.stack(layer_quants, -1)
quants_norm = quants
quants_norm = F.normalize(quants_norm, dim=[0, 2], p=2)
return quants_norm
def output_select_channels_class_dependent(tracker_dict_per_class):
all_class_channel_dict = [None] * len(tracker_dict_per_class)
for e, layer_name in enumerate(tracker_dict_per_class[0].keys()):
if 'output' not in layer_name:
continue
for class_id in range(len(tracker_dict_per_class)):
if all_class_channel_dict[class_id] is None:
all_class_channel_dict[class_id] = {}
n_classes = tracker_dict_per_class[class_id][layer_name].mean.shape[0]
assert n_classes > class_id
all_class_channel_dict[class_id][layer_name] = [class_id]
return all_class_channel_dict
def find_most_seperable_channels_class_dependent(tracker_dict_per_class, relative_cut=0.05):
all_class_channel_dict = [None] * len(tracker_dict_per_class)
# =============================================================================
# layer_list = [i for i in tracker_dict_per_class[0].keys()]
# layer_dict = dict.fromkeys(layer_list)
# =============================================================================
for layer_name in tracker_dict_per_class[0].keys():
quants_norm = _extractNormalizedQuants(layer_name, tracker_dict_per_class)
for temp_base_class in range(0, len(tracker_dict_per_class)):
quant_base_class = quants_norm[:, :, temp_base_class].unsqueeze(-1).expand(
[quants_norm.shape[0], quants_norm.shape[1], len(tracker_dict_per_class)])
var_per_quant = ((quants_norm - quant_base_class) ** 2).sum(2)
var_per_channel = var_per_quant.sum(0)
_, ranked_channels = th.sort(var_per_channel)
nchannels = np.ceil(len(ranked_channels) * relative_cut).astype(np.int32)
if all_class_channel_dict[temp_base_class] is None:
all_class_channel_dict[temp_base_class] = {}
all_class_channel_dict[temp_base_class][layer_name] = ranked_channels[-(nchannels + 1):-1]
return all_class_channel_dict
def find_most_seperable_channels(tracker_dict_per_class, max_channels_per_class = 5):
layer_list = [i for i in tracker_dict_per_class[0].keys()]
layer_dict = dict.fromkeys(layer_list)
for layer_name in layer_dict.keys():
quants_norm = _extractNormalizedQuants(layer_name, tracker_dict_per_class)
chosen_channels = list()
for temp_base_class in range(0, len(tracker_dict_per_class)):
quant_base_class = quants_norm[:, :, temp_base_class].unsqueeze(-1).expand(
[quants_norm.shape[0], quants_norm.shape[1], len(tracker_dict_per_class)])
var_per_quant = ((quants_norm - quant_base_class) ** 2).sum(2)
var_per_channel = var_per_quant.sum(0)
_, ranked_channels = th.sort(var_per_channel)
chosen_channels.append(ranked_channels[-(max_channels_per_class + 1):-1])
layer_dict[layer_name] = th.unique(th.cat(chosen_channels))
return (layer_dict)
def sample_random_channels(tracker_dict_per_class,relative_cut=0.05,seed=0):
generator = th.Generator().manual_seed(seed)
ret = {}
for k,v in tracker_dict_per_class[0].items():
nchannels = v.mean.shape[0]
samp = th.randperm(nchannels,generator=generator)[:np.ceil(nchannels*relative_cut).astype(np.int32)]
assert len(samp)>=1,samp
ret[k]=samp.to(v.mean.device)
return ret