Skip to content

Commit

Permalink
Separate PR for FSDP using Reduce-Scatter with bucketing/coalescing
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws authored and Arjunbala committed Dec 10, 2023
1 parent a80c1e7 commit 66636c8
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 64 deletions.
87 changes: 73 additions & 14 deletions torch_xla/distributed/fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,80 @@ def dummy_all_reduce(reduce_type, inputs, scale=1.0, groups=None):
return [t.mul_(scale) for t in inputs]


def dummy_reduce_scatter(reduce_type,
input,
scale,
scatter_dim,
shard_count,
groups=None):
class DummyReduceScatter:
"""A dummy op for debugging with the same output shape as reduce_scatter"""
assert shard_count == xm.xrt_world_size()
full_size = input.size(scatter_dim)
shard_size = full_size // xm.xrt_world_size()
begin = shard_size * xm.get_ordinal()
end = begin + shard_size
slices = [None] * input.dim()
slices[scatter_dim] = slice(begin, end)
return input[tuple(slices)] * scale

def __init__(self, shard_count):
assert shard_count == xm.xrt_world_size()
self.scale = 1.0

def __call__(self, input, callback):
full_size = input.size(0)
shard_size = full_size // xm.xrt_world_size()
begin = shard_size * xm.get_ordinal()
end = begin + shard_size
slices = [None] * input.dim()
slices[0] = slice(begin, end)
callback(input[tuple(slices)])

def flush(self):
pass


class BucketizedReduceScatter:
"""A reduce_scatter op that group input tensors before reduce-scattering them."""

def __init__(self, bucket_size_mb, shard_count, groups, pin_layout) -> None:
self.bucket_size_bytes = bucket_size_mb * 1024 * 1024
self.shard_count = shard_count
self.groups = groups
self.pin_layout = pin_layout
self.scale = 1.0

self.callbacks = []
self.bucket = []
self.bucket_watermark = 0

def __call__(self, input, callback):
input_byte_size = input.element_size() * input.numel()
self.bucket.append(input)
self.callbacks.append(callback)
self.bucket_watermark += input_byte_size
# If bucket_size_mb is default 0, flush for every tensor rather than coalesce
if self.bucket_watermark > self.bucket_size_bytes:
self.flush()

def flush(self):
if not self.bucket:
return
# TODO: debug coalesce error "" for GPU when pin_layout=True.
# For now, workaround by using the non-coalesce version of reduce-scatter
# when there's only 1 tensor input (bucket_size_mb=0).
if len(self.bucket) == 1:
result = xm.reduce_scatter(
xm.REDUCE_SUM,
self.bucket[0],
scale=self.scale,
scatter_dim=0,
shard_count=self.shard_count,
groups=self.groups,
pin_layout=self.pin_layout)
self.callbacks[0](result)
else:
results = xm.reduce_scatter(
xm.REDUCE_SUM,
self.bucket,
scale=self.scale,
scatter_dim=0,
shard_count=self.shard_count,
groups=self.groups,
pin_layout=self.pin_layout)
for cb, result in zip(self.callbacks, results):
cb(result)

self.bucket.clear()
self.callbacks.clear()
self.bucket_watermark = 0


class XLAPatchedLinear(torch.autograd.Function):
Expand Down
121 changes: 71 additions & 50 deletions torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
import torch_xla.core.xla_model as xm

from .xla_flatten_params_wrapper import XlaFlattenParamsWrapper
from .utils import dummy_all_gather, dummy_all_reduce, dummy_reduce_scatter, apply_xla_patch_to_nn_linear
from .utils import (
BucketizedReduceScatter,
DummyReduceScatter,
dummy_all_gather,
dummy_all_reduce,
apply_xla_patch_to_nn_linear,
)

from .wrap import recursive_wrap
from ._init_utils import _materialize_module

Expand Down Expand Up @@ -295,6 +302,7 @@ def __init__(
sharding_world_size: Optional[int] = None,
shard_param_on_dim_0: bool = False,
pin_layout_in_collective_ops: bool = True,
reduce_scatter_bucket_size_mb: Optional[int] = 0,
coalesce_all_gather_ops: bool = False,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
Expand Down Expand Up @@ -398,6 +406,20 @@ def __init__(
# When `_shard_param_on_dim_0` is True, we shard and all-gather model parameter tensors
# only along their dim 0 without flattening the parameter
self._shard_param_on_dim_0 = shard_param_on_dim_0 and not flatten_parameters
# Allow specifying groups for the sharding collective ops, useful for mixing
# FSDP data parallelism with model parallelism (e.g. Megatron)
self.sharding_groups = sharding_groups
if sharding_groups is None:
self.rank = xm.get_ordinal()
self.world_size = xm.xrt_world_size()
else:
if sharding_rank is None or sharding_world_size is None:
raise ValueError(
"sharding_rank and sharding_world_size must be provided when sharding_groups is specified"
)
self.rank = sharding_rank
self.world_size = sharding_world_size

self.coalesce_all_gather_ops = coalesce_all_gather_ops
# Set layout pinning to False in all_gather, all_reduce, and reduce_scatter so that they can work together
# TODO (ronghanghu): change the default layout pinning to True after it's supported simultaneously
Expand All @@ -413,29 +435,18 @@ def __init__(
self.all_reduce_op = functools.partial(
xm.all_reduce, pin_layout=pin_layout_in_collective_ops)
if _debug_dummy_reduce_scatter_op:
self.reduce_scatter_op = dummy_reduce_scatter
self.reduce_scatter_op = DummyReduceScatter(shard_count=self.world_size)
else:
self.reduce_scatter_op = functools.partial(
xm.reduce_scatter, pin_layout=pin_layout_in_collective_ops)
self.reduce_scatter_op = BucketizedReduceScatter(
reduce_scatter_bucket_size_mb,
shard_count=self.world_size,
groups=self.sharding_groups,
pin_layout=pin_layout_in_collective_ops)
if _debug_dummy_optimization_barrier_op:
self.optimization_barrier_op = lambda *args: None
else:
self.optimization_barrier_op = xm.optimization_barrier_

# Allow specifying groups for the sharding collective ops, useful for mixing
# FSDP data parallelism with model parallelism (e.g. Megatron)
self.sharding_groups = sharding_groups
if sharding_groups is None:
self.rank = xm.get_ordinal()
self.world_size = xm.xrt_world_size()
else:
if sharding_rank is None or sharding_world_size is None:
raise ValueError(
"sharding_rank and sharding_world_size must be provided when sharding_groups is specified"
)
self.rank = sharding_rank
self.world_size = sharding_world_size

# Options for debugging
# - set _debug_dummy_forward_pass=True to check for parameter-only memory consumption
# - set _debug_msg="xxx" and _debug_print=True to distinguish different FSDP instance
Expand Down Expand Up @@ -554,6 +565,10 @@ def set_gradient_divide_factors(self, pre: float, post: float,
module.set_gradient_divide_factors(pre, post, False)
self.gradient_predivide_factor = pre
self.gradient_postdivide_factor = post
if (pre, post) == (1, 1):
self.reduce_scatter_op.scale = 1.0 / self.world_size
else:
self.reduce_scatter_op.scale = 1.0

@property
def module(self) -> XlaFlattenParamsWrapper:
Expand Down Expand Up @@ -1144,6 +1159,7 @@ def _register_post_backward_hooks(self) -> None:
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
self._post_backward_hooks_to_call = 0
for p in self.full_params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
Expand All @@ -1157,6 +1173,7 @@ def _register_post_backward_hooks(self) -> None:
handle = grad_acc.register_hook(
functools.partial(self._post_backward_hook, p))
p._shard_bwd_hook = (grad_acc, handle)
self._post_backward_hooks_to_call += 1

@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
Expand All @@ -1183,7 +1200,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
self._post_backward_hooks_to_call -= 1
if param.grad is None:
if self._post_backward_hooks_to_call == 0:
self.reduce_scatter_op.flush()
return

assert param.grad is not None, param.shape
Expand All @@ -1204,6 +1224,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
apply_opt_barrier=self.optimization_barrier_in_backward)

if not self._require_backward_grad_sync:
if self._post_backward_hooks_to_call == 0:
self.reduce_scatter_op.flush()
return

if self.gradient_predivide_factor > 1:
Expand All @@ -1219,38 +1241,37 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
self.optimization_barrier_op([grad_flat])
if grad_flat.dtype != torch.float32 and self.fp32_reduce_scatter:
grad_flat = grad_flat.to(torch.float32)
reduced_grad = self.reduce_scatter_op(
xm.REDUCE_SUM,
grad_flat.detach(),
scale=1.0,
scatter_dim=0,
shard_count=self.world_size,
groups=self.sharding_groups)
if reduced_grad.dtype != torch.float32:
reduced_grad = reduced_grad.to(torch.float32)
if self.optimization_barrier_in_backward:
self.optimization_barrier_op([reduced_grad])
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.div_(self.gradient_postdivide_factor)

grad._has_full_param = True
grad_flat._has_full_param = True
self._free_full_params(
[grad, grad_flat],
dependency_tensors=[reduced_grad],
apply_opt_barrier=self.optimization_barrier_in_backward)
self._try_adding_to_backward_opt_barrier_lists(reduced_grad)

# Accumulate into the gradient shard.
assert hasattr(param, "_sharded_param")
p_shard = param._sharded_param
if p_shard.grad is None:
p_shard.grad = reduced_grad
else:
assert p_shard.grad.shape == reduced_grad.shape
assert p_shard.grad.device == reduced_grad.device
p_shard.grad += reduced_grad

def reduce_scatter_done(reduced_grad):
if reduced_grad.dtype != torch.float32:
reduced_grad = reduced_grad.to(torch.float32)
if self.optimization_barrier_in_backward:
self.optimization_barrier_op([reduced_grad])
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.div_(self.gradient_postdivide_factor)

grad._has_full_param = True
grad_flat._has_full_param = True
self._free_full_params(
[grad, grad_flat],
dependency_tensors=[reduced_grad],
apply_opt_barrier=self.optimization_barrier_in_backward)
self._try_adding_to_backward_opt_barrier_lists(reduced_grad)

# Accumulate into the gradient shard.
assert hasattr(param, "_sharded_param")
p_shard = param._sharded_param
if p_shard.grad is None:
p_shard.grad = reduced_grad
else:
assert p_shard.grad.shape == reduced_grad.shape
assert p_shard.grad.device == reduced_grad.device
p_shard.grad += reduced_grad

self.reduce_scatter_op(grad_flat.detach(), reduce_scatter_done)
if self._post_backward_hooks_to_call == 0:
self.reduce_scatter_op.flush()

def _queue_wait_for_post_backward(self) -> None:
"""
Expand Down

0 comments on commit 66636c8

Please sign in to comment.