diff --git a/torch_xla/distributed/fsdp/utils.py b/torch_xla/distributed/fsdp/utils.py index ee79bc4a966..099688c46c9 100644 --- a/torch_xla/distributed/fsdp/utils.py +++ b/torch_xla/distributed/fsdp/utils.py @@ -59,21 +59,80 @@ def dummy_all_reduce(reduce_type, inputs, scale=1.0, groups=None): return [t.mul_(scale) for t in inputs] -def dummy_reduce_scatter(reduce_type, - input, - scale, - scatter_dim, - shard_count, - groups=None): +class DummyReduceScatter: """A dummy op for debugging with the same output shape as reduce_scatter""" - assert shard_count == xm.xrt_world_size() - full_size = input.size(scatter_dim) - shard_size = full_size // xm.xrt_world_size() - begin = shard_size * xm.get_ordinal() - end = begin + shard_size - slices = [None] * input.dim() - slices[scatter_dim] = slice(begin, end) - return input[tuple(slices)] * scale + + def __init__(self, shard_count): + assert shard_count == xm.xrt_world_size() + self.scale = 1.0 + + def __call__(self, input, callback): + full_size = input.size(0) + shard_size = full_size // xm.xrt_world_size() + begin = shard_size * xm.get_ordinal() + end = begin + shard_size + slices = [None] * input.dim() + slices[0] = slice(begin, end) + callback(input[tuple(slices)]) + + def flush(self): + pass + + +class BucketizedReduceScatter: + """A reduce_scatter op that group input tensors before reduce-scattering them.""" + + def __init__(self, bucket_size_mb, shard_count, groups, pin_layout) -> None: + self.bucket_size_bytes = bucket_size_mb * 1024 * 1024 + self.shard_count = shard_count + self.groups = groups + self.pin_layout = pin_layout + self.scale = 1.0 + + self.callbacks = [] + self.bucket = [] + self.bucket_watermark = 0 + + def __call__(self, input, callback): + input_byte_size = input.element_size() * input.numel() + self.bucket.append(input) + self.callbacks.append(callback) + self.bucket_watermark += input_byte_size + # If bucket_size_mb is default 0, flush for every tensor rather than coalesce + if self.bucket_watermark > self.bucket_size_bytes: + self.flush() + + def flush(self): + if not self.bucket: + return + # TODO: debug coalesce error "" for GPU when pin_layout=True. + # For now, workaround by using the non-coalesce version of reduce-scatter + # when there's only 1 tensor input (bucket_size_mb=0). + if len(self.bucket) == 1: + result = xm.reduce_scatter( + xm.REDUCE_SUM, + self.bucket[0], + scale=self.scale, + scatter_dim=0, + shard_count=self.shard_count, + groups=self.groups, + pin_layout=self.pin_layout) + self.callbacks[0](result) + else: + results = xm.reduce_scatter( + xm.REDUCE_SUM, + self.bucket, + scale=self.scale, + scatter_dim=0, + shard_count=self.shard_count, + groups=self.groups, + pin_layout=self.pin_layout) + for cb, result in zip(self.callbacks, results): + cb(result) + + self.bucket.clear() + self.callbacks.clear() + self.bucket_watermark = 0 class XLAPatchedLinear(torch.autograd.Function): diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index dae259b6fb7..c39d012154f 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -35,7 +35,14 @@ import torch_xla.core.xla_model as xm from .xla_flatten_params_wrapper import XlaFlattenParamsWrapper -from .utils import dummy_all_gather, dummy_all_reduce, dummy_reduce_scatter, apply_xla_patch_to_nn_linear +from .utils import ( + BucketizedReduceScatter, + DummyReduceScatter, + dummy_all_gather, + dummy_all_reduce, + apply_xla_patch_to_nn_linear, +) + from .wrap import recursive_wrap from ._init_utils import _materialize_module @@ -295,6 +302,7 @@ def __init__( sharding_world_size: Optional[int] = None, shard_param_on_dim_0: bool = False, pin_layout_in_collective_ops: bool = True, + reduce_scatter_bucket_size_mb: Optional[int] = 0, coalesce_all_gather_ops: bool = False, auto_wrap_policy: Optional[Callable] = None, auto_wrapper_callable: Optional[Callable] = None, @@ -398,6 +406,20 @@ def __init__( # When `_shard_param_on_dim_0` is True, we shard and all-gather model parameter tensors # only along their dim 0 without flattening the parameter self._shard_param_on_dim_0 = shard_param_on_dim_0 and not flatten_parameters + # Allow specifying groups for the sharding collective ops, useful for mixing + # FSDP data parallelism with model parallelism (e.g. Megatron) + self.sharding_groups = sharding_groups + if sharding_groups is None: + self.rank = xm.get_ordinal() + self.world_size = xm.xrt_world_size() + else: + if sharding_rank is None or sharding_world_size is None: + raise ValueError( + "sharding_rank and sharding_world_size must be provided when sharding_groups is specified" + ) + self.rank = sharding_rank + self.world_size = sharding_world_size + self.coalesce_all_gather_ops = coalesce_all_gather_ops # Set layout pinning to False in all_gather, all_reduce, and reduce_scatter so that they can work together # TODO (ronghanghu): change the default layout pinning to True after it's supported simultaneously @@ -413,29 +435,18 @@ def __init__( self.all_reduce_op = functools.partial( xm.all_reduce, pin_layout=pin_layout_in_collective_ops) if _debug_dummy_reduce_scatter_op: - self.reduce_scatter_op = dummy_reduce_scatter + self.reduce_scatter_op = DummyReduceScatter(shard_count=self.world_size) else: - self.reduce_scatter_op = functools.partial( - xm.reduce_scatter, pin_layout=pin_layout_in_collective_ops) + self.reduce_scatter_op = BucketizedReduceScatter( + reduce_scatter_bucket_size_mb, + shard_count=self.world_size, + groups=self.sharding_groups, + pin_layout=pin_layout_in_collective_ops) if _debug_dummy_optimization_barrier_op: self.optimization_barrier_op = lambda *args: None else: self.optimization_barrier_op = xm.optimization_barrier_ - # Allow specifying groups for the sharding collective ops, useful for mixing - # FSDP data parallelism with model parallelism (e.g. Megatron) - self.sharding_groups = sharding_groups - if sharding_groups is None: - self.rank = xm.get_ordinal() - self.world_size = xm.xrt_world_size() - else: - if sharding_rank is None or sharding_world_size is None: - raise ValueError( - "sharding_rank and sharding_world_size must be provided when sharding_groups is specified" - ) - self.rank = sharding_rank - self.world_size = sharding_world_size - # Options for debugging # - set _debug_dummy_forward_pass=True to check for parameter-only memory consumption # - set _debug_msg="xxx" and _debug_print=True to distinguish different FSDP instance @@ -554,6 +565,10 @@ def set_gradient_divide_factors(self, pre: float, post: float, module.set_gradient_divide_factors(pre, post, False) self.gradient_predivide_factor = pre self.gradient_postdivide_factor = post + if (pre, post) == (1, 1): + self.reduce_scatter_op.scale = 1.0 / self.world_size + else: + self.reduce_scatter_op.scale = 1.0 @property def module(self) -> XlaFlattenParamsWrapper: @@ -1144,6 +1159,7 @@ def _register_post_backward_hooks(self) -> None: """ if not torch.is_grad_enabled(): return # don't register grad hooks if grad isn't enabled + self._post_backward_hooks_to_call = 0 for p in self.full_params: if p.requires_grad: if hasattr(p, "_shard_bwd_hook"): @@ -1157,6 +1173,7 @@ def _register_post_backward_hooks(self) -> None: handle = grad_acc.register_hook( functools.partial(self._post_backward_hook, p)) p._shard_bwd_hook = (grad_acc, handle) + self._post_backward_hooks_to_call += 1 @torch.no_grad() def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: @@ -1183,7 +1200,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # then subsequent hook callbacks will see POST state. self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) self.training_state = TrainingState.BACKWARD_POST + self._post_backward_hooks_to_call -= 1 if param.grad is None: + if self._post_backward_hooks_to_call == 0: + self.reduce_scatter_op.flush() return assert param.grad is not None, param.shape @@ -1204,6 +1224,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: apply_opt_barrier=self.optimization_barrier_in_backward) if not self._require_backward_grad_sync: + if self._post_backward_hooks_to_call == 0: + self.reduce_scatter_op.flush() return if self.gradient_predivide_factor > 1: @@ -1219,38 +1241,37 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: self.optimization_barrier_op([grad_flat]) if grad_flat.dtype != torch.float32 and self.fp32_reduce_scatter: grad_flat = grad_flat.to(torch.float32) - reduced_grad = self.reduce_scatter_op( - xm.REDUCE_SUM, - grad_flat.detach(), - scale=1.0, - scatter_dim=0, - shard_count=self.world_size, - groups=self.sharding_groups) - if reduced_grad.dtype != torch.float32: - reduced_grad = reduced_grad.to(torch.float32) - if self.optimization_barrier_in_backward: - self.optimization_barrier_op([reduced_grad]) - if self.gradient_postdivide_factor > 1: - # Average grad by world_size for consistency with PyTorch DDP. - reduced_grad.div_(self.gradient_postdivide_factor) - - grad._has_full_param = True - grad_flat._has_full_param = True - self._free_full_params( - [grad, grad_flat], - dependency_tensors=[reduced_grad], - apply_opt_barrier=self.optimization_barrier_in_backward) - self._try_adding_to_backward_opt_barrier_lists(reduced_grad) - - # Accumulate into the gradient shard. - assert hasattr(param, "_sharded_param") - p_shard = param._sharded_param - if p_shard.grad is None: - p_shard.grad = reduced_grad - else: - assert p_shard.grad.shape == reduced_grad.shape - assert p_shard.grad.device == reduced_grad.device - p_shard.grad += reduced_grad + + def reduce_scatter_done(reduced_grad): + if reduced_grad.dtype != torch.float32: + reduced_grad = reduced_grad.to(torch.float32) + if self.optimization_barrier_in_backward: + self.optimization_barrier_op([reduced_grad]) + if self.gradient_postdivide_factor > 1: + # Average grad by world_size for consistency with PyTorch DDP. + reduced_grad.div_(self.gradient_postdivide_factor) + + grad._has_full_param = True + grad_flat._has_full_param = True + self._free_full_params( + [grad, grad_flat], + dependency_tensors=[reduced_grad], + apply_opt_barrier=self.optimization_barrier_in_backward) + self._try_adding_to_backward_opt_barrier_lists(reduced_grad) + + # Accumulate into the gradient shard. + assert hasattr(param, "_sharded_param") + p_shard = param._sharded_param + if p_shard.grad is None: + p_shard.grad = reduced_grad + else: + assert p_shard.grad.shape == reduced_grad.shape + assert p_shard.grad.device == reduced_grad.device + p_shard.grad += reduced_grad + + self.reduce_scatter_op(grad_flat.detach(), reduce_scatter_done) + if self._post_backward_hooks_to_call == 0: + self.reduce_scatter_op.flush() def _queue_wait_for_post_backward(self) -> None: """