Skip to content

Commit

Permalink
Call any trainer function from the LightningCLI (#7508)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Aug 28, 2021
1 parent 045c879 commit 0dfc6a1
Show file tree
Hide file tree
Showing 14 changed files with 597 additions and 250 deletions.
2 changes: 2 additions & 0 deletions .azure-pipelines/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ jobs:
bash pl_examples/run_examples.sh --trainer.gpus=1
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=ddp
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=ddp --trainer.precision=16
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=dp
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=dp --trainer.precision=16
env:
PL_USE_MOCKED_MNIST: "1"
displayName: 'Testing: examples'
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662))


- Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
- `LightningCLI` additions:
* Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
* Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508))


- Fault-tolerant training:
Expand Down
186 changes: 116 additions & 70 deletions docs/source/common/lightning_cli.rst

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,12 @@ def predict_dataloader(self):


def cli_main():
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(
LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
predictions = cli.trainer.predict(ckpt_path="best")
print(predictions[0])


Expand Down
7 changes: 4 additions & 3 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def predict_dataloader(self):


def cli_main():
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
predictions = cli.trainer.predict(ckpt_path="best")
print(predictions[0])


Expand Down
5 changes: 3 additions & 2 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def cli_main():
if not _DALI_AVAILABLE:
return

cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")


if __name__ == "__main__":
Expand Down
7 changes: 5 additions & 2 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ def configure_optimizers(self):


def cli_main():
cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(
LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ def add_arguments_to_parser(self, parser):
}
)

def instantiate_trainer(self):
finetuning_callback = MilestonesFinetuning(**self.config_init["finetuning"])
def instantiate_trainer(self, *args):
finetuning_callback = MilestonesFinetuning(**self._get(self.config_init, "finetuning"))
self.trainer_defaults["callbacks"] = [finetuning_callback]
super().instantiate_trainer()
return super().instantiate_trainer(*args)


def cli_main():
Expand Down
9 changes: 8 additions & 1 deletion pl_examples/run_examples.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
set -ex

dir_path=$(dirname "${BASH_SOURCE[0]}")
args="--trainer.max_epochs=1 --data.batch_size=32 --trainer.limit_train_batches=2 --trainer.limit_val_batches=2"
args="
--data.batch_size=32
--trainer.max_epochs=1
--trainer.limit_train_batches=2
--trainer.limit_val_batches=2
--trainer.limit_test_batches=2
--trainer.limit_predict_batches=2
"

python "${dir_path}/basic_examples/simple_image_classifier.py" ${args} "$@"
python "${dir_path}/basic_examples/backbone_image_classifier.py" ${args} "$@"
Expand Down
51 changes: 4 additions & 47 deletions pl_examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,70 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import platform
from unittest import mock

import pytest
import torch

from pl_examples import _DALI_AVAILABLE
from tests.helpers.runif import RunIf

ARGS_DEFAULT = (
"--trainer.default_root_dir %(tmpdir)s "
"--trainer.max_epochs 1 "
"--trainer.limit_train_batches 2 "
"--trainer.limit_val_batches 2 "
"--trainer.limit_test_batches 2 "
"--trainer.limit_predict_batches 2 "
"--data.batch_size 32 "
)
ARGS_GPU = ARGS_DEFAULT + "--trainer.gpus 1 "
ARGS_DP = ARGS_DEFAULT + "--trainer.gpus 2 --trainer.accelerator dp "
ARGS_AMP = "--trainer.precision 16 "


@pytest.mark.parametrize(
"import_cli",
[
"pl_examples.basic_examples.simple_image_classifier",
"pl_examples.basic_examples.backbone_image_classifier",
"pl_examples.basic_examples.autoencoder",
],
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.parametrize("cli_args", [ARGS_DP, ARGS_DP + ARGS_AMP])
def test_examples_dp(tmpdir, import_cli, cli_args):

module = importlib.import_module(import_cli)
# update the temp dir
cli_args = cli_args % {"tmpdir": tmpdir}

with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
module.cli_main()


@pytest.mark.parametrize(
"import_cli",
[
"pl_examples.basic_examples.simple_image_classifier",
"pl_examples.basic_examples.backbone_image_classifier",
"pl_examples.basic_examples.autoencoder",
],
)
@pytest.mark.parametrize("cli_args", [ARGS_DEFAULT])
def test_examples_cpu(tmpdir, import_cli, cli_args):

module = importlib.import_module(import_cli)
# update the temp dir
cli_args = cli_args % {"tmpdir": tmpdir}

with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
module.cli_main()


@pytest.mark.skipif(not _DALI_AVAILABLE, reason="Nvidia DALI required")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(platform.system() != "Linux", reason="Only applies to Linux platform.")
@RunIf(min_gpus=1, skip_windows=True)
@pytest.mark.parametrize("cli_args", [ARGS_GPU])
def test_examples_mnist_dali(tmpdir, cli_args):
from pl_examples.basic_examples.dali_image_classifier import cli_main
Expand Down
Loading

0 comments on commit 0dfc6a1

Please sign in to comment.