-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
train_loop_mixin.py
309 lines (241 loc) · 11.4 KB
/
train_loop_mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import numpy as np
import tqdm
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
class TrainerTrainLoopMixin(object):
def train(self):
# run all epochs
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
# set seed for distributed sampler (enables shuffling for each epoch)
if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
self.get_train_dataloader().sampler.set_epoch(epoch_nb)
# get model
model = self.get_model()
# update training progress in trainer and model
model.current_epoch = epoch_nb
self.current_epoch = epoch_nb
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.nb_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
# total batches includes multiple val checks
self.total_batches = (self.nb_training_batches +
self.nb_val_batches * val_checks_per_epoch)
self.batch_loss_value = 0 # accumulated grads
if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
nb_iterations = 2
elif self.is_iterable_train_dataloader:
# for iterable train loader, the progress bar never ends
nb_iterations = None
else:
nb_iterations = self.total_batches
# reset progress bar
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(nb_iterations)
desc = f'Epoch {epoch_nb + 1}' if not self.is_iterable_train_dataloader else ''
self.main_progress_bar.set_description(desc)
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
# -----------------
# RUN TNG EPOCH
# -----------------
self.run_training_epoch()
# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step(self.current_epoch)
# early stopping
met_min_epochs = epoch_nb > self.min_nb_epochs
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb,
logs=self.callback_metrics)
# stop training
stop = should_stop and met_min_epochs
if stop:
self.main_progress_bar.close()
return
self.main_progress_bar.close()
if self.logger is not None:
self.logger.finalize("success")
def run_training_epoch(self):
# before epoch hook
if self.is_function_implemented('on_epoch_start'):
model = self.get_model()
model.on_epoch_start()
# run epoch
for batch_nb, batch in enumerate(self.get_train_dataloader()):
self.batch_nb = batch_nb
model = self.get_model()
model.global_step = self.global_step
# ---------------
# RUN TRAIN STEP
# ---------------
output = self.run_training_batch(batch, batch_nb)
batch_result, grad_norm_dic, batch_step_metrics = output
# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_result == -1
# ---------------
# RUN VAL STEP
# ---------------
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
should_check_val = ((is_val_check_batch or early_stop_epoch) and can_check_epoch)
# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
self.run_evaluation(test=self.testing)
# when logs should be saved
should_save_log = (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
if self.proc_rank == 0 and self.logger is not None:
self.logger.save()
# when metrics should be logged
should_log_metrics = batch_nb % self.row_log_interval == 0 or early_stop_epoch
if should_log_metrics or self.fast_dev_run:
# logs user requested information to logger
self.log_metrics(batch_step_metrics, grad_norm_dic)
self.global_step += 1
self.total_batch_nb += 1
# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
if early_stop_epoch or self.fast_dev_run:
break
# stop epoch if we limited nb batches
met_batch_limit = batch_nb >= self.nb_training_batches
if met_batch_limit:
break
# epoch end hook
if self.is_function_implemented('on_epoch_end'):
model = self.get_model()
model.on_epoch_end()
def run_training_batch(self, batch, batch_nb):
# track grad norms
grad_norm_dic = {}
# track all metrics for callbacks
all_callback_metrics = []
# track metrics to log
all_log_metrics = []
if batch is None:
return 0, grad_norm_dic, all_log_metrics
# hook
if self.is_function_implemented('on_batch_start'):
model_ref = self.get_model()
response = model_ref.on_batch_start(batch)
if response == -1:
return -1, grad_norm_dic, all_log_metrics
splits = [batch]
if self.truncated_bptt_steps is not None:
model_ref = self.get_model()
splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps)
self.hiddens = None
for split_nb, split_batch in enumerate(splits):
self.split_nb = split_nb
# call training_step once per optimizer
for opt_idx, optimizer in enumerate(self.optimizers):
# wrap the forward step in a closure so second order methods work
def optimizer_closure():
# forward pass
output = self.training_forward(
split_batch, batch_nb, opt_idx, self.hiddens)
closure_loss = output[0]
progress_bar_metrics = output[1]
log_metrics = output[2]
callback_metrics = output[3]
self.hiddens = output[4]
# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
closure_loss = closure_loss / self.accumulate_grad_batches
# backward pass
model_ref = self.get_model()
model_ref.backward(self.use_amp, closure_loss, optimizer)
# track metrics for callbacks
all_callback_metrics.append(callback_metrics)
# track progress bar metrics
self.add_tqdm_metrics(progress_bar_metrics)
all_log_metrics.append(log_metrics)
# insert after step hook
if self.is_function_implemented('on_after_backward'):
model_ref = self.get_model()
model_ref.on_after_backward()
return closure_loss
# calculate loss
loss = optimizer_closure()
# nan grads
if self.print_nan_grads:
self.print_nan_gradients()
# track total loss for logging (avoid mem leaks)
self.batch_loss_value += loss.item()
# gradient update with accumulated gradients
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
# track gradient norms when requested
if batch_nb % self.row_log_interval == 0:
if self.track_grad_norm > 0:
model = self.get_model()
grad_norm_dic = model.grad_norm(
self.track_grad_norm)
# clip gradients
self.clip_gradients()
# calls .step(), .zero_grad()
# override function to modify this behavior
model = self.get_model()
model.optimizer_step(self.current_epoch, batch_nb,
optimizer, opt_idx, optimizer_closure)
# calculate running loss for display
self.running_loss.append(self.batch_loss_value)
self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:])
# activate batch end hook
if self.is_function_implemented('on_batch_end'):
model = self.get_model()
model.on_batch_end()
# update progress bar
self.main_progress_bar.update(1)
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
# collapse all metrics into one dict
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
# track all metrics for callbacks
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})
return 0, grad_norm_dic, all_log_metrics
def training_forward(self, batch, batch_nb, opt_idx, hiddens):
"""
Handle forward for each training case (distributed, single gpu, etc...)
:param batch:
:param batch_nb:
:return:
"""
# ---------------
# FORWARD
# ---------------
# enable not needing to add opt_idx to training_step
args = [batch, batch_nb]
if len(self.optimizers) > 1:
args.append(opt_idx)
# pass hiddens if using tbptt
if self.truncated_bptt_steps is not None:
args.append(hiddens)
# distributed forward
if self.use_ddp or self.use_ddp2 or self.use_dp:
output = self.model(*args)
# single GPU forward
elif self.single_gpu:
gpu_id = 0
if type(self.data_parallel_device_ids) is list:
gpu_id = self.data_parallel_device_ids[0]
batch = self.transfer_batch_to_gpu(batch, gpu_id)
args[0] = batch
output = self.model.training_step(*args)
# CPU forward
else:
output = self.model.training_step(*args)
# allow any mode to define training_end
if self.is_overriden('training_end'):
model_ref = self.get_model()
output = model_ref.training_end(output)
# format and reduce outputs accordingly
output = self.process_output(output, train=True)
return output