-
Notifications
You must be signed in to change notification settings - Fork 70
/
subnetmask.py
436 lines (358 loc) · 15.5 KB
/
subnetmask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
from __future__ import annotations
from copy import deepcopy
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.utils import parameters_to_vector
from torch.utils.data import DataLoader
import laplace
from laplace.utils.enums import Likelihood
from laplace.utils.feature_extractor import FeatureExtractor
from laplace.utils.swag import fit_diagonal_swag_var
__all__ = [
"SubnetMask",
"RandomSubnetMask",
"LargestMagnitudeSubnetMask",
"LargestVarianceDiagLaplaceSubnetMask",
"LargestVarianceSWAGSubnetMask",
"ParamNameSubnetMask",
"ModuleNameSubnetMask",
"LastLayerSubnetMask",
]
class SubnetMask:
"""Baseclass for all subnetwork masks in this library (for subnetwork Laplace).
Parameters
----------
model : torch.nn.Module
"""
def __init__(self, model: nn.Module) -> None:
self.model: nn.Module = model
self.parameter_vector: torch.Tensor = parameters_to_vector(
self.model.parameters()
).detach()
self._n_params: int = len(self.parameter_vector)
self._indices: torch.LongTensor | None = None
self._n_params_subnet: int | None = None
def _check_select(self) -> None:
if self._indices is None:
raise AttributeError("Subnetwork mask not selected. Run select() first.")
@property
def _device(self) -> torch.device:
return next(self.model.parameters()).device
@property
def indices(self) -> torch.LongTensor:
self._check_select()
return self._indices
@property
def n_params_subnet(self) -> int:
if self._n_params_subnet is None:
self._check_select()
self._n_params_subnet = len(self._indices)
return self._n_params_subnet
def convert_subnet_mask_to_indices(
self, subnet_mask: torch.Tensor
) -> torch.LongTensor:
"""Converts a subnetwork mask into subnetwork indices.
Parameters
----------
subnet_mask : torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters
within the vectorized model parameters
(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
Returns
-------
subnet_mask_indices : torch.LongTensor
a vector of indices of the vectorized model parameters
(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
that define the subnetwork
"""
if not isinstance(subnet_mask, torch.Tensor):
raise ValueError("Subnetwork mask needs to be torch.Tensor!")
elif (
subnet_mask.dtype
not in [
torch.int64,
torch.int32,
torch.int16,
torch.int8,
torch.uint8,
torch.bool,
]
or len(subnet_mask.shape) != 1
):
raise ValueError(
"Subnetwork mask needs to be 1-dimensional integral or boolean tensor!"
)
elif (
len(subnet_mask) != self._n_params
or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1])
!= self._n_params
):
raise ValueError(
"Subnetwork mask needs to be a binary vector of"
"size (n_params) where 1s locate the subnetwork"
"parameters within the vectorized model parameters"
"(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!"
)
subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
return subnet_mask_indices
def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor:
"""Select the subnetwork mask.
Parameters
----------
train_loader : torch.data.utils.DataLoader, default=None
each iterate is a training batch (X, y);
`train_loader.dataset` needs to be set to access \\(N\\), size of the data set
Returns
-------
subnet_mask_indices : torch.LongTensor
a vector of indices of the vectorized model parameters
(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
that define the subnetwork
"""
if self._indices is not None:
raise ValueError("Subnetwork mask already selected.")
subnet_mask = self.get_subnet_mask(train_loader)
self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
return self._indices
def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor:
"""Get the subnetwork mask.
Parameters
----------
train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y);
`train_loader.dataset` needs to be set to access \\(N\\), size of the data set
Returns
-------
subnet_mask: torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters
within the vectorized model parameters
(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
"""
raise NotImplementedError
class ScoreBasedSubnetMask(SubnetMask):
"""Baseclass for subnetwork masks defined by selecting
the top-scoring parameters according to some criterion.
Parameters
----------
model : torch.nn.Module
n_params_subnet : int
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
"""
def __init__(self, model: nn.Module, n_params_subnet: int) -> None:
super().__init__(model)
if n_params_subnet is None:
raise ValueError(
"Need to pass number of subnetwork parameters when using subnetwork Laplace."
)
if n_params_subnet > self._n_params:
raise ValueError(
f"Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params})."
)
self._n_params_subnet = n_params_subnet
self._param_scores: torch.Tensor | None = None
def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor:
raise NotImplementedError
def _check_param_scores(self) -> None:
assert self._param_scores is not None
if self._param_scores.shape != self.parameter_vector.shape:
raise ValueError(
"Parameter scores need to be of same shape as parameter vector."
)
def get_subnet_mask(self, train_loader):
"""Get the subnetwork mask by (descendingly) ranking parameters based on their scores."""
if self._param_scores is None:
self._param_scores = self.compute_param_scores(train_loader)
self._check_param_scores()
idx = torch.argsort(self._param_scores, descending=True)[
: self._n_params_subnet
]
idx = idx.sort()[0]
subnet_mask = torch.zeros_like(self.parameter_vector).bool()
subnet_mask[idx] = 1
return subnet_mask
class RandomSubnetMask(ScoreBasedSubnetMask):
"""Subnetwork mask of parameters sampled uniformly at random."""
def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor:
return torch.rand_like(self.parameter_vector)
class LargestMagnitudeSubnetMask(ScoreBasedSubnetMask):
"""Subnetwork mask identifying the parameters with the largest magnitude."""
def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor:
return self.parameter_vector.abs()
class LargestVarianceDiagLaplaceSubnetMask(ScoreBasedSubnetMask):
"""Subnetwork mask identifying the parameters with the largest marginal variances
(estimated using a diagonal Laplace approximation over all model parameters).
Parameters
----------
model : torch.nn.Module
n_params_subnet : int
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
diag_laplace_model : `laplace.baselaplace.DiagLaplace`
diagonal Laplace model to use for variance estimation
"""
def __init__(
self,
model: nn.Module,
n_params_subnet: int,
diag_laplace_model: laplace.baselaplace.DiagLaplace,
):
super().__init__(model, n_params_subnet)
self.diag_laplace_model: laplace.baselaplace.DiagLaplace = diag_laplace_model
def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor:
if train_loader is None:
raise ValueError("Need to pass train loader for subnet selection.")
self.diag_laplace_model.fit(train_loader)
return self.diag_laplace_model.posterior_variance
class LargestVarianceSWAGSubnetMask(ScoreBasedSubnetMask):
"""Subnetwork mask identifying the parameters with the largest marginal variances
(estimated using diagonal SWAG over all model parameters).
Parameters
----------
model : torch.nn.Module
n_params_subnet : int
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
likelihood : str
'classification' or 'regression'
swag_n_snapshots : int
number of model snapshots to collect for SWAG
swag_snapshot_freq : int
SWAG snapshot collection frequency (in epochs)
swag_lr : float
learning rate for SWAG snapshot collection
"""
def __init__(
self,
model: nn.Module,
n_params_subnet: int,
likelihood: Likelihood | str = Likelihood.CLASSIFICATION,
swag_n_snapshots: int = 40,
swag_snapshot_freq: int = 1,
swag_lr: float = 0.01,
):
if likelihood not in [Likelihood.CLASSIFICATION, Likelihood.REGRESSION]:
raise ValueError("Only available for classification and regression!")
super().__init__(model, n_params_subnet)
self.likelihood: Likelihood | str = likelihood
self.swag_n_snapshots: int = swag_n_snapshots
self.swag_snapshot_freq: int = swag_snapshot_freq
self.swag_lr: float = swag_lr
def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor:
if train_loader is None:
raise ValueError("Need to pass train loader for subnet selection.")
if self.likelihood == Likelihood.CLASSIFICATION:
criterion = CrossEntropyLoss(reduction="mean")
else:
criterion = MSELoss(reduction="mean")
param_variances = fit_diagonal_swag_var(
self.model,
train_loader,
criterion,
n_snapshots_total=self.swag_n_snapshots,
snapshot_freq=self.swag_snapshot_freq,
lr=self.swag_lr,
)
return param_variances
class ParamNameSubnetMask(SubnetMask):
"""Subnetwork mask corresponding to the specified parameters of the neural network.
Parameters
----------
model : torch.nn.Module
parameter_names: List[str]
list of names of the parameters (as in `model.named_parameters()`)
that define the subnetwork
"""
def __init__(self, model: nn.Module, parameter_names: list[str]) -> None:
super().__init__(model)
self._parameter_names: list[str] = parameter_names
self._n_params_subnet: int | None = None
def _check_param_names(self) -> None:
param_names = deepcopy(self._parameter_names)
if len(param_names) == 0:
raise ValueError("Parameter name list cannot be empty.")
for name, _ in self.model.named_parameters():
if name in param_names:
param_names.remove(name)
if len(param_names) > 0:
raise ValueError(f"Parameters {param_names} do not exist in model.")
def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor:
"""Get the subnetwork mask identifying the specified parameters."""
self._check_param_names()
subnet_mask_list = []
for name, param in self.model.named_parameters():
if name in self._parameter_names:
mask_method = torch.ones_like
else:
mask_method = torch.zeros_like
subnet_mask_list.append(mask_method(parameters_to_vector(param)))
subnet_mask = torch.cat(subnet_mask_list).bool()
return subnet_mask
class ModuleNameSubnetMask(SubnetMask):
"""Subnetwork mask corresponding to the specified modules of the neural network.
Parameters
----------
model : torch.nn.Module
parameter_names: List[str]
list of names of the modules (as in `model.named_modules()`) that define the subnetwork;
the modules cannot have children, i.e. need to be leaf modules
"""
def __init__(self, model: nn.Module, module_names: list[str]):
super().__init__(model)
self._module_names: list[str] = module_names
self._n_params_subnet: int | None = None
def _check_module_names(self) -> None:
module_names = deepcopy(self._module_names)
if len(module_names) == 0:
raise ValueError("Module name list cannot be empty.")
for name, module in self.model.named_modules():
if name in module_names:
if len(list(module.children())) > 0:
raise ValueError(
f'Module "{name}" has children, which is not supported.'
)
elif len(list(module.parameters())) == 0:
raise ValueError(f'Module "{name}" does not have any parameters.')
else:
module_names.remove(name)
if len(module_names) > 0:
raise ValueError(f"Modules {module_names} do not exist in model.")
def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor:
"""Get the subnetwork mask identifying the specified modules."""
self._check_module_names()
subnet_mask_list = []
for name, module in self.model.named_modules():
if len(list(module.children())) > 0 or len(list(module.parameters())) == 0:
continue
if name in self._module_names:
mask_method = torch.ones_like
else:
mask_method = torch.zeros_like
subnet_mask_list.append(
mask_method(parameters_to_vector(module.parameters()))
)
subnet_mask = torch.cat(subnet_mask_list).bool()
return subnet_mask
class LastLayerSubnetMask(ModuleNameSubnetMask):
"""Subnetwork mask corresponding to the last layer of the neural network.
Parameters
----------
model : torch.nn.Module
last_layer_name: str, default=None
name of the model's last layer, if None it will be determined automatically
"""
def __init__(self, model: nn.Module, last_layer_name: str | None = None):
super().__init__(model, [])
self._feature_extractor: FeatureExtractor = FeatureExtractor(
self.model, last_layer_name=last_layer_name
)
self._n_params_subnet: int | None = None
def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor:
"""Get the subnetwork mask identifying the last layer."""
if train_loader is None:
raise ValueError("Need to pass train loader for subnet selection.")
self._feature_extractor.eval()
if self._feature_extractor.last_layer is None:
X = next(iter(train_loader))[0]
with torch.no_grad():
self._feature_extractor.find_last_layer(X[:1].to(self._device))
self._module_names = [self._feature_extractor._last_layer_name]
return super().get_subnet_mask(train_loader)