-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathflow_trainer.py
282 lines (222 loc) · 8.41 KB
/
flow_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
from functools import partial
import jax.numpy as jnp
import jax
from jax import random, jit, vmap
from nux.internal.flow import Flow
import nux.util as util
from typing import Optional, Mapping, Callable, Sequence, Tuple, Any
from haiku._src.typing import Params, State, PRNGKey
import optax
from optax._src import transform
GradientTransformation = transform.GradientTransformation
from abc import ABC, abstractmethod
from collections import namedtuple
from .trainer import Trainer
from .tester import Tester
__all__ = ["MaximumLikelihoodTrainer",
"JointClassificationTrainer"]
################################################################################################################
class FlowTrainer(ABC):
def __init__(self, flow: Flow, optimizer: GradientTransformation, **kwargs):
self.flow = flow
self.trainer = self.TrainerClass(self.params, self.state, self.loss, optimizer)
self.tester = self.TesterClass(self.params, self.state, self.loss)
self.test_eval_times = jnp.array([])
@property
def train_losses(self):
return self.trainer.losses
@property
def n_train_steps(self):
return self.train_losses.size
@property
def test_losses(self):
return self.tester.losses
@property
def params(self):
return self.flow.params
@params.setter
def params(self, val):
self.flow.params = val
@property
def state(self):
return self.flow.state
@state.setter
def state(self, val):
self.flow.state = val
def update_params_and_state_from_trainer(self):
# Don't think this should be needed, but keep it in just in case
self.params, self.state = self.trainer.params, self.trainer.state
self.tester.params, self.tester.state = self.params, self.state
def update_params_and_state_from_tester(self):
self.params, self.state = self.tester.params, self.tester.state
self.trainer.params, self.trainer.state = self.params, self.state
@property
def TrainerClass(cls):
return Trainer
@property
def TesterClass(cls):
return Tester
@abstractmethod
def loss(self, params, state, key, inputs, **kwargs):
pass
@abstractmethod
def summarize_train_out(self, out):
pass
@abstractmethod
def summarize_test_out(self, out):
pass
def save(self, path: str=None):
save_items = {"params": self.params,
"state": self.state,
"test_eval_times": self.test_eval_times}
train_items = self.trainer.save_items()
test_items = self.tester.save_items()
save_items.update(train_items)
save_items.update(test_items)
util.save_pytree(save_items, path, overwrite=True)
def load(self, path: str=None):
loaded_items = util.load_pytree(path)
self.params = loaded_items["params"]
self.state = loaded_items["state"]
self.test_eval_times = loaded_items["test_eval_times"]
self.trainer.load_items(loaded_items)
self.tester.load_items(loaded_items)
def grad_step(self,
key: PRNGKey,
inputs: Mapping[str, jnp.ndarray],
**kwargs):
out = self.trainer.step(key, inputs, **kwargs)
self.update_params_and_state_from_trainer()
return out
def grad_step_for_loop(self,
key: PRNGKey,
inputs: Mapping[str, jnp.ndarray],
**kwargs):
out = self.trainer.step_for_loop(key, inputs, **kwargs)
self.update_params_and_state_from_trainer()
return out
def grad_step_scan_loop(self,
key: PRNGKey,
inputs: Mapping[str, jnp.ndarray],
**kwargs):
out = self.trainer.step_scan_loop(key, inputs, **kwargs)
self.update_params_and_state_from_trainer()
return out
def test_step(self,
key: PRNGKey,
inputs: Mapping[str, jnp.ndarray],
**kwargs):
out = self.tester.step(key, inputs, **kwargs)
self.update_params_and_state_from_tester()
return out
def test_step_for_loop(self,
key: PRNGKey,
inputs: Mapping[str, jnp.ndarray],
**kwargs):
out = self.tester.step_for_loop(key, inputs, **kwargs)
self.update_params_and_state_from_tester()
return out
def test_step_scan_loop(self,
key: PRNGKey,
inputs: Mapping[str, jnp.ndarray],
**kwargs):
out = self.tester.step_scan_loop(key, inputs, **kwargs)
self.update_params_and_state_from_tester()
return out
def evaluate_test_set(self,
key: PRNGKey,
input_iterator,
**kwargs):
outs = []
i = 0
try:
while True:
key, test_key = random.split(key, 2)
inputs = next(input_iterator)
test_out = self.test_step_scan_loop(key, inputs, update=False, **kwargs)
outs.append(test_out)
i += 1
except StopIteration:
pass
def concat(*args):
try:
return jnp.concatenate(args, axis=0)
except ValueError:
return jnp.array(args)
# Mark when we evaluated the test set
self.test_eval_times = jnp.hstack([self.test_eval_times, self.n_train_steps])
outs = jax.tree_multimap(concat, *outs)
# Condense the outputs over the entire test set
out = jax.tree_map(jnp.mean, outs)
self.tester.update_outputs(out)
return outs
################################################################################################################
class MaximumLikelihoodTrainer(FlowTrainer):
""" Convenience class for training a flow with maximum likelihood.
Args:
flow - A Flow object.
clip - How much to clip gradients. This is crucial for stable training!
warmup - How much to warm up the learning rate.
lr_decay - Learning rate decay.
lr - Max learning rate.
"""
def __init__(self,
flow: Flow,
optimizer: GradientTransformation=None,
image: bool=False,
**kwargs):
super().__init__(flow, optimizer=optimizer, **kwargs)
self.image = image
@property
def accumulate_args(self):
return ["log_pz", "log_det"]
def loss(self, params, state, key, inputs, **kwargs):
outputs, updated_state = self.flow._apply_fun(params, state, key, inputs, accumulate=self.accumulate_args, **kwargs)
loss = outputs.get("log_pz", 0.0) + outputs.get("log_det", 0.0)
aux = ()
return -loss.mean(), (aux, updated_state)
def summarize_train_out(self, out):
log_px = out.loss.mean()
if self.image:
log_px = self.flow.to_bits_per_dim(log_px)
return f"loss: {log_px:.2f}"
def summarize_test_out(self, out):
log_px = out.loss.mean()
if self.image:
log_px = self.flow.to_bits_per_dim(log_px)
return f"loss: {log_px:.2f}"
################################################################################################################
class JointClassificationTrainer(FlowTrainer):
def __init__(self,
flow: Flow,
optimizer: GradientTransformation=None,
image: bool=False,
**kwargs):
super().__init__(flow, optimizer=optimizer, **kwargs)
self.image = image
@property
def accumulate_args(self):
return ["log_pz", "log_det", "log_pygx"]
def loss(self, params, state, key, inputs, **kwargs):
outputs, updated_state = self.flow._apply_fun(params, state, key, inputs, accumulate=self.accumulate_args, **kwargs)
# TODO: Stop grouping p(x|y)p(y) into the prior and instead pass them out separately
log_pyax = outputs.get("log_pz", 0.0) + outputs.get("log_det", 0.0)
# Compute the data log likelihood
log_px = log_pyax - outputs.get("log_pygx", 0.0)
# Compute the accuracy
y_one_hot = inputs["y"]
acc = (outputs["prediction_one_hot"]*y_one_hot).sum(axis=-1).mean()
aux = (acc, -log_px.mean())
return -log_pyax.mean(), (aux, updated_state)
def summarize_train_out(self, out):
loss = out.loss.mean()
accuracy, nll = jax.tree_map(jnp.mean, out.aux)
if self.image:
nll = self.flow.to_bits_per_dim(nll)
return f"loss: {loss:.2f}, nll: {nll:.2f}, acc: {accuracy:.2f}"
def summarize_test_out(self, out):
loss = out.loss.mean()
accuracy, nll = jax.tree_map(jnp.mean, out.aux)
if self.image:
nll = self.flow.to_bits_per_dim(nll)
return f"loss: {loss:.2f}, nll: {nll:.2f}, acc: {accuracy:.2f}"