-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
lightning.py
1491 lines (1115 loc) · 52.4 KB
/
lightning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import inspect
import logging as log
import os
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Any, Callable, Dict, Optional, Union
import torch
import torch.distributed as dist
from torch.optim import Adam
from pytorch_lightning.core.decorators import data_loader
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.debugging import MisconfigurationException
try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
except ImportError:
XLA_AVAILABLE = False
class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
def __init__(self, *args, **kwargs):
super(LightningModule, self).__init__(*args, **kwargs)
#: Current dtype
self.dtype = torch.FloatTensor
self.exp_save_path = None
#: The current epoch
self.current_epoch = 0
#: Total training batches seen across all epochs
self.global_step = 0
self.loaded_optimizer_states_dict = {}
#: Pointer to the trainer object
self.trainer = None
#: Pointer to the logger object
self.logger = None
self.example_input_array = None
#: True if your model is currently running on GPUs.
#: Useful to set flags around the LightningModule for different CPU vs GPU behavior.
self.on_gpu = False
#: True if using dp
self.use_dp = False
#: True if using ddp
self.use_ddp = False
#: True if using ddp2
self.use_ddp2 = False
#: True if using amp
self.use_amp = False
self.hparams = None
def print(self, *args, **kwargs):
r"""
Prints only from process 0. Use this in any distributed mode to log only once
Args:
x (object): The thing to print
Example
-------
.. code-block:: python
# example if we were using this model as a feature extractor
def forward(self, x):
self.print(x, 'in loader')
"""
if self.trainer.proc_rank == 0:
log.info(*args, **kwargs)
@abstractmethod
def forward(self, *args, **kwargs):
r"""
Same as torch.nn.Module.forward(), however in Lightning you want this to define
the operations you want to use for prediction (ie: on a server or as a feature extractor).
Normally you'd call self.forward() from your training_step() method. This makes it easy to write a complex
system for training with the outputs you'd want in a prediction setting.
Args:
x (tensor): Whatever you decide to define in the forward method
Return:
Predicted output
Example
-------
.. code-block:: python
# example if we were using this model as a feature extractor
def forward(self, x):
feature_maps = self.convnet(x)
return feature_maps
def training_step(self, batch, batch_idx):
x, y = batch
feature_maps = self.forward(x)
logits = self.classifier(feature_maps)
# ...
return loss
# splitting it this way allows model to be used a feature extractor
model = MyModelAbove()
inputs = server.get_request()
results = model(inputs)
server.write_results(results)
# -------------
# This is in stark contrast to torch.nn.Module where normally you would have this:
def forward(self, batch):
x, y = batch
feature_maps = self.convnet(x)
logits = self.classifier(feature_maps)
return logits
"""
def training_step(self, *args, **kwargs):
r"""return loss, dict with metrics for tqdm
Args:
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your dataloader.
A tensor, tuple or list
batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): If using multiple optimizers, this argument will also be present.
hiddens(:`Tensor <https://pytorch.org/docs/stable/tensors.html>`_): Passed in if truncated_bptt_steps > 0.
:param
:return: dict with loss key and optional log, progress keys
if implementing training_step, return whatever you need in that step:
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
In this step you'd normally do the forward pass and calculate the loss for a batch.
You can also do fancier things like multiple forward passes or something specific to your model.
Example
-------
.. code-block:: python
def training_step(self, batch, batch_idx):
x, y, z = batch
# implement your own
out = self.forward(x)
loss = self.loss(out, x)
logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS)
# if using TestTubeLogger or TensorBoardLogger you can nest scalars
logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS)
output = {
'loss': loss, # required
'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS)
'log': logger_logs
}
# return a dict
return output
If you define multiple optimizers, this step will also be called with an additional `optimizer_idx` param.
.. code-block:: python
# Multiple optimizers (ie: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 0:
# do training_step with encoder
if optimizer_idx == 1:
# do training_step with decoder
If you add truncated back propagation through time you will also get an additional
argument with the hidden states of the previous step.
.. code-block:: python
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hiddens from the previous truncated backprop step
...
out, hiddens = self.lstm(data, hiddens)
...
return {
"loss": ...,
"hiddens": hiddens # remember to detach() this
}
You can also return a -1 instead of a dict to stop the current loop. This is useful
if you want to break out of the current training epoch early.
"""
def training_end(self, *args, **kwargs):
"""
.. warning:: Deprecated in v0.7.0. use training_step_end instead
"""
def training_step_end(self, *args, **kwargs):
"""
Use this when training with dp or ddp2 because training_step will operate
on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
.. note:: If you later switch to ddp or some other mode, this will still be called
so that you don't have to change your code
.. code-block:: python
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
training_step_end(batch_parts_outputs)
:param batch_parts_outputs: What you return in `training_step` for each batch part.
:return dict: dictionary with loss key and optional log, progress keys:
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
In this case you should define training_step_end to perform those calculations.
Example
-------
.. code-block:: python
# WITHOUT training_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
loss = self.softmax(out)
loss = nce_loss(loss)
return {'loss': loss}
# --------------
# with training_step_end to do softmax over the full batch
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
return {'out': out}
def training_step_end(self, outputs):
# this out is now the full size of the batch
out = outputs['out']
# this softmax now uses the full batch size
loss = nce_loss(loss)
return {'loss': loss}
.. note:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
"""
def validation_step(self, *args, **kwargs):
r"""
Operate on a single batch of data from the validation set
In this step you'd normally generate examples or calculate anything of interest such as accuracy.
.. code-block:: python
# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
out = validation_step(train_batch)
val_outs.append(out
validation_epoch_end(val_outs)
Args:
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your dataloader.
A tensor, tuple or list
batch_idx (int): The index of this batch
dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple
val datasets used)
Return:
Dict or OrderedDict - passed to the validation_epoch_end
.. code-block:: python
# if you have one val dataloader:
def validation_step(self, batch, batch_idx)
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx)
Example
-------
.. code-block:: python
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self.forward(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# all optional...
# return whatever you need for the collation function validation_end
output = OrderedDict({
'val_loss': loss_val,
'val_acc': torch.tensor(val_acc), # everything must be a tensor
})
# return an optional dict
return output
If you pass in multiple validation datasets, validation_step will have an additional argument.
.. code-block:: python
# CASE 2: multiple validation datasets
def validation_step(self, batch, batch_idx, dataset_idx):
# dataset_idx tells you which dataset this is.
.. note:: If you don't need to validate you don't need to implement this method.
.. note:: When the validation_step is called, the model has been put in eval mode and PyTorch gradients
have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
"""
def validation_step_end(self, *args, **kwargs):
"""
Use this when training with dp or ddp2 because training_step will operate
on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
.. note:: If you later switch to ddp or some other mode, this will still be called
so that you don't have to change your code
.. code-block:: python
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
validation_step_end(batch_parts_outputs)
:param batch_parts_outputs: What you return in `training_step` for each batch part.
:return dict: dictionary with loss key and optional log, progress keys:
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
In this case you should define validation_step_end to perform those calculations.
Example
-------
.. code-block:: python
# WITHOUT validation_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
loss = self.softmax(out)
loss = nce_loss(loss)
return {'loss': loss}
# --------------
# with validation_step_end to do softmax over the full batch
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
return {'out': out}
def validation_step_end(self, outputs):
# this out is now the full size of the batch
out = outputs['out']
# this softmax now uses the full batch size
loss = nce_loss(loss)
return {'loss': loss}
.. note:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
"""
def validation_end(self, outputs):
"""
.. warning:: Deprecated in v0.7.0. use validation_epoch_end instead.
Will be removed 1.0.0
:param outputs:
:return:
"""
def validation_epoch_end(self, outputs):
"""
Called at end of validation epoch with the output of all validation_steps
.. code-block:: python
# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
out = validation_step(train_batch)
train_outs.append(out
validation_epoch_end(val_outs)
Args:
outputs (list): List of outputs you defined in validation_step, or if there are multiple dataloaders,
a list containing a list of outputs for each dataloader
Return:
Dict or OrderedDict (dict): Dict has the following optional keys:
progress_bar -> Dict for progress bar display. Must have only tensors
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
.. note:: If you didn't define a validation_step, this won't be called.
- The outputs here are strictly for logging or progress bar.
- If you don't need to display anything, don't return anything.
- If you want to manually set current step, you can specify it with the 'step' key in the 'log' Dict.
Example
-------
With a single dataloader
.. code-block:: python
def validation_epoch_end(self, outputs):
val_acc_mean = 0
for output in outputs:
val_acc_mean += output['val_acc']
val_acc_mean /= len(outputs)
tqdm_dict = {'val_acc': val_acc_mean.item()}
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_acc': val_acc_mean.item()}
}
return results
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.
.. code-block:: python
def validation_epoch_end(self, outputs):
val_acc_mean = 0
i = 0
for dataloader_outputs in outputs:
for output in dataloader_outputs:
val_acc_mean += output['val_acc']
i += 1
val_acc_mean /= i
tqdm_dict = {'val_acc': val_acc_mean.item()}
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_acc': val_acc_mean.item(), 'step': self.current_epoch}
}
return results
"""
def test_step(self, *args, **kwargs):
r"""
Operate on a single batch of data from the test set
In this step you'd normally generate examples or calculate anything of interest such as accuracy.
.. code-block:: python
# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
out = test_step(train_batch)
test_outs.append(out
test_epoch_end(test_outs)
Args:
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your dataloader.
A tensor, tuple or list
batch_idx (int): The index of this batch
dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple
test datasets used)
Return:
Dict or OrderedDict - passed to the test_epoch_end
.. code-block:: python
# if you have one test dataloader:
def test_step(self, batch, batch_idx)
# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx)
Example
-------
.. code-block:: python
# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self.forward(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# all optional...
# return whatever you need for the collation function validation_end
output = OrderedDict({
'val_loss': loss_val,
'val_acc': torch.tensor(val_acc), # everything must be a tensor
})
# return an optional dict
return output
If you pass in multiple validation datasets, validation_step will have an additional argument.
.. code-block:: python
# CASE 2: multiple validation datasets
def test_step(self, batch, batch_idx, dataset_idx):
# dataset_idx tells you which dataset this is.
.. note:: If you don't need to validate you don't need to implement this method.
.. note:: When the validation_step is called, the model has been put in eval mode and PyTorch gradients
have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
"""
def test_step_end(self, *args, **kwargs):
"""
Use this when training with dp or ddp2 because training_step will operate
on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
.. note:: If you later switch to ddp or some other mode, this will still be called
so that you don't have to change your code
.. code-block:: python
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
test_step_end(batch_parts_outputs)
:param batch_parts_outputs: What you return in `training_step` for each batch part.
:return dict: dictionary with loss key and optional log, progress keys:
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
In this case you should define test_step_end to perform those calculations.
Example
-------
.. code-block:: python
# WITHOUT test_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
loss = self.softmax(out)
loss = nce_loss(loss)
return {'loss': loss}
# --------------
# with test_step_end to do softmax over the full batch
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
return {'out': out}
def test_step_end(self, outputs):
# this out is now the full size of the batch
out = outputs['out']
# this softmax now uses the full batch size
loss = nce_loss(loss)
return {'loss': loss}
.. note:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
"""
def test_end(self, outputs):
"""
.. warning:: Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0
:param outputs:
:return:
"""
def test_epoch_end(self, outputs):
"""
Called at end of test epoch with the output of all test_steps
.. code-block:: python
# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
out = test_step(test_batch)
test_outs.append(out)
test_epoch_end(test_outs)
Args:
outputs (list): List of outputs you defined in test_step, or if there are multiple dataloaders,
a list containing a list of outputs for each dataloader
Return:
Dict or OrderedDict (dict): Dict has the following optional keys:
progress_bar -> Dict for progress bar display. Must have only tensors
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
.. note:: If you didn't define a test_step, this won't be called.
- The outputs here are strictly for logging or progress bar.
- If you don't need to display anything, don't return anything.
- If you want to manually set current step, you can specify it with the 'step' key in the 'log' Dict.
Example
-------
With a single dataloader
.. code-block:: python
def test_epoch_end(self, outputs):
test_acc_mean = 0
for output in outputs:
test_acc_mean += output['test_acc']
test_acc_mean /= len(outputs)
tqdm_dict = {'test_acc': test_acc_mean.item()}
# show test_loss and test_acc in progress bar but only log test_loss
results = {
'progress_bar': tqdm_dict,
'log': {'test_acc': test_acc_mean.item()}
}
return results
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each test step for that dataloader.
.. code-block:: python
def test_epoch_end(self, outputs):
test_acc_mean = 0
i = 0
for dataloader_outputs in outputs:
for output in dataloader_outputs:
test_acc_mean += output['test_acc']
i += 1
test_acc_mean /= i
tqdm_dict = {'test_acc': test_acc_mean.item()}
# show test_loss and test_acc in progress bar but only log test_loss
results = {
'progress_bar': tqdm_dict,
'log': {'test_acc': test_acc_mean.item(), 'step': self.current_epoch}
}
return results
"""
def configure_ddp(self, model, device_ids):
r"""
Override to init DDP in your own way or with your own wrapper.
The only requirements are that:
1. On a validation batch the call goes to model.validation_step.
2. On a training batch the call goes to model.training_step.
3. On a testing batch, the call goes to model.test_step
Args:
model (:class:`.LightningModule`): the LightningModule currently being optimized
device_ids (list): the list of GPU ids
Return:
DDP wrapped model
Example
-------
.. code-block:: python
# default implementation used in Trainer
def configure_ddp(self, model, device_ids):
# Lightning DDP simply routes to test_step, val_step, etc...
model = LightningDistributedDataParallel(
model,
device_ids=device_ids,
find_unused_parameters=True
)
return model
"""
model = LightningDistributedDataParallel(
model,
device_ids=device_ids,
find_unused_parameters=True
)
return model
def init_ddp_connection(self, proc_rank, world_size):
r"""
Override to define your custom way of setting up a distributed environment.
Lightning's implementation uses env:// init by default and sets the first node as root.
Args:
proc_rank (int): The current process rank within the node.
world_size (int): Number of GPUs being use across all nodes. (num_nodes*nb_gpu_nodes).
Example
-------
.. code-block:: python
def init_ddp_connection(self):
# use slurm job id for the port number
# guarantees unique ports across jobs from same grid search
try:
# use the last 4 numbers in the job id as the id
default_port = os.environ['SLURM_JOB_ID']
default_port = default_port[-4:]
# all ports should be in the 10k+ range
default_port = int(default_port) + 15000
except Exception as e:
default_port = 12910
# if user gave a port number, use that one instead
try:
default_port = os.environ['MASTER_PORT']
except Exception:
os.environ['MASTER_PORT'] = str(default_port)
# figure out the root node addr
try:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
except Exception:
root_node = '127.0.0.2'
root_node = self.trainer.resolve_root_node_address(root_node)
os.environ['MASTER_ADDR'] = root_node
dist.init_process_group(
'nccl',
rank=self.proc_rank,
world_size=self.world_size
)
"""
# use slurm job id for the port number
# guarantees unique ports across jobs from same grid search
try:
# use the last 4 numbers in the job id as the id
default_port = os.environ['SLURM_JOB_ID']
default_port = default_port[-4:]
# all ports should be in the 10k+ range
default_port = int(default_port) + 15000
except Exception:
default_port = 12910
# if user gave a port number, use that one instead
try:
default_port = os.environ['MASTER_PORT']
except Exception:
os.environ['MASTER_PORT'] = str(default_port)
# figure out the root node addr
try:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
except Exception:
root_node = '127.0.0.2'
root_node = self.trainer.resolve_root_node_address(root_node)
os.environ['MASTER_ADDR'] = root_node
dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
def configure_apex(self, amp, model, optimizers, amp_level):
r"""
Override to init AMP your own way
Must return a model and list of optimizers
Args:
amp (object): pointer to amp library object
model (:class:`.LightningModule`): pointer to current lightningModule
optimizers (list): list of optimizers passed in configure_optimizers()
amp_level (str): AMP mode chosen ('O1', 'O2', etc...)
Return:
Apex wrapped model and optimizers
Example
-------
.. code-block:: python
# Default implementation used by Trainer.
def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)
return model, optimizers
"""
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)
return model, optimizers
def configure_optimizers(self):
r"""
This is where you choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple.
If you don't define this method Lightning will automatically use Adam(lr=1e-3)
Return: any of these 3 options:
- Single optimizer
- List or Tuple - List of optimizers
- Two lists - The first list has multiple optimizers, the second a list of learning-rate schedulers
Example
-------
.. code-block:: python
# most cases (default if not defined)
def configure_optimizers(self):
opt = Adam(self.parameters(), lr=1e-3)
return opt
# multiple optimizer case (eg: GAN)
def configure_optimizers(self):
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
return generator_opt, disriminator_opt
# example with learning_rate schedulers
def configure_optimizers(self):
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
return [generator_opt, disriminator_opt], [discriminator_sched]
.. note:: Lightning calls .backward() and .step() on each optimizer and learning rate scheduler as needed.
.. note:: If you use 16-bit precision (use_amp=True), Lightning will automatically
handle the optimizers for you.
.. note:: If you use multiple optimizers, training_step will have an additional `optimizer_idx` parameter.
.. note:: If you use LBFGS lightning handles the closure function automatically for you
.. note:: If you use multiple optimizers, gradients will be calculated only
for the parameters of current optimizer at each training step.
.. note:: If you need to control how often those optimizers step or override the default .step() schedule,
override the `optimizer_step` hook.
"""
return Adam(self.parameters(), lr=1e-3)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
r"""
Override this method to adjust the default way the Trainer calls each optimizer. By default, Lightning
calls .step() and zero_grad() as shown in the example once per optimizer.
Args:
epoch (int): Current epoch
batch_idx (int): Index of current batch
optimizer (torch.nn.Optimizer): A PyTorch optimizer
optimizer_idx (int): If you used multiple optimizers this indexes into that list
second_order_closure (int): closure for second order methods
Example
-------
.. code-block:: python
# DEFAULT
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
optimizer.step()
optimizer.zero_grad()
# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
# update generator opt every 2 steps
if optimizer_idx == 0:
if batch_idx % 2 == 0 :
optimizer.step()
optimizer.zero_grad()
# update discriminator opt every 4 steps
if optimizer_idx == 1:
if batch_idx % 4 == 0 :
optimizer.step()
optimizer.zero_grad()
# ...
# add as many optimizers as you want
Here's another example showing how to use this for more advanced things such as learning-rate warm-up:
.. code-block:: python
# learning rate warm-up
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.hparams.learning_rate
# update params
optimizer.step()
optimizer.zero_grad()
"""
if self.trainer.use_tpu and XLA_AVAILABLE:
xm.optimizer_step(optimizer)
elif isinstance(optimizer, torch.optim.LBFGS):