diff --git a/CHANGELOG.md b/CHANGELOG.md index cc0cf47e1f9ef..3585b3695be34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## Unreleased +### Added + +- Added `all_gather` method to `LightningModule` which allows gradient based tensor synchronizations for use-cases such as negative sampling. ([#5012](https://github.com/PyTorchLightning/pytorch-lightning/pull/5012)) + ### Fixed - Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 11af9e4d8f91e..5a10f21d21a1b 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -172,6 +172,20 @@ def sync_tensor(self, """ raise NotImplementedError() + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + raise NotImplementedError() + def optimizer_state(self, optimizer: Optimizer) -> dict: """ Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index f47b389faf436..2e3e39a5cd4d4 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -25,7 +25,7 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available if HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig @@ -234,6 +234,20 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index d3d4c1fa1b766..942d66bc029e9 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -29,6 +29,7 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -333,6 +334,20 @@ def sync_tensor(self, """ return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 50bd1b7ab9051..f109f555f575e 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -31,6 +31,7 @@ rank_zero_only, rank_zero_warn, sync_ddp_if_available, + all_gather_ddp_if_available, ) if HYDRA_AVAILABLE: @@ -261,6 +262,20 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 50267afa525dc..5f09189e8b42c 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -25,7 +25,7 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available if HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig @@ -225,6 +225,20 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 7db2d5a309d9c..d768c3b6fbdc3 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -34,6 +34,7 @@ rank_zero_only, rank_zero_warn, sync_ddp_if_available, + all_gather_ddp_if_available, ) from pytorch_lightning.utilities.seed import seed_everything @@ -293,6 +294,20 @@ def sync_tensor(self, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 671084cb2fac7..5acec2b86722c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -365,6 +365,24 @@ def __auto_choose_log_on_epoch(self, on_epoch): return on_epoch + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + r""" + Allows users to call ``self.all_gather()`` from the LightningModule, thus making + the ```all_gather``` operation accelerator agnostic. + + ```all_gather``` is a function provided by accelerators to gather a tensor from several + distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads) + def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 3a04a325905a9..7869690dea98b 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -21,6 +21,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.distributed import AllGatherGrad from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index ffa1be87cd3ca..9724f05247c00 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -22,10 +22,14 @@ if torch.distributed.is_available(): from torch.distributed import ReduceOp + from torch.distributed import group else: class ReduceOp: SUM = None + class group: + WORLD = None + def rank_zero_only(fn): @@ -155,3 +159,54 @@ def sync_ddp( result = result / torch.distributed.get_world_size(group) return result + + +class AllGatherGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, group=group.WORLD): + ctx.group = group + + gathered_tensor = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + + torch.distributed.all_gather(gathered_tensor, tensor, group=group) + gathered_tensor = torch.stack(gathered_tensor, dim=0) + + return gathered_tensor + + @staticmethod + def backward(ctx, *grad_output): + grad_output = torch.cat(grad_output) + + torch.distributed.all_reduce( + grad_output, + op=torch.distributed.ReduceOp.SUM, + async_op=False, + group=ctx.group + ) + + return grad_output[torch.distributed.get_rank()] + + +def all_gather_ddp_if_available( + tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False +) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if sync_grads: + return AllGatherGrad.apply(tensor, group) + else: + with torch.no_grad: + return AllGatherGrad.apply(tensor, group) + return tensor diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py new file mode 100644 index 0000000000000..66e50776edd3f --- /dev/null +++ b/tests/utilities/test_all_gather_grad.py @@ -0,0 +1,44 @@ +import os +import pytest +import sys +import torch +import torch.nn as nn + +from pytorch_lightning.utilities import AllGatherGrad + + +def setup_ddp(rank, world_size): + """ Setup ddp enviroment """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" + + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + +def _test_all_gather_ddp(rank, world_size): + setup_ddp(rank, world_size) + + tensor1 = torch.ones(8, requires_grad=True) + tensor2 = torch.ones((8, 16, 32), requires_grad=True) + + tensor1_gathered = AllGatherGrad.apply(tensor1) + tensor2_gathered = AllGatherGrad.apply(tensor2) + + tensor1_gathered = tensor1_gathered * rank + tensor2_gathered = tensor2_gathered * rank + + tensor1_gathered.sum().backward() + tensor2_gathered.sum().backward() + + grad1 = torch.zeros_like(tensor1.grad).fill_(torch.arange(world_size).sum().float()) + grad2 = torch.zeros_like(tensor2.grad).fill_(torch.arange(world_size).sum().float()) + + assert torch.allclose(grad1, tensor1.grad) + assert torch.allclose(grad2, tensor2.grad) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_all_gather_ddp(): + world_size = 3 + torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)