-
Notifications
You must be signed in to change notification settings - Fork 4
/
optimize_sampling_greedy_roulette.py
227 lines (208 loc) · 8.41 KB
/
optimize_sampling_greedy_roulette.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import numpy as np
import pdb
from tensorflow import flags
from copy import deepcopy
FLAGS = flags.FLAGS
def euclidean_proj_simplex(v, s=1):
assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
n, = v.shape # will raise ValueError if v is not 1-D
# check if we are already on the simplex
if v.sum() == s and np.alltrue(v >= 0):
# best projection: itself!
return v
# get the array of cumulative sums of a sorted (decreasing) copy of v
u = np.sort(v)[::-1]
cssv = np.cumsum(u)
# get the number of > 0 components of the optimal solution
rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1]
# compute the Lagrange multiplier associated to the simplex constraint
theta = (cssv[rho] - s) / (rho + 1.0)
# compute the projection by thresholding v using theta
w = (v - theta).clip(min=0)
return w
def optimize_q(c, sq_norms):
'''Return q propto sqrt(sq_norms / c)'''
c_diffs = np.array([c[i] - (0. if i < 1 else c[i-1]) for i in range(len(c))])
c_diffs[np.logical_not(np.isfinite(c_diffs))] = 0.0
#sq_norms[not np.isfinite(sq_norms)]
p_j_geq_i_unnormalized = np.sqrt(sq_norms / (c_diffs+1e-14))
for i in range(len(p_j_geq_i_unnormalized)-1):
if p_j_geq_i_unnormalized[-(i+2)] < p_j_geq_i_unnormalized[-i+1]:
p_j_geq_i_unnormalized[-(i+2)] = p_j_geq_i_unnormalized[-i+1]
q_unnormalized = np.array(
[p - (0. if i+1 == len(p_j_geq_i_unnormalized)
else p_j_geq_i_unnormalized[i+1])
for i, p in enumerate(p_j_geq_i_unnormalized)
])
q_unnormalized = np.maximum(0., q_unnormalized)
q = np.array(q_unnormalized) / sum(q_unnormalized)
return q
def compute_and_variance(q, c, sq_norms):
weights = [1/(q[i:].sum()) for i in range(len(q))]
wn = np.array(weights) * sq_norms
wn_sum = np.cumsum(wn)
expected_variance = (q*wn_sum).sum()
expected_compute = (q*c).sum()
return expected_compute, expected_variance
def get_sq_norm_seq(sq_norms_matrix, idxs):
sq_norms = []
for i in range(len(idxs)):
idx1 = 0 if i < 1 else idxs[i-1] + 1
idx2 = idxs[i]
v = sq_norms_matrix[idx1, idx2]
sq_norms.append(v)
return np.array(sq_norms)
def get_c_seq(c, idxs):
c_seq = []
try:
compute_penalty = FLAGS.partial_update or FLAGS.compute_penalty
except Exception as e:
compute_penalty = False
if compute_penalty:
for i in range(len(idxs)):
c_seq.append(sum([c[idxs[j]] for j in range(0, i+1)]))
else:
c_seq = [c[i] for i in idxs]
return np.array(c_seq)
def cost(sq_norm_matrix, c, idxs, return_cv=False):
sq_norms = get_sq_norm_seq(sq_norm_matrix, idxs)
c = get_c_seq(c, idxs)
q = optimize_q(c, sq_norms)
cval, vval = compute_and_variance(q, c, sq_norms)
costval = cval * (vval ** FLAGS.variance_weight)
if not np.isfinite(costval):
pdb.set_trace()
if return_cv:
return costval, cval, vval
else:
return costval
def get_q(sq_norm_matrix, c, idxs):
sq_norms = get_sq_norm_seq(sq_norm_matrix, idxs)
c = get_c_seq(c, idxs)
return optimize_q(c, sq_norms)
def optimize_remove(sq_norm_matrix, c, idxs, verbose=False):
idxs = deepcopy(idxs)
baseline = cost(sq_norm_matrix, c, idxs)
converged = False
while not converged and len(idxs) > 1:
if verbose:
print("Not yet converged")
converged = True
i = len(idxs)-2
while i >= 0:
# Try eliminating every intermediate value
idxs_minus_i = idxs[:i] + idxs[i+1:]
cost_minus_i = cost(sq_norm_matrix, c, idxs_minus_i)
if cost_minus_i < baseline:
if verbose:
print("{}, trial cost {} under baseline {}".format(
i, cost_minus_i, baseline))
print("removing idx {}, remaining {}".format(idxs[i], idxs_minus_i))
baseline = cost_minus_i
idxs = idxs_minus_i
converged = False
break
else:
if verbose:
print("{}, trial cost {} not under baseline {}".format(
i, cost_minus_i, baseline))
i -= 1
q = get_q(sq_norm_matrix, c, idxs)
if verbose:
print("Converged. Final idxs: {}. Final ps: {}".format(idxs, q))
return idxs, q
def idxs_from_negative(negative_idxs, idxs):
return [i for i in idxs if i not in negative_idxs]
def optimize_add(sq_norm_matrix, c, idxs, verbose=False, logger=None):
idxs = deepcopy(idxs)
negative_idxs = idxs[:-1]
baseline = cost(sq_norm_matrix, c, idxs_from_negative(negative_idxs, idxs))
converged = False
while not converged and len(negative_idxs) > 0:
if verbose:
print("Not yet converged")
converged = True
i = 0
while i <= len(negative_idxs)-1:
idxs_minus_i = negative_idxs[:i] + negative_idxs[i+1:]
cost_minus_i = cost(sq_norm_matrix, c,
idxs_from_negative(idxs_minus_i, idxs))
if cost_minus_i < baseline:
if verbose:
print("{}, trial cost {} under baseline {}".format(
i, cost_minus_i, baseline))
print("adding idx {}, giving {}".format(
negative_idxs[i],
idxs_from_negative(idxs_minus_i, idxs)))
baseline = cost_minus_i
negative_idxs = idxs_minus_i
converged = False
break
else:
if verbose:
print("{}, trial cost {} not under baseline {}".format(
i, cost_minus_i, baseline))
i += 1
idxs = idxs_from_negative(negative_idxs, idxs)
q = get_q(sq_norm_matrix, c, idxs)
if verbose:
print("Converged. Final idxs: {}. Final ps: {}".format(idxs, q))
return idxs, q
def optimize_greedy_roulette(sq_norm_matrix, c,
idxs, verbose=False, logger=None):
'''Greedily optimize a RT sampler.
Args:
sq_norm_matrix: N+1 x N array
entries [0, j]: sq norm of g_j
entries [i+1, j]: sq norm of g_j - g_i
idxs: all remaining nodes under consideration
'''
# Try greedily optimizing idxs by starting with all and removing
base_cost, base_c, base_v = cost(sq_norm_matrix, c, [idxs[-1]],
return_cv=True)
try:
force_all_idxs = FLAGS.force_all_idxs
except Exception as e:
force_all_idxs = False
if force_all_idxs:
if verbose:
print("Forcing using all idxs")
q = get_q(sq_norm_matrix, c, idxs)
else:
idxs_remove, q_remove = optimize_remove(sq_norm_matrix, c, idxs, verbose)
idxs_add, q_add = optimize_add(sq_norm_matrix, c, idxs, verbose)
cost_remove = cost(sq_norm_matrix, c, idxs_remove)
cost_add = cost(sq_norm_matrix, c, idxs_add)
if cost_remove < cost_add:
if verbose:
print("Greedy remove cost {} < greedy add cost {}.".format(
cost_remove, cost_add))
print("Returning greedy remove idxs {} instead of greedy add "
"idxs {}".format(idxs_remove, idxs_add))
idxs = idxs_remove
q = q_remove
else:
if verbose:
print("Greedy add cost {} <= greedy remove cost {}.".format(
cost_add, cost_remove))
print("Returning greedy add idxs {} instead of greedy remove "
"idxs {}".format(idxs_add, idxs_remove))
idxs = idxs_add
q = q_add
costval, cval, vval = cost(sq_norm_matrix, c, idxs, return_cv=True)
if logger:
logger.info(
"Optimized RT. idxs: {}. q: {}".format(idxs, q))
logger.info("RT estimator has cost " +
"{:.2f}, compute: {:.2f}, variance: {:.2f}".format(
costval, cval, vval
))
logger.info("Deterministic estimator has " +
"cost: {:.2f}, compute {:.2f}, variance {:.2f}".format(
base_cost, base_c, base_v
))
logger.info("Change factors are " +
"cost {:.2f}, compute {:.2f}, variance {:.2f}.".format(
costval/base_cost, cval/base_c, vval/base_v
))
return idxs, q