Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast_dev_run can be int #4629

Merged
merged 22 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675))


- Updated `fast_dev_run` to accept integer representing num_batches ([#4629](https://github.com/PyTorchLightning/pytorch-lightning/pull/4629))


- Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647))


Expand Down Expand Up @@ -89,6 +92,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added lambda closure to `manual_optimizer_step` ([#4618](https://github.com/PyTorchLightning/pytorch-lightning/pull/4618))


### Changed

- Change Metrics `persistent` default mode to `False` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/4685))
Expand Down
11 changes: 7 additions & 4 deletions docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@ The following are flags that make debugging much easier.

fast_dev_run
------------
This flag runs a "unit test" by running 1 training batch and 1 validation batch.
The point is to detect any bugs in the training/validation loop without having to wait for
a full epoch to crash.
This flag runs a "unit test" by running n if set to ``n`` (int) else 1 if set to ``True`` training and validation batch(es).
The point is to detect any bugs in the training/validation loop without having to wait for a full epoch to crash.

(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.fast_dev_run`
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)

.. testcode::


# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)

# runs 7 train, val, test batches and program ends
trainer = Trainer(fast_dev_run=7)

----------------

Inspect gradient norms
Expand Down
28 changes: 15 additions & 13 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ Example::
# ddp2 = DistributedDataParallel + dp
trainer = Trainer(gpus=2, num_nodes=2, accelerator='ddp2')

.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)
.. note:: This option does not apply to TPU. TPUs use ```ddp``` by default (over each core)

You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs.

Expand Down Expand Up @@ -623,17 +623,10 @@ fast_dev_run

|

.. raw:: html

<video width="50%" max-width="400px" controls
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/fast_dev_run.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/fast_dev_run.mp4"></video>

|
Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test
to find any bugs (ie: a sort of unit test).

Runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).

Under the hood the pseudocode looks like this:
Under the hood the pseudocode looks like this when running *fast_dev_run* with a single batch:

.. code-block:: python

Expand All @@ -658,6 +651,16 @@ Under the hood the pseudocode looks like this:
# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)

# runs 7 train, val, test batches and program ends
trainer = Trainer(fast_dev_run=7)

.. note::

This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will
disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be
used only for debugging purposes. ``limit_train/val/test_batches`` only limits the number of batches and won't
disable anything.

gpus
^^^^

Expand Down Expand Up @@ -1199,8 +1202,7 @@ Orders the progress bar. Useful when running multiple trainers on the same node.
# default used by the Trainer
trainer = Trainer(process_position=0)

Note:
This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`.
.. note:: This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`.

profiler
^^^^^^^^
Expand Down
28 changes: 23 additions & 5 deletions pytorch_lightning/trainer/connectors/debugging_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,34 @@ def on_init_start(
overfit_batches,
fast_dev_run
):
if not isinstance(fast_dev_run, (bool, int)):
raise MisconfigurationException(
f'fast_dev_run={fast_dev_run} is not a valid configuration.'
' It should be either a bool or an int >= 0'
)

if isinstance(fast_dev_run, int) and (fast_dev_run < 0):
raise MisconfigurationException(
f'fast_dev_run={fast_dev_run} is not a'
' valid configuration. It should be >= 0.'
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
)

self.trainer.fast_dev_run = fast_dev_run
if self.trainer.fast_dev_run:
limit_train_batches = 1
limit_val_batches = 1
limit_test_batches = 1
fast_dev_run = int(fast_dev_run)

# set fast_dev_run=True when it is 1, used while logging
if fast_dev_run == 1:
self.trainer.fast_dev_run = True
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

if fast_dev_run:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
limit_train_batches = fast_dev_run
limit_val_batches = fast_dev_run
limit_test_batches = fast_dev_run
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
rank_zero_info(
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
'Running in fast_dev_run mode: will run a full train,'
f' val and test loop using {fast_dev_run} batch(es)'
)

self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):

def log_train_step_metrics(self, batch_output):
# when metrics should be logged
if self.should_update_logs or self.trainer.fast_dev_run:
if self.should_update_logs or self.trainer.fast_dev_run is True:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
# logs user requested information to logger
metrics = self.cached_results.get_latest_batch_log_metrics()
grad_norm_dic = batch_output.grad_norm_dic
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TrainerProperties(ABC):
logger_connector: LoggerConnector
_state: TrainerState
global_rank: int
fast_dev_run: bool
fast_dev_run: Union[int, bool]
use_dp: bool
use_ddp: bool
use_ddp2: bool
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: bool = False,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
max_epochs: int = 1000,
min_epochs: int = 1,
Expand Down Expand Up @@ -191,7 +191,8 @@ def __init__(

distributed_backend: deprecated. Please use 'accelerator'

fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
of train, val and test to find any bugs (ie: a sort of unit test).

flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
def save_loggers_on_train_batch_end(self):
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs or self.trainer.fast_dev_run:
if should_flush_logs or self.trainer.fast_dev_run is True:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def scale_batch_size(trainer,
or datamodule.
"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning)
rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning)
return

if not lightning_hasattr(model, batch_arg_name):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def lr_find(

"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning)
rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning)
return

save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder'])
def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg):
def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
""" Test that tuner algorithms are skipped if fast dev run is enabled """

hparams = EvalModelTemplate.get_default_hparams()
Expand All @@ -16,6 +16,6 @@ def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg):
auto_lr_find=True if tuner_alg == 'learning rate finder' else False,
fast_dev_run=True
)
expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`'
expected_message = f'Skipping {tuner_alg} since fast_dev_run is enabled.'
with pytest.warns(UserWarning, match=expected_message):
trainer.tune(model)
57 changes: 40 additions & 17 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,11 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_dataloaders_with_fast_dev_run(tmpdir):
"""Verify num_batches for train, val & test dataloaders passed with fast_dev_run = True"""

@pytest.mark.parametrize('fast_dev_run', [True, 1, 3, -1, 'temp'])
def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run):
"""
Verify num_batches for train, val & test dataloaders passed with fast_dev_run
"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
Expand All @@ -445,26 +447,47 @@ def test_dataloaders_with_fast_dev_run(tmpdir):
model.test_step = model.test_step__multiple_dataloaders
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

# train, multiple val and multiple test dataloaders passed with fast_dev_run = True
trainer = Trainer(
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=2,
fast_dev_run=True,
fast_dev_run=fast_dev_run,
)
assert trainer.max_epochs == 1
assert trainer.num_sanity_val_steps == 0

trainer.fit(model)
assert not trainer.disable_validation
assert trainer.num_training_batches == 1
assert trainer.num_val_batches == [1] * len(trainer.val_dataloaders)
if fast_dev_run == 'temp':
with pytest.raises(MisconfigurationException, match='either a bool or an int'):
trainer = Trainer(**trainer_options)
elif fast_dev_run == -1:
with pytest.raises(MisconfigurationException, match='should be >= 0'):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(**trainer_options)
else:
trainer = Trainer(**trainer_options)

trainer.test(ckpt_path=None)
assert trainer.num_test_batches == [1] * len(trainer.test_dataloaders)
# fast_dev_run is set to True when it is 1
if fast_dev_run == 1:
fast_dev_run = True

# verify sanity check batches match as expected
num_val_dataloaders = len(model.val_dataloader())
assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders
assert trainer.fast_dev_run is fast_dev_run

if fast_dev_run is True:
fast_dev_run = 1

assert trainer.limit_train_batches == fast_dev_run
assert trainer.limit_val_batches == fast_dev_run
assert trainer.limit_test_batches == fast_dev_run
assert trainer.num_sanity_val_steps == 0
assert trainer.max_epochs == 1

trainer.fit(model)
assert not trainer.disable_validation
assert trainer.num_training_batches == fast_dev_run
assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders)

trainer.test(ckpt_path=None)
assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders)

# verify sanity check batches match as expected
num_val_dataloaders = len(model.val_dataloader())
assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders


@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
Expand Down