Skip to content

Commit

Permalink
Merge branch 'master' into hyperparameters_for_datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Jul 9, 2021
2 parents 43c75fe + 3102922 commit d3d5cf7
Show file tree
Hide file tree
Showing 28 changed files with 539 additions and 44 deletions.
13 changes: 11 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))


- Added the `on_before_optimizer_step` hook ([#8048](https://github.com/PyTorchLightning/pytorch-lightning/pull/8048))


- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))


Expand Down Expand Up @@ -149,6 +152,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792))


- Added `LSFEnvironment` for distributed training with the LSF resource manager `jsrun` ([#5102](https://github.com/PyTorchLightning/pytorch-lightning/pull/5102))


### Changed


Expand Down Expand Up @@ -247,10 +253,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))


- The `on_after_backward` hook is now called on accumulating iterations ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The `on_after_backward` hook is now called on accumulating iterations. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The mixed precision loss is no longer unscaled before the `on_after_backward` hook ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The mixed precision loss is no longer unscaled before the `on_after_backward` hook. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The `TrainingTypePlugin.{pre,post}_backward` hooks no longer take the `optimizer, opt_idx, should_accumulate` arguments ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
Expand All @@ -262,6 +268,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `PrecisionPlugin.backward` hooks no longer takes a `should_accumulate` argument ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- Added the `on_before_backward` hook ([#7865](https://github.com/PyTorchLightning/pytorch-lightning/pull/7865))


- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963))


Expand Down
14 changes: 14 additions & 0 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1191,9 +1191,11 @@ for more information.
on_before_zero_grad()
optimizer_zero_grad()
on_before_backward()
backward()
on_after_backward()
on_before_optimizer_step()
optimizer_step()
on_train_batch_end()
Expand Down Expand Up @@ -1246,6 +1248,12 @@ get_progress_bar_dict
.. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict
:noindex:

on_before_backward
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_backward
:noindex:

on_after_backward
~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -1444,6 +1452,12 @@ on_test_model_train
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
:noindex:

on_before_optimizer_step
~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step
:noindex:

optimizer_step
~~~~~~~~~~~~~~

Expand Down
12 changes: 12 additions & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,24 @@ on_load_checkpoint
.. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint
:noindex:

on_before_backward
^^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_before_backward
:noindex:

on_after_backward
^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
:noindex:

on_before_optimizer_step
^^^^^^^^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_before_optimizer_step
:noindex:

on_before_zero_grad
^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/source/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Cluster Environments

ClusterEnvironment
LightningEnvironment
LSFEnvironment
TorchElasticEnvironment
KubeflowEnvironment
SLURMEnvironment
20 changes: 20 additions & 0 deletions docs/source/guides/speed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ This by default comes with a performance hit, and can be disabled in most cases.
plugins=DDPPlugin(find_unused_parameters=False),
)
When using DDP on a multi-node cluster, set NCCL parameters
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`NCCL <https://developer.nvidia.com/nccl>`__ is the NVIDIA Collective Communications Library which is used under the hood by PyTorch to handle communication across nodes and GPUs. There are reported benefits in terms of speedups when adjusting NCCL parameters as seen in this `issue <https://github.com/PyTorchLightning/pytorch-lightning/issues/7179>`__. In the issue we see a 30% speed improvement when training the Transformer XLM-RoBERTa and a 15% improvement in training with Detectron2.

NCCL parameters can be adjusted via environment variables.

.. note::

AWS and GCP already set default values for these on their clusters. This is typically useful for custom cluster setups.

* `NCCL_NSOCKS_PERTHREAD <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nsocks-perthread>`__
* `NCCL_SOCKET_NTHREADS <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-socket-nthreads>`__
* `NCCL_MIN_NCHANNELS <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-min-nchannels>`__

.. code-block:: bash
export NCCL_NSOCKS_PERTHREAD=4
export NCCL_SOCKET_NTHREADS=2
Dataloaders
^^^^^^^^^^^
When building your DataLoader set ``num_workers > 0`` and ``pin_memory=True`` (only for GPUs).
Expand Down
13 changes: 12 additions & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import abc
from typing import Any, Dict, List, Optional

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -296,8 +297,18 @@ def on_load_checkpoint(
"""
pass

def on_before_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', loss: torch.Tensor) -> None:
"""Called before ``loss.backward()``."""
pass

def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called after ``loss.backward()`` and before optimizers do anything."""
"""Called after ``loss.backward()`` and before optimizers are stepped."""
pass

def on_before_optimizer_step(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer, opt_idx: int
) -> None:
"""Called before ``optimizer.step()``."""
pass

def on_before_zero_grad(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer) -> None:
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def __init__(
on_keyboard_interrupt: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_optimizer_step: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
on_predict_start: Optional[Callable] = None,
on_predict_end: Optional[Callable] = None,
Expand Down
34 changes: 30 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,21 +295,47 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None:
optimizer: The optimizer for which grads should be zeroed.
"""

def on_before_backward(self, loss: torch.Tensor) -> None:
"""
Called before ``loss.backward()``.
Args:
loss: Loss divided by number of batches for gradient accumulation and scaled if using native AMP.
"""
pass

def on_after_backward(self) -> None:
"""
Called in the training loop after loss.backward() and before optimizers do anything.
This is the ideal place to inspect or log gradient information.
Called after ``loss.backward()`` and before optimizers are stepped.
Note:
If using native AMP, the gradients will not be unscaled at this point.
Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
"""

def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""
Called before ``optimizer.step()``.
The hook is only called if gradients do not need to be accumulated.
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
If using native AMP, the loss will be unscaled before calling this hook.
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
for more information on the scaling of gradients.
Args:
optimizer: Current optimizer being used.
optimizer_idx: Index of the current optimizer being used.
Example::
def on_after_backward(self):
def on_before_optimizer_step(self, optimizer, optimizer_idx):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
for k, v in self.named_parameters():
self.logger.experiment.add_histogram(
tag=k, values=v.grad, global_step=self.trainer.global_step
)
"""

def on_post_move_to_device(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.lsf_environment import LSFEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
160 changes: 160 additions & 0 deletions pytorch_lightning/plugins/environments/lsf_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import socket

from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.environments import ClusterEnvironment


class LSFEnvironment(ClusterEnvironment):
"""
An environment for running on clusters managed by the LSF resource manager.
It is expected that any execution using this ClusterEnvironment was executed
using the Job Step Manager i.e. ``jsrun``.
This plugin expects the following environment variables.
LSB_JOBID:
The LSF assigned job ID
LSB_HOSTS:
The hosts used in the job. This string is expected to have the format "batch <rank_0_host> ...."
JSM_NAMESPACE_LOCAL_RANK:
The node local rank for the task. This environment variable is set by jsrun
JSM_NAMESPACE_SIZE:
The world size for the task. This environment variable is set by jsrun
"""

def __init__(self):
self._master_address = self._get_master_address()
self._master_port = self._get_master_port()
log.debug(f"MASTER_ADDR: {self._master_address}")
log.debug(f"MASTER_PORT: {self._master_port}")

@staticmethod
def is_using_lsf() -> bool:
""" Returns ``True`` if the current process was launched using the jsrun command. """
required_env_vars = (
"LSB_JOBID",
"LSB_HOSTS",
"JSM_NAMESPACE_LOCAL_RANK",
"JSM_NAMESPACE_SIZE",
)
return all(v in os.environ for v in required_env_vars)

def creates_children(self) -> bool:
return True

def master_address(self):
""" The master address is read from a list of hosts contained in the environment variable `LSB_HOSTS`. """
return self._master_address

def master_port(self):
""" THe master port gets calculated from the LSF job ID. """
return self._master_port

def world_size(self):
""" The world size is read from the environment variable `JSM_NAMESPACE_SIZE`. """
var = "JSM_NAMESPACE_SIZE"
world_size = os.environ.get(var)
if world_size is None:
raise ValueError(
f"Cannot determine world size from environment variable {var}."
" Make sure you run your executable with `jsrun`"
)
return int(world_size)

def set_world_size(self, size: int) -> None:
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self):
""" The world size is read from the environment variable `JSM_NAMESPACE_RANK`. """
var = "JSM_NAMESPACE_RANK"
global_rank = os.environ.get(var)
if global_rank is None:
raise ValueError(
f"Cannot determine global rank from environment variable {var}."
" Make sure you run your executable with `jsrun`"
)
return int(global_rank)

def set_global_rank(self, rank: int) -> None:
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self):
""" The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`. """
var = "JSM_NAMESPACE_LOCAL_RANK"
local_rank = os.environ.get(var)
if local_rank is None:
raise ValueError(
f"Cannot determine local rank from environment variable {var}."
" Make sure you run your executable with `jsrun`"
)
return int(local_rank)

def node_rank(self):
"""
The node rank is determined by the position of the current hostname in the list of hosts stored in
the environment variable `LSB_HOSTS`.
"""
hosts = self._read_hosts()
count = dict()
for host in hosts:
if "batch" in host or "login" in host:
continue
if host not in count:
count[host] = len(count)
return count[socket.gethostname()]

@staticmethod
def _read_hosts():
hosts = os.environ.get("LSB_HOSTS")
if not hosts:
raise ValueError("Could not find hosts in environment variable LSB_HOSTS")
hosts = hosts.split()
if len(hosts) < 2:
raise ValueError(
"Cannot parse hosts from LSB_HOSTS environment variable."
" Expected format: \"batch <rank_0_host> ...\""
)
return hosts

def _get_master_address(self):
hosts = self._read_hosts()
return hosts[1]

@staticmethod
def _get_master_port():
"""
A helper function for accessing the master port.
Uses the LSF job ID so all ranks can compute the master port.
"""
# check for user-specified master port
port = os.environ.get("MASTER_PORT")
if not port:
jobid = os.environ.get("LSB_JOBID")
if not jobid:
raise ValueError("Could not find job id in environment variable LSB_JOBID")
port = int(jobid)
# all ports should be in the 10k+ range
port = int(port) % 1000 + 10000
log.debug(f"calculated LSF master port: {port}")
else:
log.debug(f"using externally specified master port: {port}")
return int(port)
Loading

0 comments on commit d3d5cf7

Please sign in to comment.