Skip to content

Commit

Permalink
Default seed_everything(workers=True) in the LightningCLI (#7504)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored May 13, 2021
1 parent dd1a17b commit a584196
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 19 deletions.
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))


- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474))
- Default `seed_everything(workers=True)` in the `LightningCLI` ([#7504](https://github.com/PyTorchLightning/pytorch-lightning/pull/7504))


- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))
- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([#7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474))


- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))


### Deprecated
Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ practice to create a configuration file and provide this to the tool. A way to d
The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class takes care of parsing command line
and config file options, instantiating the classes, setting up a callback to save the config in the log directory and
finally running :func:`trainer.fit`. The resulting object :code:`cli` can be used for instance to get the result of fit,
i.e., :code:`cli.fit_result`.
finally running the trainer. The resulting object :code:`cli` can be used for example to get the instance of the
model, (:code:`cli.model`).

After multiple trainings with different configurations, each run will have in its respective log directory a
:code:`config.yaml` file. This file can be used for reference to know in detail all the settings that were used for each
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def test_dataloader(self):

def cli_main():
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def test_dataloader(self):

def cli_main():
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,7 @@ def cli_main():
return

cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def configure_optimizers(self):

def cli_main():
cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == '__main__':
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,14 @@ def __init__(
.. warning:: ``LightningCLI`` is in beta and subject to change.
Args:
model_class: The LightningModule class to train on.
datamodule_class: An optional LightningDataModule class.
model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on.
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class.
save_config_callback: A callback class to save the training config.
trainer_class: An optional extension of the Trainer class.
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class.
trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
seed_everything_default: Default value for seed_everything argument.
description: Description of the tool shown when running --help.
seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
seed argument.
description: Description of the tool shown when running ``--help``.
env_prefix: Prefix for environment variables.
env_parse: Whether environment variable parsing is enabled.
parser_kwargs: Additional arguments to instantiate LightningArgumentParser.
Expand Down Expand Up @@ -165,7 +166,7 @@ def __init__(
self.add_arguments_to_parser(self.parser)
self.parse_arguments()
if self.config['seed_everything'] is not None:
seed_everything(self.config['seed_everything'])
seed_everything(self.config['seed_everything'], workers=True)
self.before_instantiate_classes()
self.instantiate_classes()
self.prepare_fit_kwargs()
Expand Down

0 comments on commit a584196

Please sign in to comment.