Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call any trainer function from the LightningCLI #7508

Merged
merged 74 commits into from
Aug 28, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
6d95128
Default `seed_everything(workers=True)` in the `LightningCLI`
carmocca May 12, 2021
1937103
Update CHANGELOG
carmocca May 12, 2021
76e814e
Add support for all trainer functions to the `LightningCLI`
carmocca May 12, 2021
5c52e89
Update default
carmocca May 12, 2021
97a6628
Lowercase `TrainerFn`
carmocca May 12, 2021
532b81d
Update docs
carmocca May 12, 2021
9259a32
run_kwargs
carmocca May 12, 2021
0b22b7a
Update tests
carmocca May 12, 2021
ef91e77
Update CHANGELOG
carmocca May 12, 2021
91715b9
Use proper subcommands
carmocca May 12, 2021
c004f41
Revert "Lowercase `TrainerFn`"
carmocca May 12, 2021
3738a6b
Dynamic subcommand calling
carmocca May 12, 2021
8889b11
Add core arguments to the base parser
carmocca May 12, 2021
a5e90f9
TODO
carmocca May 12, 2021
37c30f3
Fix some tests
carmocca May 12, 2021
897c9ae
Address comments
carmocca May 13, 2021
03d82df
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca May 13, 2021
717707c
Fix imports
carmocca May 13, 2021
e475f2d
Fix test
carmocca May 13, 2021
869a7b4
Re-structure
carmocca May 14, 2021
22d1aac
Return None
carmocca May 14, 2021
12537c0
Minor changes
carmocca May 14, 2021
f993d35
Add commands to subparser
carmocca May 14, 2021
7370c57
Merge master - to be fixed
carmocca Jul 29, 2021
4543b0b
Improvements
carmocca Jul 29, 2021
493823d
Improvements
carmocca Jul 29, 2021
05bffa2
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 3, 2021
fcc80a4
Progress
carmocca Aug 4, 2021
9e5d6cd
Shorter name
carmocca Aug 4, 2021
47425a9
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 4, 2021
37e55ec
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 10, 2021
dae877f
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 10, 2021
fabb8b2
Bad merge
carmocca Aug 10, 2021
e5cf1eb
Fix and add tests
carmocca Aug 10, 2021
b3d1fbb
Fix most tests
carmocca Aug 10, 2021
22701f2
Fix test
carmocca Aug 10, 2021
e229844
Fix config
carmocca Aug 10, 2021
4e0edaa
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 18, 2021
31df189
Fix optimizer tests
carmocca Aug 18, 2021
6bf983a
Add config tests
carmocca Aug 18, 2021
13521b6
Minor fix
carmocca Aug 18, 2021
348df1c
Undo docs changes
carmocca Aug 18, 2021
17ee5ae
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 23, 2021
570308a
Fix tests
carmocca Aug 24, 2021
9ae188d
Fix mypy
carmocca Aug 24, 2021
8cdb04f
Set failing test
carmocca Aug 24, 2021
96737ce
Fix doctests
carmocca Aug 24, 2021
f48ed61
Simplify
carmocca Aug 24, 2021
d65472c
Fix `parser_kwargs` with subcommands
carmocca Aug 24, 2021
856aecd
Add parser_kwargs and multiple config tests
carmocca Aug 24, 2021
d5c0bd7
Update docs
carmocca Aug 24, 2021
286c32b
Silence mypy
carmocca Aug 24, 2021
36f6c85
Fix mypy for unused imports
kaushikb11 Aug 24, 2021
e6d026c
Try different python version for docs
carmocca Aug 24, 2021
b96e0d7
Undo python version change
carmocca Aug 24, 2021
92d79d5
Debug - revert me
carmocca Aug 24, 2021
15afa3a
Add extra import
carmocca Aug 24, 2021
87325d7
Revert last 2 commits
carmocca Aug 24, 2021
fc8ced8
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 24, 2021
039bebd
Merge remote-tracking branch 'origin/fix/mypy' into feat/lightning-cl…
carmocca Aug 24, 2021
c1a64d2
Point CLI summary to docs
carmocca Aug 24, 2021
03c16f8
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 24, 2021
e149ce2
Minor docstring change
carmocca Aug 24, 2021
fde4e1d
Fix tests
carmocca Aug 24, 2021
4c94e8e
Fix docstring_parser import breaking make test due to mocked imports
carmocca Aug 24, 2021
3aeef28
Waiting for 3.19
carmocca Aug 24, 2021
e9a33e5
Address comments
carmocca Aug 25, 2021
0ec656e
Typo
carmocca Aug 25, 2021
65eeea5
Avoid Python 3.6 bug where `Union[int, bool]` becomes `int`
carmocca Aug 26, 2021
d07cee8
Skip tests due to bpo-17185
carmocca Aug 27, 2021
bb79dee
Update pl_examples
carmocca Aug 27, 2021
338f267
Merge branch 'master' into feat/lightning-cli-trainer-fn
carmocca Aug 27, 2021
8d1c423
Fix bash string
carmocca Aug 27, 2021
620703b
Deduplicate tests
carmocca Aug 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Added support to call any trainer function from the `LightningCLI` ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508))


- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))


Expand All @@ -35,13 +38,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))


- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


- Default `seed_everything(workers=True)` in the `LightningCLI` ([#7504](https://github.com/PyTorchLightning/pytorch-lightning/pull/7504))


- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474))
- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([#7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474))


- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))
- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))


### Deprecated
Expand Down
41 changes: 28 additions & 13 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities.cli import LightningCLI

original_fit = LightningCLI.fit
LightningCLI.fit = lambda self: None
original_run = LightningCLI.run
LightningCLI.run = lambda self: None

class MyModel(LightningModule):
def __init__(
Expand Down Expand Up @@ -37,7 +37,7 @@

.. testcleanup:: *

LightningCLI.fit = original_fit
LightningCLI.run = original_run
mock_argv.stop()


Expand Down Expand Up @@ -91,8 +91,22 @@ practice to create a configuration file and provide this to the tool. A way to d

The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class takes care of parsing command line
and config file options, instantiating the classes, setting up a callback to save the config in the log directory and
finally running :func:`trainer.fit`. The resulting object :code:`cli` can be used for instance to get the result of fit,
i.e., :code:`cli.fit_result`.
finally running the trainer. The resulting object :code:`cli` can be used for instance to get the instance of the
carmocca marked this conversation as resolved.
Show resolved Hide resolved
model, i.e., :code:`cli.model`.

The :class:`~pytorch_lightning.utilities.cli.LightningCLI` is configured to run
:meth:`~pytorch_lightning.trainer.Trainer.fit` by default. This can be changed either by setting
``LightningCLI(trainer_fn="test")`` or by passing the argument through command line positionally, e.g.:

.. code-block:: bash

python trainer.py test --trainer.limit_test_batches=10
carmocca marked this conversation as resolved.
Show resolved Hide resolved

.. tip::

You can override :meth:`~pytorch_lightning.utilities.cli.LightningCLI.prepare_run_kwargs` to pass any extra
arguments to the trainer function to run.


After multiple trainings with different configurations, each run will have in its respective log directory a
:code:`config.yaml` file. This file can be used for reference to know in detail all the settings that were used for each
Expand Down Expand Up @@ -341,9 +355,9 @@ The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include
more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. The
:class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two methods that can be used to run code before
and after :code:`trainer.fit` is executed: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` and
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A realistic example for these would be to send an email
before and after the execution of fit. The code would be something like:
and after the trainer runs: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_run` and
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_run`. A realistic example for these would be to send an email
before and after the execution. The code would be something like:

.. testcode::

Expand All @@ -354,23 +368,24 @@ before and after the execution of fit. The code would be something like:
def add_arguments_to_parser(self, parser):
parser.add_argument('--notification_email', default='will@email.com')

def before_fit(self):
def before_run(self):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
send_email(
address=self.config['notification_email'],
message='trainer.fit starting'
message='Trainer running'
)

def after_fit(self):
def after_run(self):
send_email(
address=self.config['notification_email'],
message='trainer.fit finished'
message='Trainer finished'
)

cli = MyLightningCLI(MyModel)

Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It
has the same structure as the yaml format described previously. This means for instance that the parameters used for
instantiating the trainer class can be found in :code:`self.config['trainer']`.
instantiating the trainer class can be found in :code:`self.config['trainer']`. You can also access the trainer function
that is meant to run with :code:`self.config["trainer_fn"].value`.
carmocca marked this conversation as resolved.
Show resolved Hide resolved

Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the
model and data module depend on a common parameter. For example in some cases both classes require to know the
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def test_dataloader(self):

def cli_main():
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def test_dataloader(self):

def cli_main():
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,7 @@ def cli_main():
return

cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def configure_optimizers(self):

def cli_main():
cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer.state.fn not in (TrainerFn.fit, TrainerFn.tune):
return
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
trainer=trainer, model=self.lightning_module
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:

def _should_skip_check(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking
return trainer.state.fn != TrainerFn.fit or trainer.sanity_checking

def on_train_epoch_end(self, trainer, pl_module) -> None:
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
or trainer.state.fn != TrainerFn.fit # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None
self.lightning_module.trainer.state.fn == TrainerFn.fit and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand All @@ -298,7 +298,7 @@ def __recover_child_process_weights(self, best_path, last_path):
# todo, pass also best score

# load last weights
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.fit:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
self.lightning_module.load_state_dict(ckpt)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def restore_model_state_from_ckpt_path(
if not self.save_full_weights and self.world_size > 1:
# Rely on deepspeed to load the checkpoint and necessary information
from pytorch_lightning.trainer.states import TrainerFn
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.fit
save_dir = self._filepath_to_dir(ckpt_path)

if self.zero_stage_3:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _skip_init_connections(self):
Returns: Whether to skip initialization

"""
return torch_distrib.is_initialized() and self.lightning_module.trainer.state.fn != TrainerFn.FITTING
return torch_distrib.is_initialized() and self.lightning_module.trainer.state.fn != TrainerFn.fit

def init_model_parallel_groups(self):
num_model_parallel = 1 # TODO currently no support for vertical model parallel
Expand All @@ -231,7 +231,7 @@ def _infer_check_num_gpus(self):
return self.world_size

def handle_transferred_pipe_module(self) -> None:
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
if self.lightning_module.trainer.state.fn == TrainerFn.fit:
torch_distrib.barrier() # Ensure we await main process initialization
# Add trainer/configure_optimizers to the pipe model for access in all worker processes
rpc_pipe.PipeModel.trainer = self.lightning_module.trainer
Expand All @@ -243,7 +243,7 @@ def init_pipe_module(self) -> None:
# Create pipe_module
model = self.lightning_module
self._find_and_init_pipe_module(model)
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
if self.lightning_module.trainer.state.fn == TrainerFn.fit:
torch_distrib.barrier() # Ensure we join main process initialization
model.sequential_module.foreach_worker(register_optimizers, include_self=True)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _reinit_optimizers_with_oss(self):
trainer.convert_to_lightning_optimizers()

def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
if self.model.trainer.state.fn != TrainerFn.fit:
return
self._reinit_optimizers_with_oss()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _reinit_optimizers_with_oss(self):
trainer.optimizers = optimizers

def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
if self.model.trainer.state.fn != TrainerFn.fit:
return
self._reinit_optimizers_with_oss()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None
self.lightning_module.trainer.state.fn == TrainerFn.fit and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
model: The model to check the configuration.

"""
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if self.trainer.state.fn in (TrainerFn.fit, TrainerFn.tune):
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state.fn == TrainerFn.VALIDATING:
elif self.trainer.state.fn == TrainerFn.validate:
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state.fn == TrainerFn.TESTING:
elif self.trainer.state.fn == TrainerFn.test:
self.__verify_eval_loop_configuration(model, 'test')
elif self.trainer.state.fn == TrainerFn.PREDICTING:
elif self.trainer.state.fn == TrainerFn.predict:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]:

# TODO(carmocca): when we implement flushing the logger connector metrics after
# the trainer.state changes, this should check trainer.evaluating instead
if self.trainer.state.fn in (TrainerFn.TESTING, TrainerFn.VALIDATING):
if self.trainer.state.fn in (TrainerFn.test, TrainerFn.validate):
logger_connector.evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT:

# log results of evaluation
if (
self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
self.trainer.state.fn != TrainerFn.fit and self.trainer.evaluating and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print('-' * 80)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
else:
self.trainer.call_hook('on_validation_end', *args, **kwargs)

if self.trainer.state.fn != TrainerFn.FITTING:
if self.trainer.state.fn != TrainerFn.fit:
# summarize profile results
self.trainer.profiler.describe()

Expand Down
24 changes: 12 additions & 12 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ class TrainerFn(LightningEnum):
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
"""
FITTING = 'fit'
VALIDATING = 'validate'
TESTING = 'test'
PREDICTING = 'predict'
TUNING = 'tune'
fit = 'fit'
validate = 'validate'
test = 'test'
predict = 'predict'
tune = 'tune'

@property
def _setup_fn(self) -> 'TrainerFn':
"""
``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders.
``fit`` is used instead of ``tune`` as there are no "tune" dataloaders.

This is used for the ``setup()`` and ``teardown()`` hooks
"""
return TrainerFn.FITTING if self == TrainerFn.TUNING else self
return TrainerFn.fit if self == TrainerFn.tune else self


class RunningStage(LightningEnum):
Expand All @@ -58,11 +58,11 @@ class RunningStage(LightningEnum):
This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
More than one running stage value can be set while a :class:`TrainerFn` is running:

- ``TrainerFn.FITTING`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}``
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
- ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
- ``TrainerFn.fit`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}``
- ``TrainerFn.validate`` - ``RunningStage.VALIDATING``
- ``TrainerFn.test`` - ``RunningStage.TESTING``
- ``TrainerFn.predict`` - ``RunningStage.PREDICTING``
- ``TrainerFn.tune`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
"""
TRAINING = 'train'
SANITY_CHECKING = 'sanity_check'
Expand Down
Loading