Skip to content

Commit

Permalink
Merge branch 'master' into feat_artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
borisdayma authored Mar 12, 2021
2 parents b59fdf1 + 680e83a commit 13a730b
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 71 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))

- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))

- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))

Expand Down Expand Up @@ -58,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))

Expand Down
26 changes: 17 additions & 9 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,14 @@ Under the hood, Lightning does the following (pseudocode):
loss = training_step(batch)
losses.append(loss.detach())
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# apply and clear grads
# update parameters
optimizer.step()
optimizer.zero_grad()
Training epoch-level metrics
Expand Down Expand Up @@ -212,12 +214,14 @@ Here's the pseudocode of what it does under the hood:
# forward
out = training_step(val_batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# apply and clear grads
# update parameters
optimizer.step()
optimizer.zero_grad()
epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs]))
Expand Down Expand Up @@ -247,12 +251,14 @@ The matching pseudocode is:
# forward
out = training_step(val_batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# apply and clear grads
# update parameters
optimizer.step()
optimizer.zero_grad()
training_epoch_end(outs)
Expand Down Expand Up @@ -946,9 +952,9 @@ When set to ``False``, Lightning does not automate the optimization process. Thi
opt = self.optimizers(use_pl_optimizer=True)
loss = ...
opt.zero_grad()
self.manual_backward(loss)
opt.step()
opt.zero_grad()
This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research.

Expand Down Expand Up @@ -1048,11 +1054,13 @@ This is the pseudocode to describe how all the hooks are called during a call to
loss = out.loss
on_before_zero_grad()
optimizer_zero_grad()
backward()
on_after_backward()
optimizer_step()
on_before_zero_grad()
optimizer_zero_grad()
on_train_batch_end(out)
Expand Down
6 changes: 4 additions & 2 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ Here's the pseudocode for what the trainer does under the hood (showing the trai
# train step
loss = training_step(batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# apply and clear grads
# update parameters
optimizer.step()
optimizer.zero_grad()
losses.append(loss)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/06-mnist-tpu-training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"id": "AYGWh10lRaF1"
},
"source": [
"! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl"
"! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl"
],
"execution_count": null,
"outputs": []
Expand Down
129 changes: 105 additions & 24 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,25 @@ class ModelCheckpoint(Callback):
save_weights_only: if ``True``, then only the model's weights will be
saved (``model.save_weights(filepath)``), else the full model
is saved (``model.save(filepath)``).
every_n_train_steps: Number of training steps between checkpoints.
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative.
This must be mutually exclusive with ``every_n_val_epochs``.
every_n_val_epochs: Number of validation epochs between checkpoints.
If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end
To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``every_n_train_steps``.
Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
will only save checkpoints at epochs 0 < E <= N
where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
period: Interval (number of epochs) between checkpoints.
.. warning::
This argument has been deprecated in v1.3 and will be removed in v1.5.
Use ``every_n_val_epochs`` instead.
Note:
For extra customization, ModelCheckpoint includes the following attributes:
Expand Down Expand Up @@ -166,16 +183,17 @@ def __init__(
save_top_k: Optional[int] = None,
save_weights_only: bool = False,
mode: str = "min",
period: int = 1,
auto_insert_metric_name: bool = True
auto_insert_metric_name: bool = True,
every_n_train_steps: Optional[int] = None,
every_n_val_epochs: Optional[int] = None,
period: Optional[int] = None,
):
super().__init__()
self.monitor = monitor
self.verbose = verbose
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.auto_insert_metric_name = auto_insert_metric_name
self._last_global_step_saved = -1
self.current_score = None
Expand All @@ -189,6 +207,7 @@ def __init__(

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
self.__validate_init_configuration()

def on_pretrain_routine_start(self, trainer, pl_module):
Expand All @@ -198,10 +217,26 @@ def on_pretrain_routine_start(self, trainer, pl_module):
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint

def on_validation_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, *args, **kwargs) -> None:
""" Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """
if self._should_skip_saving_checkpoint(trainer):
return
step = trainer.global_step
skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0)
if skip_batch:
return
self.save_checkpoint(trainer)

def on_validation_end(self, trainer, *args, **kwargs) -> None:
"""
checkpoints can be saved at the end of the val loop
"""
skip = (
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
)
if skip:
return
self.save_checkpoint(trainer)

def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -229,20 +264,8 @@ def save_checkpoint(self, trainer, unused: Optional = None):
" has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning
)

epoch = trainer.current_epoch
global_step = trainer.global_step

from pytorch_lightning.trainer.states import TrainerState
if (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or self._last_global_step_saved == global_step # already saved at the last step
):
return

self._add_backward_monitor_support(trainer)
self._validate_monitor_key(trainer)

Expand All @@ -265,9 +288,32 @@ def save_checkpoint(self, trainer, unused: Optional = None):
if trainer.is_global_zero and trainer.logger and hasattr(trainer.logger, 'after_save_checkpoint'):
trainer.logger.after_save_checkpoint(proxy(self))

def _should_skip_saving_checkpoint(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerState
return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
if self._every_n_train_steps < 0:
raise MisconfigurationException(
f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
)
if self._every_n_val_epochs < 0:
raise MisconfigurationException(
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
)
if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0:
raise MisconfigurationException(
f'Invalid values for every_n_train_steps={self._every_n_train_steps}'
' and every_n_val_epochs={self._every_n_val_epochs}.'
' Both cannot be enabled at the same time.'
)
if self.monitor is None:
# None: save last epoch, -1: save all epochs, 0: nothing is saved
if self.save_top_k not in (None, -1, 0):
Expand Down Expand Up @@ -314,6 +360,46 @@ def __init_monitor_mode(self, monitor, mode):

self.kth_value, self.mode = mode_dict[mode]

def __init_triggers(
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int]
) -> None:

# Default to running once after each validation epoch if neither
# every_n_train_steps nor every_n_val_epochs is set
if every_n_train_steps is None and every_n_val_epochs is None:
self._every_n_val_epochs = 1
self._every_n_train_steps = 0
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
else:
self._every_n_val_epochs = every_n_val_epochs or 0
self._every_n_train_steps = every_n_train_steps or 0

# period takes precedence over every_n_val_epochs for backwards compatibility
if period is not None:
rank_zero_warn(
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)
self._every_n_val_epochs = period

self._period = self._every_n_val_epochs

@property
def period(self) -> Optional[int]:
rank_zero_warn(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)
return self._period

@period.setter
def period(self, value: Optional[int]) -> None:
rank_zero_warn(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)
self._period = value

@rank_zero_only
def _del_model(self, filepath: str):
if self._fs.exists(filepath):
Expand Down Expand Up @@ -427,11 +513,8 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
"""
filename = self._format_checkpoint_name(
self.filename,
epoch,
step,
metrics,
auto_insert_metric_name=self.auto_insert_metric_name)
self.filename, epoch, step, metrics, auto_insert_metric_name=self.auto_insert_metric_name
)

if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
Expand Down Expand Up @@ -586,9 +669,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A
self._save_model(trainer, filepath)

if (
self.save_top_k is None
and self.best_model_path
and self.best_model_path != filepath
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
and trainer.is_global_zero
):
self._del_model(self.best_model_path)
Expand Down
14 changes: 5 additions & 9 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
import platform
from abc import ABC
from copy import deepcopy
from typing import Callable, Iterable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, Tuple, Union

from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -36,8 +37,6 @@ class TrainerDataLoadingMixin(ABC):

# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
global_rank: int
shown_warnings:...
val_check_interval: float
tpu_local_core_rank: int
train_dataloader: DataLoader
Expand All @@ -48,13 +47,10 @@ class TrainerDataLoadingMixin(ABC):
test_dataloaders: List[DataLoader]
num_test_batches: List[Union[int, float]]
limit_train_batches: Union[int, float]
limit_val_batches: Union[int, float]
limit_test_batches: Union[int, float]
replace_sampler_ddp: bool
overfit_batches: Union[int, float]
distributed_sampler_kwargs: dict
accelerator: Accelerator
num_nodes: int
num_processes: int
distributed_backend: Optional[str]
accelerator_connector: AcceleratorConnector
dev_debugger: InternalDebugger

def _worker_check(self, dataloader: DataLoader, name: str) -> None:
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

class DeprecatedDistDeviceAttributes:

_distrib_type: DistributedType
_device_type: DeviceType
num_gpus: int
accelerator_connector: AcceleratorConnector

Expand Down Expand Up @@ -135,7 +133,7 @@ def use_single_gpu(self, val: bool) -> None:
class DeprecatedTrainerAttributes:

accelerator: Accelerator
lightning_module = LightningModule
lightning_module: LightningModule
sanity_checking: bool

@property
Expand Down
Loading

0 comments on commit 13a730b

Please sign in to comment.