-
Notifications
You must be signed in to change notification settings - Fork 9
/
main.py
174 lines (137 loc) · 7.39 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import importlib
import copy
import argparse
import measures
from torchvision import transforms, datasets
# train the model for one epoch on the given set
def train(args, model, device, train_loader, criterion, optimizer, epoch):
sum_loss, sum_correct = 0, 0
# switch to train mode
model.train()
for i, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# compute the output
output = model(data)
# compute the classification error and loss
loss = criterion(output, target)
pred = output.max(1)[1]
sum_correct += pred.eq(target).sum().item()
sum_loss += len(data) * loss.item()
# compute the gradient and do an SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
return 1 - (sum_correct / len(train_loader.dataset)), sum_loss / len(train_loader.dataset)
# evaluate the model on the given set
def validate(args, model, device, val_loader, criterion):
sum_loss, sum_correct = 0, 0
margin = torch.Tensor([]).to(device)
# switch to evaluation mode
model.eval()
with torch.no_grad():
for i, (data, target) in enumerate(val_loader):
data, target = data.to(device), target.to(device)
# compute the output
output = model(data)
# compute the classification error and loss
pred = output.max(1)[1]
sum_correct += pred.eq(target).sum().item()
sum_loss += len(data) * criterion(output, target).item()
# compute the margin
output_m = output.clone()
for i in range(target.size(0)):
output_m[i, target[i]] = output_m[i,:].min()
margin = torch.cat((margin, output[:, target].diag() - output_m[:, output_m.max(1)[1]].diag()), 0)
val_margin = np.percentile( margin.cpu().numpy(), 5 )
return 1 - (sum_correct / len(val_loader.dataset)), sum_loss / len(val_loader.dataset), val_margin
# Load and Preprocess data.
# Loading: If the dataset is not in the given directory, it will be downloaded.
# Preprocessing: This includes normalizing each channel and data augmentation by random cropping and horizontal flipping
def load_data(split, dataset_name, datadir, nchannels):
if dataset_name == 'MNIST':
normalize = transforms.Normalize(mean=[0.131], std=[0.289])
else:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
tr_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize])
val_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), normalize])
get_dataset = getattr(datasets, dataset_name)
if dataset_name == 'SVHN':
if split == 'train':
dataset = get_dataset(root=datadir, split='train', download=True, transform=tr_transform)
else:
dataset = get_dataset(root=datadir, split='test', download=True, transform=val_transform)
else:
if split == 'train':
dataset = get_dataset(root=datadir, train=True, download=True, transform=tr_transform)
else:
dataset = get_dataset(root=datadir, train=False, download=True, transform=val_transform)
return dataset
# This function trains a neural net on the given dataset and calculates various measures on the learned network.
def main():
# settings
parser = argparse.ArgumentParser(description='Training a fully connected NN with one hidden layer')
parser.add_argument('--no-cuda', default=False, action='store_true',
help='disables CUDA training')
parser.add_argument('--datadir', default='datasets', type=str,
help='path to the directory that contains the datasets (default: datasets)')
parser.add_argument('--dataset', default='CIFAR10', type=str,
help='name of the dataset (options: MNIST | CIFAR10 | CIFAR100 | SVHN, default: CIFAR10)')
parser.add_argument('--model', default='vgg', type=str,
help='architecture (options: fc | vgg, default: vgg)')
parser.add_argument('--epochs', default=1000, type=int,
help='number of epochs to train (default: 1000)')
parser.add_argument('--stopcond', default=0.01, type=float,
help='stopping condtion based on the cross-entropy loss (default: 0.01)')
parser.add_argument('--batchsize', default=64, type=int,
help='input batch size (default: 64)')
parser.add_argument('--learningrate', default=0.01, type=float,
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', default=0.9, type=float,
help='momentum (default: 0.9)')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
nchannels, nclasses, img_dim, = 3, 10, 32
if args.dataset == 'MNIST': nchannels = 1
if args.dataset == 'CIFAR100': nclasses = 100
# create an initial model
model = getattr(importlib.import_module('models.{}'.format(args.model)), 'Network')(nchannels, nclasses)
model = model.to(device)
# create a copy of the initial model to be used later
init_model = copy.deepcopy(model)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), args.learningrate, momentum=args.momentum)
# loading data
train_dataset = load_data('train', args.dataset, args.datadir, nchannels)
val_dataset = load_data('val', args.dataset, args.datadir, nchannels)
train_loader = DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, **kwargs)
val_loader = DataLoader(val_dataset, batch_size=args.batchsize, shuffle=False, **kwargs)
# training the model
for epoch in range(0, args.epochs):
# train for one epoch
tr_err, tr_loss = train(args, model, device, train_loader, criterion, optimizer, epoch)
val_err, val_loss, val_margin = validate(args, model, device, val_loader, criterion)
print(f'Epoch: {epoch + 1}/{args.epochs}\t Training loss: {tr_loss:.3f}\t',
f'Training error: {tr_err:.3f}\t Validation error: {val_err:.3f}')
# stop training if the cross-entropy loss is less than the stopping condition
if tr_loss < args.stopcond: break
# calculate the training error and margin of the learned model
tr_err, tr_loss, tr_margin = validate(args, model, device, train_loader, criterion)
print(f'\nFinal: Training loss: {tr_loss:.3f}\t Training margin {tr_margin:.3f}\t ',
f'Training error: {tr_err:.3f}\t Validation error: {val_err:.3f}\n')
# calcualtes various measures and bounds on the learned network
measure_dict, bound_dict = measures.calculate(model, init_model, device, train_loader, tr_margin, nchannels, nclasses, img_dim)
print('\n###### Measures')
for key, value in measure_dict.items():
print(f'{key.ljust(25):s}:{float(value):3.3}')
print('\n###### Generalization Bounds')
for key, value in bound_dict.items():
print(f'{key.ljust(45):s}:{float(value):3.3}')
if __name__ == '__main__':
main()