Skip to content

Commit

Permalink
All gatherwith grads (Lightning-AI#5012)
Browse files Browse the repository at this point in the history
* all_gather

* ddp

* horovod

* grad tests

* fixed ddp

* ddp fixed, removed tpu, horovod for now

* changelog

* windows fix

* windows fix

* removed batch from ctx

* all_gather

* ddp

* horovod

* grad tests

* fixed ddp

* ddp fixed, removed tpu, horovod for now

* changelog

* windows fix

* windows fix

* removed batch from ctx

* removed code duplication

* merge

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
ananyahjha93 and Borda authored Dec 8, 2020
1 parent ee9b3fe commit 127454a
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## Unreleased

### Added

- Added `all_gather` method to `LightningModule` which allows gradient based tensor synchronizations for use-cases such as negative sampling. ([#5012](https://github.com/PyTorchLightning/pytorch-lightning/pull/5012))

### Fixed

- Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138))
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ def sync_tensor(self,
"""
raise NotImplementedError()

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
raise NotImplementedError()

def optimizer_state(self, optimizer: Optimizer) -> dict:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -234,6 +234,20 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
Expand Down Expand Up @@ -333,6 +334,20 @@ def sync_tensor(self,
"""
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)

if HYDRA_AVAILABLE:
Expand Down Expand Up @@ -261,6 +262,20 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -225,6 +225,20 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)
from pytorch_lightning.utilities.seed import seed_everything

Expand Down Expand Up @@ -293,6 +294,20 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)

Expand Down
18 changes: 18 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,24 @@ def __auto_choose_log_on_epoch(self, on_epoch):

return on_epoch

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
r"""
Allows users to call ``self.all_gather()`` from the LightningModule, thus making
the ```all_gather``` operation accelerator agnostic.
```all_gather``` is a function provided by accelerators to gather a tensor from several
distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads)

def forward(self, *args, **kwargs):
r"""
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import AllGatherGrad
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils

Expand Down
55 changes: 55 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@

if torch.distributed.is_available():
from torch.distributed import ReduceOp
from torch.distributed import group
else:
class ReduceOp:
SUM = None

class group:
WORLD = None


def rank_zero_only(fn):

Expand Down Expand Up @@ -155,3 +159,54 @@ def sync_ddp(
result = result / torch.distributed.get_world_size(group)

return result


class AllGatherGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, group=group.WORLD):
ctx.group = group

gathered_tensor = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]

torch.distributed.all_gather(gathered_tensor, tensor, group=group)
gathered_tensor = torch.stack(gathered_tensor, dim=0)

return gathered_tensor

@staticmethod
def backward(ctx, *grad_output):
grad_output = torch.cat(grad_output)

torch.distributed.all_reduce(
grad_output,
op=torch.distributed.ReduceOp.SUM,
async_op=False,
group=ctx.group
)

return grad_output[torch.distributed.get_rank()]


def all_gather_ddp_if_available(
tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False
) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if sync_grads:
return AllGatherGrad.apply(tensor, group)
else:
with torch.no_grad:
return AllGatherGrad.apply(tensor, group)
return tensor
44 changes: 44 additions & 0 deletions tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import pytest
import sys
import torch
import torch.nn as nn

from pytorch_lightning.utilities import AllGatherGrad


def setup_ddp(rank, world_size):
""" Setup ddp enviroment """
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "8088"

if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)


def _test_all_gather_ddp(rank, world_size):
setup_ddp(rank, world_size)

tensor1 = torch.ones(8, requires_grad=True)
tensor2 = torch.ones((8, 16, 32), requires_grad=True)

tensor1_gathered = AllGatherGrad.apply(tensor1)
tensor2_gathered = AllGatherGrad.apply(tensor2)

tensor1_gathered = tensor1_gathered * rank
tensor2_gathered = tensor2_gathered * rank

tensor1_gathered.sum().backward()
tensor2_gathered.sum().backward()

grad1 = torch.zeros_like(tensor1.grad).fill_(torch.arange(world_size).sum().float())
grad2 = torch.zeros_like(tensor2.grad).fill_(torch.arange(world_size).sum().float())

assert torch.allclose(grad1, tensor1.grad)
assert torch.allclose(grad2, tensor2.grad)


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_all_gather_ddp():
world_size = 3
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)

0 comments on commit 127454a

Please sign in to comment.