Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default seed_everything(workers=True) in the LightningCLI #7504

Merged
merged 4 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))


- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474))
- Default `seed_everything(workers=True)` in the `LightningCLI` ([#7504](https://github.com/PyTorchLightning/pytorch-lightning/pull/7504))


- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))
- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([#7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474))


- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))


### Deprecated
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def test_dataloader(self):

def cli_main():
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def test_dataloader(self):

def cli_main():
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,7 @@ def cli_main():
return

cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def configure_optimizers(self):

def cli_main():
cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == '__main__':
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,14 @@ def __init__(
.. warning:: ``LightningCLI`` is in beta and subject to change.

Args:
model_class: The LightningModule class to train on.
datamodule_class: An optional LightningDataModule class.
model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on.
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class.
save_config_callback: A callback class to save the training config.
trainer_class: An optional extension of the Trainer class.
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class.
trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
seed_everything_default: Default value for seed_everything argument.
description: Description of the tool shown when running --help.
seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
seed argument.
description: Description of the tool shown when running ``--help``.
env_prefix: Prefix for environment variables.
env_parse: Whether environment variable parsing is enabled.
parser_kwargs: Additional arguments to instantiate LightningArgumentParser.
Expand Down Expand Up @@ -165,7 +166,7 @@ def __init__(
self.add_arguments_to_parser(self.parser)
self.parse_arguments()
if self.config['seed_everything'] is not None:
seed_everything(self.config['seed_everything'])
seed_everything(self.config['seed_everything'], workers=True)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.before_instantiate_classes()
self.instantiate_classes()
self.prepare_fit_kwargs()
Expand Down