-
Notifications
You must be signed in to change notification settings - Fork 2
/
cert_util.py
138 lines (103 loc) · 4.47 KB
/
cert_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from http.client import NOT_IMPLEMENTED
import torch.nn as nn
import gurobipy as gp
from gurobipy import GRB
import numpy as np
import torchvision.transforms as trans
from torch.utils.data import DataLoader
import os
import torchvision.datasets as dset
from torch.utils.data import sampler
import torchvision.datasets as dset
def load_data(data_dir: str = "./data", num_imgs: int = 25, random: bool = False, dataset: str = 'MNIST') -> tuple:
"""
Loads the MNIST data.
Args:
data_dir:
The directory to store the full dataset.
num_imgs:
The number of images to extract from the test-set
random:
If true, random image indices are used, otherwise the first images
are used.
Returns:
A tuple of tensors (images, labels).
"""
if not os.path.isdir(data_dir):
os.mkdir(data_dir)
trns_norm = trans.ToTensor()
if dataset == 'MNIST':
cifar10_test = dset.MNIST(data_dir, train=False, download=True, transform=trns_norm)
elif dataset == 'CIFAR':
cifar10_test = dset.CIFAR10(data_dir, train=False, download=True, transform=trns_norm)
else:
raise NOT_IMPLEMENTED
if random:
loader_test = DataLoader(cifar10_test, batch_size=num_imgs,
sampler=sampler.SubsetRandomSampler(range(10000)))
else:
loader_test = DataLoader(cifar10_test, batch_size=num_imgs)
return next(iter(loader_test))
class DeltaWrapper(nn.Module):
def __init__(self, model):
super(DeltaWrapper, self).__init__()
self.model = model
def forward(self, x, delta):
pert_x = x + delta
return self.model(pert_x)
def min_correct_with_eps(alpha, beta, eps, label, number_class=10, verbose=False, dataset = 'MNIST'):
if (beta.min(1)[0]>0).sum() == 0:
print('early stop as zeros suffice')
if dataset == 'MNIST':
return 0, np.zeros((1,28,28))
elif dataset == 'CIFAR':
return 0, np.zeros((3,32,32))
else:
raise NOT_IMPLEMENTED
#construct the MILP model
m = gp.Model()
m.setParam('TimeLimit', 10*60)
if not verbose:
m.Params.LogToConsole = 0
c = m.addVar(lb=-10e5, ub=10e5, vtype=GRB.CONTINUOUS, name='fooled-samples')
delta = m.addVars(alpha.shape[-1], lb=-eps, ub=eps, vtype=GRB.CONTINUOUS, name='delta')
s = m.addVars(len(label), number_class-1, vtype=GRB.BINARY, name='s')
q_ = m.addVars(len(label), lb=-10e5, ub=10e5, vtype=GRB.CONTINUOUS, name='q_')
q = m.addVars(len(label), vtype=GRB.BINARY, name='q')
aux1 = m.addVars(len(label),number_class-1,lb=-10e5, ub=10e5, vtype=GRB.CONTINUOUS, name='aux1')
aux2 = m.addVars(len(label),number_class-1,lb=-10e5, ub=10e5, vtype=GRB.CONTINUOUS, name='aux2')
auxsc = m.addVars(len(label), number_class-1, vtype=GRB.BINARY, name='sc')
small_val = -10e3
giant_val = 10e5
#####
m.setObjective(c, GRB.MINIMIZE)
m.addConstrs(aux1[i,j] == gp.quicksum(alpha[i][j][k] * delta[k] for k in range(alpha.shape[-1])) + beta[i][j] for j in range(number_class-1)
for i in range(len(label)))
m.addConstrs(aux2[i,j] == aux1[i,j] * s[i, j] for j in range(number_class-1)
for i in range(len(label)))
#####---new formulation
m.addConstr(c == gp.quicksum(q[i] for i in range(len(label))))
m.addConstrs((q_[i] == gp.quicksum(aux2[i,j] for j in range(number_class-1)))
for i in range(len(label)))
m.addConstrs(giant_val*q[i] >= q_[i] for i in range(len(label)))
m.addConstrs(giant_val*(1-q[i]) >= -q_[i] for i in range(len(label)))
m.addConstrs(auxsc[i,j] == 1 - s[i,j] for j in range(number_class-1)
for i in range(len(label)))
m.addConstrs(gp.quicksum(s[i, j] for j in range(number_class-1)) == 1 for i in range(len(label)))
m.addConstrs(aux2[i,j] + small_val * auxsc[i, j] <= aux1[i,j]
for j in range(number_class-1)
for l in range(number_class-1)
for i in range(len(label)))
m.update()
m.optimize()
milp_delta = []
for v in m.getVars():
if str('delta') in v.varName:
milp_delta.append(v.x)
if dataset == 'MNIST':
milp_delta = np.reshape(np.asarray(milp_delta),(1,28,28))
elif dataset == 'CIFAR':
milp_delta = np.reshape(np.asarray(milp_delta),(3,32,32))
else:
raise NOT_IMPLEMENTED
return m.objVal, milp_delta