-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_loss_land.py
81 lines (56 loc) · 2.5 KB
/
main_loss_land.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import torch
from torchvision import datasets, transforms
from utils import * # get the dataset
from pyhessian import hessian # Hessian computation
#from density_plot import get_esd_plot # ESD plot
from pytorchcv.model_provider import get_model as ptcv_get_model # model
import matplotlib.pyplot as plt
#%matplotlib inline
# enable cuda devices
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
# get the model
model = ptcv_get_model("resnet20_cifar10", pretrained=True)
# change the model to eval mode to disable running stats upate
model.eval()
# create loss function
criterion = torch.nn.CrossEntropyLoss()
# get dataset
train_loader, test_loader = getData()
# for illustrate, we only use one batch to do the tutorial
for inputs, targets in train_loader:
break
# we use cuda to make the computation fast
model = model.cuda()
inputs, targets = inputs.cuda(), targets.cuda()
# create the hessian computation module
hessian_comp = hessian(model, criterion, data=(inputs, targets), cuda=True)
# Now let's compute the top eigenvalue. This only takes a few seconds.
top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues()
print("The top Hessian eigenvalue of this model is %.4f"%top_eigenvalues[-1])
# Now let's compute the top 2 eigenavlues and eigenvectors of the Hessian
top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues(top_n=2)
print("The top two eigenvalues of this model are: %.4f %.4f"% (top_eigenvalues[-1],top_eigenvalues[-2]))
# get the top eigenvector
top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues()
# This is a simple function, that will allow us to perturb the model paramters and get the result
def get_params(model_orig, model_perb, direction, alpha):
for m_orig, m_perb, d in zip(model_orig.parameters(), model_perb.parameters(), direction):
m_perb.data = m_orig.data + alpha * d
return model_perb
# lambda is a small scalar that we use to perturb the model parameters along the eigenvectors
lams = np.linspace(-0.5, 0.5, 21).astype(np.float32)
loss_list = []
# create a copy of the model
model_perb = ptcv_get_model("resnet20_cifar10", pretrained=True)
model_perb.eval()
model_perb = model_perb.cuda()
for lam in lams:
model_perb = get_params(model, model_perb, top_eigenvector[0], lam)
loss_list.append(criterion(model_perb(inputs), targets).item())
plt.plot(lams, loss_list)
plt.ylabel('Loss')
plt.xlabel('Perturbation')
plt.title('Loss landscape perturbed based on top Hessian eigenvector')