-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathssl.py
127 lines (99 loc) · 4.05 KB
/
ssl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def group_lasso(param_group):
return torch.sum(param_group ** 2)
def cross_entropy_loss(outputs, targets):
"""Cross-entropy loss"""
ce_loss_func = nn.CrossEntropyLoss()
return ce_loss_func(outputs, targets)
def filter_and_channel_wise_ssl_loss(model, outputs, targets, params):
"""
Penalizing unimportant filters and channels
Params:
- pretrained model
- output tensor generated by the forward step of model
- target output tensor corresponding to the input (for CE loss)
- hyperparameters dict given in `params.py`. Must contain entry for
weight decay, lambda_n (for filter-wise group LASSO), and lambda_c
(for channel-wise group LASSO)
Return:
- Loss given by Eqn.2 of "Learning Structured Sparsity in Deep Neural
Networks"
"""
# Compute cross-entropy loss
ce_loss = cross_entropy_loss(outputs, targets)
# Loss accumulators
wgt_l2_norm = torch.Tensor([0.])
filter_wise_loss = torch.Tensor([0.])
channel_wise_loss = torch.Tensor([0.])
if params.use_gpu == True:
wgt_l2_norm = wgt_l2_norm.cuda()
filter_wise_loss = filter_wise_loss.cuda()
channel_wise_loss = channel_wise_loss.cuda()
# Coefficient hyperparams
hyperparams = params.ssl_hyperparams
wgt_decay = hyperparams["wgt_decay"]
lambda_n = hyperparams["lambda_n"]
lambda_c = hyperparams["lambda_c"]
# Iterate over every layer
params = list(model.parameters())
for param in params:
# L2 norm over entire parameters
wgt_l2_norm += torch.norm(param)
# Ignore linear or bias parameters
if len(param.size()) != 4:
continue
num_filters, num_channels = param.size()[0], param.size()[1]
# Group LASSO over filters of current layer
for filter_idx in range(num_filters):
filter_wise_loss += group_lasso(param[filter_idx, :, :, :])
# Group LASSO over channels of current layer
for channel_idx in range(num_channels):
channel_wise_loss += group_lasso(param[:, channel_idx, :, :])
return ce_loss + (wgt_decay * wgt_l2_norm) + (lambda_n * \
filter_wise_loss)+ (lambda_c * channel_wise_loss)
def shape_fiber_ssl_loss(model, outputs, targets, params):
"""
Learning shapes of filters
Params:
- pretrained model
- output tensor generated by the forward step of model
- target output tensor corresponding to the input (for CE loss)
- hyperparameters dict given in `params.py`. Must contain entry for
weight decay and lambda_s (for shape-wise group LASSO)
Return:
- Loss given by Eqn.3 of "Learning Structured Sparsity in Deep Neural
Networks"
"""
# Compute corss-entropy loss
ce_loss = cross_entropy_loss(outputs, targets)
# Loss accumulators
wgt_l2_norm = torch.Tensor([0.])
shape_wise_loss = torch.Tensor([0.])
if params.use_gpu == True:
wgt_l2_norm = wgt_l2_norm.cuda()
shape_wise_loss = shape_wise_loss.cuda()
# Coefficient hyperparameters
hyperparams = params.ssl_hyperparams
wgt_decay = hyperparams["wgt_decay"]
lambda_s = hyperparams["lambda_s"]
params = list(model.parameters())
for param in params:
# L2 norm over entire parameters
wgt_l2_norm += torch.norm(param)
# Ignore linear or bias parameters
if len(param.size()) != 4:
continue
# Group LASSO over shapes
num_channels = param.size()[1]
height = param.size()[2]
width = param.size()[3]
for channel_idx in range(num_channels):
for height_idx in range(height):
for width_idx in range(width):
shape_wise_loss += group_lasso(param[:, channel_idx,
height_idx, width_idx])
return ce_loss + (wgt_decay * wgt_l2_norm) + (lambda_s * \
shape_wise_loss)