Skip to content

Commit

Permalink
Add ReduceScatterBucketer (#94)
Browse files Browse the repository at this point in the history
* Add ReduceScatterBucketer

* Add test

* Move chunk_and_pad to fairscale.utils.parallel and add tests

* Iterate on tests

* CR comments

* more

* Remove most tests to speed up CI iteration cycle

* Fix for Python < 3.8

* more

* Revert "Remove most tests to speed up CI iteration cycle"

This reverts commit a1981ae496b021a767fbf411f95840c55fad8d17.

* CR

* lint
  • Loading branch information
myleott authored Feb 22, 2021
1 parent 43d1f73 commit 77dc364
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 76 deletions.
98 changes: 57 additions & 41 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
unpack_kwargs,
unpack_non_tensors,
)
from fairscale.utils.parallel import validate_process_group
from fairscale.utils.parallel import chunk_and_pad, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer

if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
Expand Down Expand Up @@ -107,6 +108,12 @@ class FullyShardedDataParallel(nn.Module):
move_grads_to_cpu (bool, Optional): move gradient shard to CPU after
reduction. This is useful when combined with CPU-based optimizers.
It defaults to the value of *``cpu_offload``*.
bucket_cap_mb (int, Optional): FSDP will bucket parameters so that
gradient reduction can potentially overlap with backward
computation. bucket_cap_mb controls the bucket size in MegaBytes
(MB). Buckets are sub-divided based on world_size, so the max shard
size is roughly ``bucket_cap_mb / world_size``. Values <= 0 disable
bucketing. Default: 25.
"""

def __init__(
Expand All @@ -120,6 +127,7 @@ def __init__(
cpu_offload: bool = False,
compute_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
):
super().__init__()
self.process_group = process_group or dist.new_group()
Expand All @@ -132,6 +140,7 @@ def __init__(
self.cpu_offload = cpu_offload
self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb

if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
Expand Down Expand Up @@ -405,7 +414,7 @@ def no_sync(self) -> Generator:
"""
self._lazy_init()
assert self._is_root, "no_sync on inner FSDP is not supported"
self.assert_idle()
self.assert_state(TrainingState.IDLE)
# This instance may wrap other FullyShardedDataParallel instances and we
# need to set all of them to accumulate gradients.
old_flags = []
Expand All @@ -423,6 +432,7 @@ def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None

def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
Expand All @@ -438,6 +448,7 @@ def _lazy_init(self) -> None:
if self._is_root is None:
self._set_is_root()
self._setup_streams()

if self.cpu_offload: # Buffers stay on GPU, and don't get sharded
self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype)
else:
Expand Down Expand Up @@ -555,12 +566,16 @@ def _setup_streams(self) -> None:
self._streams["all_gather"] = torch.cuda.Stream()
# Stream for overlapping grad reduction with the backward pass.
self._streams["post_backward"] = torch.cuda.Stream()
# Helper for bucketing reduce-scatter ops. This is also shared with
# children instances to improve bucket utilization.
self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
# We share streams with all children instances, which allows them to
# overlap transfers across the forward pass without synchronizing with
# the default stream.
for n, m in self.named_modules():
if n != "" and isinstance(m, FullyShardedDataParallel):
m._streams = self._streams
m._reducer = self._reducer

def _wait_for_previous_optim_step(self) -> None:
"""
Expand Down Expand Up @@ -679,6 +694,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
alignment is created by :func:`_shard_parameters_`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
self.assert_state(TrainingState.BACKWARD)
if param.grad is None:
return
if param.grad.requires_grad:
Expand Down Expand Up @@ -717,24 +733,18 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.world_size)

callback_fn = functools.partial(self._post_reduction_hook, param)
if param._is_sharded:
# Reduce-scatter grad.
param.grad.data = self._reduce_scatter_grad(param)
elif self.world_size > 1:
# All-reduce non-sharded grad.
dist.all_reduce(param.grad.data, group=self.process_group)

# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
param.grad.data = param.grad.data.to(dtype=param.data.dtype)

# Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU.
if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
param.grad.data = param._cpu_grad
assert param._is_sharded
assert self._reducer is not None
grad_chunks = chunk_and_pad(param.grad.data, self.world_size)
self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn)
else:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert self.world_size == 1
callback_fn(param.grad.data)

# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
Expand All @@ -743,10 +753,34 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
orig_grad_data.record_stream(self._streams["post_backward"])

def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
"""Hook to call on each param after the reduce-scatter."""
assert torch.cuda.current_stream() == self._streams["post_backward"]
assert param.grad is not None
self.assert_state(TrainingState.BACKWARD)
param.grad.data = reduced_grad
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
param.grad.data = param.grad.data.to(dtype=param.data.dtype)
# Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU.
if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
param.grad.data = param._cpu_grad
# Don't let this memory get reused until after the transfers.
reduced_grad.record_stream(torch.cuda.current_stream())

@torch.no_grad()
def _wait_for_post_backward(self) -> None:
"""Wait for post-backward work to finish. Only called on root instance."""
assert self._is_root
self.assert_state(TrainingState.BACKWARD)
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]):
assert self._reducer is not None
self._reducer.flush()
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self.move_grads_to_cpu:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
Expand Down Expand Up @@ -862,29 +896,11 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No
p._fp16_shard.record_stream(current_stream)
free_storage_(p._fp16_shard)

@torch.no_grad()
def _reduce_scatter_grad(self, p: nn.Parameter) -> torch.Tensor:
"""Reduce-scatter a Parameter's gradient and return a single shard of
the summed gradient across workers."""
assert p.grad is not None and p._is_sharded
grad_chunks = list(torch.flatten(p.grad.data).chunk(self.world_size))

# torch.chunk may return fewer than world_size chunks, pad accordingly.
num_pad_for_partial_chunk = grad_chunks[0].numel() - grad_chunks[-1].numel()
if num_pad_for_partial_chunk > 0:
grad_chunks[-1] = F.pad(grad_chunks[-1], [0, num_pad_for_partial_chunk])
if len(grad_chunks) < self.world_size:
grad_chunks.extend([torch.zeros_like(grad_chunks[0])] * (self.world_size - len(grad_chunks)))

output = torch.zeros_like(grad_chunks[0]) # filled with gradient summed across workers
dist.reduce_scatter(output, grad_chunks, group=self.process_group)
return output

def assert_idle(self) -> None:
"""Assert we are in the idle state."""
def assert_state(self, state: TrainingState) -> None:
"""Assert we are in the given state."""
assert (
self.training_state == TrainingState.IDLE
), f"wrong state to call no_sync. current state is {self.training_state}"
self.training_state == state
), f"expected to be in state {state} but current state is {self.training_state}"


@torch.no_grad()
Expand Down
34 changes: 13 additions & 21 deletions fairscale/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,24 @@

"""Useful functions for parallel training."""

from typing import List

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F


def compute_shard_size(numel: int, world_size: int) -> int:
"""Compute shard size like the behavior of torch.chunk()."""
assert numel > 0 and world_size > 0, "invalid inputs"
if numel % world_size == 0:
# easy case, including world_size == 1.
shard_size = numel // world_size
else:
if world_size == 2:
# two shards, shard size is the size of the bigger one.
shard_size = numel // world_size + 1
else:
# find the equal chunks until remainder is smaller than shard_size
for div in range(world_size - 1, 1, -1):
shard_size, rem = divmod(numel, div)
if shard_size >= rem:
break
# corner case: bunch of 1 elements and rest are 0s.
if shard_size == 0:
shard_size = 1
assert shard_size > 0, f"bug: {shard_size}"
return shard_size
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
chunks = list(torch.flatten(tensor).chunk(num_chunks))
# torch.chunk may return fewer than num_chunks chunks, pad accordingly.
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel()
if num_pad_for_partial_chunk > 0:
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk])
if len(chunks) < num_chunks:
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
return chunks


def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None:
Expand Down
151 changes: 151 additions & 0 deletions fairscale/utils/reduce_scatter_bucketer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup


class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup):
self.data = data
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
self.output_shard = torch.zeros_like(data[0])

def flush(self) -> None:
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group
)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.data[:, : self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.data[0])


class ReduceScatterBucketer:
"""
Helper for bucketing multiple reduce-scatter operations on small tensors
into larger reduce-scatter ops to improve communication efficiency.
Usage::
bucketer = ReduceScatterBucketer()
bucketer.reduce_scatter_async(
small_tensors, callback_fn=lambda result: print("small")
)
bucketer.reduce_scatter_async(
big_tensors, callback_fn=lambda result: print("big")
)
bucketer.reduce_scatter_async(
more_small_tensors, callback_fn=lambda result: print("small2")
)
bucketer.flush() # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2
Args:
bucket_cap_mb (int, Optional): bucket size for communicating. Buckets
are sub-divided based on world_size. Values <= 0 disable bucketing.
"""

def __init__(self, bucket_cap_mb: int = 25):
self.bucket_cap_mb = bucket_cap_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}

@torch.no_grad()
def reduce_scatter_async(
self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
can be bucketed together. The given callback (``callback_fn``) will be
called with the reduced result at some later time. Call ``flush()`` to
force all queued ops and callbacks to be executed.
Note that large inputs will be reduced immediately, and this function
may also flush the relevant bucket to make room for ``input_list``.
Args:
input_list (List[Tensor]): list of tensors to reduce-scatter. List
should contain ``group.size()`` tensors and each tensor should
have identical shape, dtype and device.
group (ProcessGroup): process group for reduction
callback_fn (Callable, Optional): callback function to call after
the reduction executes. Function will be called with a single
argument corresponding to the reduced result.
"""
world_size = group.size()

assert (
len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"

first_input = input_list[0]
first_input_size = first_input.numel()

bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size:
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None:
callback_fn(output)
return

bucket = self._get_bucket(first_input, group)
if first_input_size > bucket.data.size(1) - bucket.offset:
# not enough space remaining in bucket, flush it now
bucket.flush()

# copy data from input_list into bucket
stacked_input = torch.stack(input_list).view(world_size, first_input_size)
offset = bucket.offset
bucket.data[:, offset : offset + first_input_size].copy_(stacked_input)
bucket.offset += first_input_size

# callback will be given the reduced result
if callback_fn is not None:
result_view = bucket.output_shard[offset : offset + first_input_size].view_as(first_input)
bucket.callbacks.append(functools.partial(callback_fn, result_view))

@torch.no_grad()
def flush(self) -> None:
"""Reduce-scatter any partial buckets."""
for bucket in self.buckets.values():
bucket.flush()

@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_cap_mb * MB / element_size
return int(bucket_size // num_shards)

def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
world_size = group.size()
shard_size = self._get_shard_size(tensor.element_size(), world_size)
data = tensor.new_zeros((world_size, shard_size))
self.buckets[key] = Bucket(data, group)
return self.buckets[key]
Loading

0 comments on commit 77dc364

Please sign in to comment.