Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call any trainer function from the LightningCLI #7508

Merged
merged 74 commits into from
Aug 28, 2021
Merged
Show file tree
Hide file tree
Changes from 73 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
6d95128
Default `seed_everything(workers=True)` in the `LightningCLI`
carmocca May 12, 2021
1937103
Update CHANGELOG
carmocca May 12, 2021
76e814e
Add support for all trainer functions to the `LightningCLI`
carmocca May 12, 2021
5c52e89
Update default
carmocca May 12, 2021
97a6628
Lowercase `TrainerFn`
carmocca May 12, 2021
532b81d
Update docs
carmocca May 12, 2021
9259a32
run_kwargs
carmocca May 12, 2021
0b22b7a
Update tests
carmocca May 12, 2021
ef91e77
Update CHANGELOG
carmocca May 12, 2021
91715b9
Use proper subcommands
carmocca May 12, 2021
c004f41
Revert "Lowercase `TrainerFn`"
carmocca May 12, 2021
3738a6b
Dynamic subcommand calling
carmocca May 12, 2021
8889b11
Add core arguments to the base parser
carmocca May 12, 2021
a5e90f9
TODO
carmocca May 12, 2021
37c30f3
Fix some tests
carmocca May 12, 2021
897c9ae
Address comments
carmocca May 13, 2021
03d82df
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca May 13, 2021
717707c
Fix imports
carmocca May 13, 2021
e475f2d
Fix test
carmocca May 13, 2021
869a7b4
Re-structure
carmocca May 14, 2021
22d1aac
Return None
carmocca May 14, 2021
12537c0
Minor changes
carmocca May 14, 2021
f993d35
Add commands to subparser
carmocca May 14, 2021
7370c57
Merge master - to be fixed
carmocca Jul 29, 2021
4543b0b
Improvements
carmocca Jul 29, 2021
493823d
Improvements
carmocca Jul 29, 2021
05bffa2
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 3, 2021
fcc80a4
Progress
carmocca Aug 4, 2021
9e5d6cd
Shorter name
carmocca Aug 4, 2021
47425a9
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 4, 2021
37e55ec
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 10, 2021
dae877f
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 10, 2021
fabb8b2
Bad merge
carmocca Aug 10, 2021
e5cf1eb
Fix and add tests
carmocca Aug 10, 2021
b3d1fbb
Fix most tests
carmocca Aug 10, 2021
22701f2
Fix test
carmocca Aug 10, 2021
e229844
Fix config
carmocca Aug 10, 2021
4e0edaa
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 18, 2021
31df189
Fix optimizer tests
carmocca Aug 18, 2021
6bf983a
Add config tests
carmocca Aug 18, 2021
13521b6
Minor fix
carmocca Aug 18, 2021
348df1c
Undo docs changes
carmocca Aug 18, 2021
17ee5ae
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 23, 2021
570308a
Fix tests
carmocca Aug 24, 2021
9ae188d
Fix mypy
carmocca Aug 24, 2021
8cdb04f
Set failing test
carmocca Aug 24, 2021
96737ce
Fix doctests
carmocca Aug 24, 2021
f48ed61
Simplify
carmocca Aug 24, 2021
d65472c
Fix `parser_kwargs` with subcommands
carmocca Aug 24, 2021
856aecd
Add parser_kwargs and multiple config tests
carmocca Aug 24, 2021
d5c0bd7
Update docs
carmocca Aug 24, 2021
286c32b
Silence mypy
carmocca Aug 24, 2021
36f6c85
Fix mypy for unused imports
kaushikb11 Aug 24, 2021
e6d026c
Try different python version for docs
carmocca Aug 24, 2021
b96e0d7
Undo python version change
carmocca Aug 24, 2021
92d79d5
Debug - revert me
carmocca Aug 24, 2021
15afa3a
Add extra import
carmocca Aug 24, 2021
87325d7
Revert last 2 commits
carmocca Aug 24, 2021
fc8ced8
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 24, 2021
039bebd
Merge remote-tracking branch 'origin/fix/mypy' into feat/lightning-cl…
carmocca Aug 24, 2021
c1a64d2
Point CLI summary to docs
carmocca Aug 24, 2021
03c16f8
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 24, 2021
e149ce2
Minor docstring change
carmocca Aug 24, 2021
fde4e1d
Fix tests
carmocca Aug 24, 2021
4c94e8e
Fix docstring_parser import breaking make test due to mocked imports
carmocca Aug 24, 2021
3aeef28
Waiting for 3.19
carmocca Aug 24, 2021
e9a33e5
Address comments
carmocca Aug 25, 2021
0ec656e
Typo
carmocca Aug 25, 2021
65eeea5
Avoid Python 3.6 bug where `Union[int, bool]` becomes `int`
carmocca Aug 26, 2021
d07cee8
Skip tests due to bpo-17185
carmocca Aug 27, 2021
bb79dee
Update pl_examples
carmocca Aug 27, 2021
338f267
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 27, 2021
8d1c423
Fix bash string
carmocca Aug 27, 2021
620703b
Deduplicate tests
carmocca Aug 28, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662))


- Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
- `LightningCLI` additions:
* Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
* Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508))


- Fault-tolerant training:
Expand Down
186 changes: 116 additions & 70 deletions docs/source/common/lightning_cli.rst

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,12 @@ def predict_dataloader(self):


def cli_main():
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(
LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
predictions = cli.trainer.predict(ckpt_path="best")
print(predictions[0])


Expand Down
7 changes: 4 additions & 3 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def predict_dataloader(self):


def cli_main():
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
predictions = cli.trainer.predict(ckpt_path="best")
print(predictions[0])


Expand Down
5 changes: 3 additions & 2 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def cli_main():
if not _DALI_AVAILABLE:
return

cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")


if __name__ == "__main__":
Expand Down
7 changes: 5 additions & 2 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ def configure_optimizers(self):


def cli_main():
cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(
LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ def add_arguments_to_parser(self, parser):
}
)

def instantiate_trainer(self):
finetuning_callback = MilestonesFinetuning(**self.config_init["finetuning"])
def instantiate_trainer(self, *args):
finetuning_callback = MilestonesFinetuning(**self._get(self.config_init, "finetuning"))
self.trainer_defaults["callbacks"] = [finetuning_callback]
super().instantiate_trainer()
return super().instantiate_trainer(*args)


def cli_main():
Expand Down
9 changes: 8 additions & 1 deletion pl_examples/run_examples.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
set -ex

dir_path=$(dirname "${BASH_SOURCE[0]}")
args="--trainer.max_epochs=1 --data.batch_size=32 --trainer.limit_train_batches=2 --trainer.limit_val_batches=2"
args="
--data.batch_size=32
--trainer.max_epochs=1
--trainer.limit_train_batches=2
--trainer.limit_val_batches=2
--trainer.limit_test_batches=2
--trainer.limit_predict_batches=2
"

python "${dir_path}/basic_examples/simple_image_classifier.py" ${args} "$@"
python "${dir_path}/basic_examples/backbone_image_classifier.py" ${args} "$@"
Expand Down
199 changes: 137 additions & 62 deletions pytorch_lightning/utilities/cli.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ torchtext>=0.7
onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
jsonargparse[signatures]>=3.17.0
jsonargparse[signatures]>=3.19.0
gcsfs>=2021.5.0
2 changes: 1 addition & 1 deletion tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fi
# report+="Ran\ttests/plugins/environments/torch_elastic_deadlock.py\n"

# test that a user can manually launch individual processes
args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.fast_dev_run 1"
args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --trainer.limit_test_batches=1"
MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python pl_examples/basic_examples/simple_image_classifier.py ${args} &
MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python pl_examples/basic_examples/simple_image_classifier.py ${args}
report+="Ran\tmanual ddp launch test\n"
Expand Down
Loading