-
Notifications
You must be signed in to change notification settings - Fork 72
/
curvature.py
503 lines (433 loc) · 17.8 KB
/
curvature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
from __future__ import annotations
from typing import Any, Callable, MutableMapping
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from laplace.utils import Kron, Likelihood
class CurvatureInterface:
"""Interface to access curvature for a model and corresponding likelihood.
A `CurvatureInterface` must inherit from this baseclass and implement the
necessary functions `jacobians`, `full`, `kron`, and `diag`.
The interface might be extended in the future to account for other curvature
structures, for example, a block-diagonal one.
Parameters
----------
model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
torch model (neural network)
likelihood : {'classification', 'regression'}
last_layer : bool, default=False
only consider curvature of last layer
subnetwork_indices : torch.LongTensor, default=None
indices of the vectorized model parameters that define the subnetwork
to apply the Laplace approximation over
dict_key_x: str, default='input_ids'
The dictionary key under which the input tensor `x` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
dict_key_y: str, default='labels'
The dictionary key under which the target tensor `y` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
Attributes
----------
lossfunc : torch.nn.MSELoss or torch.nn.CrossEntropyLoss
factor : float
conversion factor between torch losses and base likelihoods
For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss.
"""
def __init__(
self,
model: nn.Module,
likelihood: Likelihood | str,
last_layer: bool = False,
subnetwork_indices: torch.LongTensor | None = None,
dict_key_x: str = "input_ids",
dict_key_y: str = "labels",
):
assert likelihood in [Likelihood.REGRESSION, Likelihood.CLASSIFICATION]
self.likelihood: Likelihood | str = likelihood
self.model: nn.Module = model
self.last_layer: bool = last_layer
self.subnetwork_indices: torch.LongTensor | None = subnetwork_indices
self.dict_key_x = dict_key_x
self.dict_key_y = dict_key_y
if likelihood == "regression":
self.lossfunc: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = (
MSELoss(reduction="sum")
)
self.factor: float = 0.5
else:
self.lossfunc: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = (
CrossEntropyLoss(reduction="sum")
)
self.factor: float = 1.0
self.params: list[nn.Parameter] = [
p for p in self._model.parameters() if p.requires_grad
]
self.params_dict: dict[str, nn.Parameter] = {
k: v for k, v in self._model.named_parameters() if v.requires_grad
}
self.buffers_dict: dict[str, torch.Tensor] = {
k: v for k, v in self.model.named_buffers()
}
@property
def _model(self) -> nn.Module:
return self.model.last_layer if self.last_layer else self.model
def jacobians(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\),
via torch.func.
Parameters
----------
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
enable_backprop : bool, default = False
whether to enable backprop through the Js and f w.r.t. x
Returns
-------
Js : torch.Tensor
Jacobians `(batch, parameters, outputs)`
f : torch.Tensor
output function `(batch, outputs)`
"""
def model_fn_params_only(params_dict, buffers_dict):
out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
return out, out
Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)(
self.params_dict, self.buffers_dict
)
# Concatenate over flattened parameters
Js = [
j.flatten(start_dim=-p.dim())
for j, p in zip(Js.values(), self.params_dict.values())
]
Js = torch.cat(Js, dim=-1)
if self.subnetwork_indices is not None:
Js = Js[:, :, self.subnetwork_indices]
return (Js, f) if enable_backprop else (Js.detach(), f.detach())
def last_layer_jacobians(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
enable_backprop: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).
Parameters
----------
x : torch.Tensor
enable_backprop : bool, default=False
Returns
-------
Js : torch.Tensor
Jacobians `(batch, outputs, last-layer-parameters)`
f : torch.Tensor
output function `(batch, outputs)`
"""
f, phi = self.model.forward_with_features(x)
bsize = phi.shape[0]
output_size = int(f.numel() / bsize)
# calculate Jacobians using the feature vector 'phi'
identity = (
torch.eye(output_size, device=next(self.model.parameters()).device)
.unsqueeze(0)
.tile(bsize, 1, 1)
)
# Jacobians are batch x output x params
Js = torch.einsum("kp,kij->kijp", phi, identity).reshape(bsize, output_size, -1)
if self.model.last_layer.bias is not None:
Js = torch.cat([Js, identity], dim=2)
return (Js, f) if enable_backprop else (Js.detach(), f.detach())
def gradients(
self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at
current parameter \\(\\theta\\).
Parameters
----------
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
y : torch.Tensor
Returns
-------
Gs : torch.Tensor
gradients `(batch, parameters)`
loss : torch.Tensor
"""
def loss_single(x, y, params_dict, buffers_dict):
"""Compute the gradient for a single sample."""
x, y = x.unsqueeze(0), y.unsqueeze(0) # vmap removes the batch dimension
output = torch.func.functional_call(
self.model, (params_dict, buffers_dict), x
)
loss = torch.func.functional_call(self.lossfunc, {}, (output, y))
return loss, loss
grad_fn = torch.func.grad(loss_single, argnums=2, has_aux=True)
batch_grad_fn = torch.func.vmap(grad_fn, in_dims=(0, 0, None, None))
batch_grad, batch_loss = batch_grad_fn(
x, y, self.params_dict, self.buffers_dict
)
Gs = torch.cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)
if self.subnetwork_indices is not None:
Gs = Gs[:, self.subnetwork_indices]
loss = batch_loss.sum(0)
return Gs, loss
def full(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
y: torch.Tensor,
**kwargs: dict[str, Any],
):
"""Compute a dense curvature (approximation) in the form of a \\(P \\times P\\) matrix
\\(H\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).
Parameters
----------
x : torch.Tensor
input data `(batch, input_shape)`
y : torch.Tensor
labels `(batch, label_shape)`
Returns
-------
loss : torch.Tensor
H : torch.Tensor
Hessian approximation `(parameters, parameters)`
"""
raise NotImplementedError
def kron(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
y: torch.Tensor,
N: int,
**kwargs: dict[str, Any],
) -> tuple[torch.Tensor, Kron]:
"""Compute a Kronecker factored curvature approximation (such as KFAC).
The approximation to \\(H\\) takes the form of two Kronecker factors \\(Q, H\\),
i.e., \\(H \\approx Q \\otimes H\\) for each Module in the neural network permitting
such curvature.
\\(Q\\) is quadratic in the input-dimension of a module \\(p_{in} \\times p_{in}\\)
and \\(H\\) in the output-dimension \\(p_{out} \\times p_{out}\\).
Parameters
----------
x : torch.Tensor
input data `(batch, input_shape)`
y : torch.Tensor
labels `(batch, label_shape)`
N : int
total number of data points
Returns
-------
loss : torch.Tensor
H : `laplace.utils.matrix.Kron`
Kronecker factored Hessian approximation.
"""
raise NotImplementedError
def diag(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
y: torch.Tensor,
**kwargs: dict[str, Any],
):
"""Compute a diagonal Hessian approximation to \\(H\\) and is represented as a
vector of the dimensionality of parameters \\(\\theta\\).
Parameters
----------
x : torch.Tensor
input data `(batch, input_shape)`
y : torch.Tensor
labels `(batch, label_shape)`
Returns
-------
loss : torch.Tensor
H : torch.Tensor
vector representing the diagonal of H
"""
raise NotImplementedError
functorch_jacobians = jacobians
class GGNInterface(CurvatureInterface):
"""Generalized Gauss-Newton or Fisher Curvature Interface.
The GGN is equal to the Fisher information for the available likelihoods.
In addition to `CurvatureInterface`, methods for Jacobians are required by subclasses.
Parameters
----------
model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
torch model (neural network)
likelihood : {'classification', 'regression'}
last_layer : bool, default=False
only consider curvature of last layer
subnetwork_indices : torch.Tensor, default=None
indices of the vectorized model parameters that define the subnetwork
to apply the Laplace approximation over
dict_key_x: str, default='input_ids'
The dictionary key under which the input tensor `x` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
dict_key_y: str, default='labels'
The dictionary key under which the target tensor `y` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
stochastic : bool, default=False
Fisher if stochastic else GGN
num_samples: int, default=1
Number of samples used to approximate the stochastic Fisher
"""
def __init__(
self,
model: nn.Module,
likelihood: Likelihood | str,
last_layer: bool = False,
subnetwork_indices: torch.LongTensor | None = None,
dict_key_x: str = "input_ids",
dict_key_y: str = "labels",
stochastic: bool = False,
num_samples: int = 1,
) -> None:
self.stochastic: bool = stochastic
self.num_samples: int = num_samples
super().__init__(
model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
)
def _get_mc_functional_fisher(self, f: torch.Tensor) -> torch.Tensor:
"""Approximate the Fisher's middle matrix (expected outer product of the functional gradient)
using MC integral with `self.num_samples` many samples.
"""
F = 0
for _ in range(self.num_samples):
if self.likelihood == "regression":
y_sample = f + torch.randn(f.shape, device=f.device) # N(y | f, 1)
grad_sample = f - y_sample # functional MSE grad
else: # classification with softmax
y_sample = torch.distributions.Multinomial(logits=f).sample()
# First functional derivative of the loglik is p - y
p = torch.softmax(f, dim=-1)
grad_sample = p - y_sample
F += (
1
/ self.num_samples
* torch.einsum("bc,bk->bck", grad_sample, grad_sample)
)
return F
def _get_functional_hessian(self, f: torch.Tensor) -> torch.Tensor | None:
if self.likelihood == "regression":
return None
else:
# second derivative of log lik is diag(p) - pp^T
ps = torch.softmax(f, dim=-1)
G = torch.diag_embed(ps) - torch.einsum("mk,mc->mck", ps, ps)
return G
def full(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
y: torch.Tensor,
**kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute the full GGN \\(P \\times P\\) matrix as Hessian approximation
\\(H_{ggn}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).
For last-layer, reduced to \\(\\theta_{last}\\)
Parameters
----------
x : torch.Tensor
input data `(batch, input_shape)`
y : torch.Tensor
labels `(batch, label_shape)`
Returns
-------
loss : torch.Tensor
H : torch.Tensor
GGN `(parameters, parameters)`
"""
Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x)
H_lik = (
self._get_mc_functional_fisher(f)
if self.stochastic
else self._get_functional_hessian(f)
)
if H_lik is not None:
H = torch.einsum("bcp,bck,bkq->pq", Js, H_lik, Js)
else: # The case of exact GGN for regression
H = torch.einsum("bcp,bcq->pq", Js, Js)
loss = self.factor * self.lossfunc(f, y)
return loss.detach(), H.detach()
def diag(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
y: torch.Tensor,
**kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x)
loss = self.factor * self.lossfunc(f, y)
H_lik = (
self._get_mc_functional_fisher(f)
if self.stochastic
else self._get_functional_hessian(f)
)
if H_lik is not None:
H = torch.einsum("bcp,bck,bkp->p", Js, H_lik, Js)
else: # The case of exact GGN for regression
H = torch.einsum("bcp,bcp->p", Js, Js)
return loss.detach(), H.detach()
class EFInterface(CurvatureInterface):
"""Interface for Empirical Fisher as Hessian approximation.
In addition to `CurvatureInterface`, methods for gradients are required by subclasses.
Parameters
----------
model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
torch model (neural network)
likelihood : {'classification', 'regression'}
last_layer : bool, default=False
only consider curvature of last layer
subnetwork_indices : torch.Tensor, default=None
indices of the vectorized model parameters that define the subnetwork
to apply the Laplace approximation over
dict_key_x: str, default='input_ids'
The dictionary key under which the input tensor `x` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
dict_key_y: str, default='labels'
The dictionary key under which the target tensor `y` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
Attributes
----------
lossfunc : torch.nn.MSELoss or torch.nn.CrossEntropyLoss
factor : float
conversion factor between torch losses and base likelihoods
For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss.
"""
def full(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
y: torch.Tensor,
**kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute the full EF \\(P \\times P\\) matrix as Hessian approximation
\\(H_{ef}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\).
For last-layer, reduced to \\(\\theta_{last}\\)
Parameters
----------
x : torch.Tensor
input data `(batch, input_shape)`
y : torch.Tensor
labels `(batch, label_shape)`
Returns
-------
loss : torch.Tensor
H_ef : torch.Tensor
EF `(parameters, parameters)`
"""
Gs, loss = self.gradients(x, y)
Gs, loss = Gs.detach(), loss.detach()
H_ef = torch.einsum("bp,bq->pq", Gs, Gs)
return self.factor * loss.detach(), self.factor * H_ef
def diag(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
y: torch.Tensor,
**kwargs: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
# Gs is (batchsize, n_params)
Gs, loss = self.gradients(x, y)
Gs, loss = Gs.detach(), loss.detach()
diag_ef = torch.einsum("bp,bp->p", Gs, Gs)
return self.factor * loss, self.factor * diag_ef