Skip to content

Commit

Permalink
[Resubmit] helpers to torch.dist.utils (pytorch#95025)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#95025
Approved by: https://github.com/fegin
  • Loading branch information
rohan-varma authored and pytorchmergebot committed Feb 17, 2023
1 parent 2aa8066 commit c43e886
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 178 deletions.
3 changes: 1 addition & 2 deletions test/distributed/fsdp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp._utils import _apply_to_tensors
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.utils import _replace_by_prefix
from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
Expand Down
50 changes: 23 additions & 27 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
TrainingState,
)
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
from torch.distributed.fsdp._utils import (
_apply_to_tensors,
_no_dispatch_record_stream,
p_assert,
)
from torch.distributed.fsdp._utils import _no_dispatch_record_stream
from torch.distributed.fsdp.api import BackwardPrefetch
from torch.distributed.fsdp.flat_param import (
_HandlesKey,
Expand All @@ -39,7 +35,7 @@
HandleShardingStrategy,
HandleTrainingState,
)
from torch.distributed.utils import _to_kwargs
from torch.distributed.utils import _apply_to_tensors, _p_assert, _to_kwargs

RESHARD_AFTER_FORWARD_STRATEGIES = {
HandleShardingStrategy.FULL_SHARD,
Expand Down Expand Up @@ -221,7 +217,7 @@ def _share_state_and_init_handle_attrs(
attr_name_to_values[attr_name] = set()
for fsdp_state in traversal_utils._get_fsdp_states(root_module):
for attr_name in HOMOGENEOUS_ATTR_NAMES:
p_assert(
_p_assert(
hasattr(fsdp_state, attr_name),
f"FSDP state missing attribute {attr_name}",
)
Expand All @@ -246,7 +242,7 @@ def _share_state_and_init_handle_attrs(
# Relax the assert for non-root FSDP instances in case the nested
# initialized module is wrapped again in FSDP later (e.g. after
# training to run inference)
p_assert(
_p_assert(
fsdp_state._is_root is None or not fsdp_state._is_root,
"Non-root FSDP instance's `_is_root` should not have been "
"set yet or should have been set to `False`",
Expand Down Expand Up @@ -344,7 +340,7 @@ def _reshard(
"""
if not handles:
return
p_assert(
_p_assert(
len(handles) == len(free_unsharded_flat_params),
"Expects both lists to have equal length but got "
f"{len(handles)} and {len(free_unsharded_flat_params)}",
Expand Down Expand Up @@ -518,7 +514,7 @@ def _root_pre_forward(
may not be the root. If not, then this method does not do anything.
"""
_lazy_init(state, module)
p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
_p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
if not state._is_root:
return args, kwargs
if state.forward_prefetch:
Expand Down Expand Up @@ -675,7 +671,7 @@ def _post_backward_hook(
# the same `FlatParameter`, the post-backward hook may run multiple
# times in one backward, in which case we permit the state to already
# be in `BACKWARD_POST`.
p_assert(
_p_assert(
handle._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
Expand Down Expand Up @@ -855,8 +851,8 @@ def _check_comm_hook(
comm_hook: Any,
comm_hook_state: Any,
) -> None:
p_assert(comm_hook is not None, "Communication hook should not be `None`")
p_assert(
_p_assert(comm_hook is not None, "Communication hook should not be `None`")
_p_assert(
comm_hook_state is not None, "Communication hook state should not be `None`"
)

Expand All @@ -865,13 +861,13 @@ def _check_grad_to_accumulate(
new_sharded_grad: torch.Tensor,
accumulated_grad: torch.Tensor,
) -> None:
p_assert(
_p_assert(
accumulated_grad.shape == new_sharded_grad.shape,
"Shape mismatch when accumulating gradients: "
f"existing gradient shape={accumulated_grad.shape} "
f"new gradient shape={new_sharded_grad.shape}",
)
p_assert(
_p_assert(
accumulated_grad.device == new_sharded_grad.device,
"Device mismatch when accumulating gradients: "
f"existing gradient device={accumulated_grad.device} "
Expand All @@ -895,7 +891,7 @@ def _post_backward_final_callback(
This runs at the end of the entire backward pass and should only be called
on the root FSDP instance.
"""
p_assert(
_p_assert(
state._is_root,
"The post-backward callback should only be called on the root FSDP instance",
)
Expand Down Expand Up @@ -952,7 +948,7 @@ def _catch_all_reshard(
if handles_to_reshard:
_reshard(state, handles_to_reshard, free_unsharded_flat_params)
except Exception as e:
p_assert(
_p_assert(
False,
f"Got exception in the catch-all reshard for {state}: {str(e)}",
raise_assertion_error=False,
Expand All @@ -969,7 +965,7 @@ def _finalize_params(
flat_param = handle.flat_param
if flat_param.requires_grad:
if hasattr(flat_param, "_post_backward_hook_state"):
p_assert(
_p_assert(
len(flat_param._post_backward_hook_state) == 2,
f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
)
Expand All @@ -982,7 +978,7 @@ def _finalize_params(
# sharded gradient from the last synchronized iteration
continue
handle.prepare_gradient_for_optim()
p_assert(
_p_assert(
hasattr(flat_param, "_post_backward_called"),
"Expects `_post_backward_called` to be set on the `FlatParameter`",
)
Expand Down Expand Up @@ -1029,7 +1025,7 @@ def _get_handles_to_prefetch(
HandleTrainingState.BACKWARD_POST,
HandleTrainingState.FORWARD,
)
p_assert(
_p_assert(
training_state in valid_training_states,
f"Prefetching is only supported in {valid_training_states} but "
f"currently in {training_state}",
Expand Down Expand Up @@ -1067,9 +1063,9 @@ def _get_training_state(
handles_key: _HandlesKey,
) -> HandleTrainingState:
"""Returns the training state of the handles in ``handles_key``."""
p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
_p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
training_states = {handle._training_state for handle in handles_key}
p_assert(
_p_assert(
len(training_states) == 1,
f"Expects uniform training state but got {training_states}",
)
Expand Down Expand Up @@ -1233,7 +1229,7 @@ def _register_post_backward_hooks(
continue
# Get the `AccumulateGrad` object
temp_flat_param = flat_param.expand_as(flat_param)
p_assert(
_p_assert(
temp_flat_param.grad_fn is not None,
"The `grad_fn` is needed to access the `AccumulateGrad` and "
"register the post-backward hook",
Expand All @@ -1255,7 +1251,7 @@ def _register_post_backward_final_callback(
backward pass. This should be called from the root FSDP instance at the
beginning of the pre-backward.
"""
p_assert(
_p_assert(
state._is_root,
"Only the root FSDP instance should register the post-backward callback",
)
Expand Down Expand Up @@ -1309,7 +1305,7 @@ def _get_buffers_and_dtypes_for_computation(
is either ``None`` if buffer mixed precision is not enabled or the buffer
low precision dtype otherwise.
"""
p_assert(state._is_root, "Expects the root to cast buffers")
_p_assert(state._is_root, "Expects the root to cast buffers")
buffers: List[torch.Tensor] = []
buffer_dtypes: List[Optional[torch.dtype]] = []
if _is_composable(state):
Expand Down Expand Up @@ -1344,7 +1340,7 @@ def _get_buffer_dtypes(
"""
buffer_dtypes: List[torch.dtype] = []
for buffer_name in buffer_names:
p_assert(
_p_assert(
buffer_name in state._buffer_name_to_orig_dtype,
f"{buffer_name} is missing from pre-computed dict on rank "
f"{state.rank}, which only has keys "
Expand All @@ -1364,7 +1360,7 @@ def _cast_buffers_to_dtype_and_device(
to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
corresponding buffer is only moved to ``device``.
"""
p_assert(
_p_assert(
buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
f"Expects `buffers` and `buffer_dtypes` to have the same length if "
f"`buffer_dtypes` is specified but got {len(buffers)} and "
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/fsdp/_unshard_param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
_unshard,
_unshard_grads,
)
from ._utils import p_assert
from torch.distributed.utils import _p_assert
from .flat_param import FlatParamHandle

FLAT_PARAM = "_flat_param"
Expand Down Expand Up @@ -336,15 +336,15 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
Deregisters the original parameters; registers the ``FlatParameter``.
"""
handles = _module_handles(state, module)
p_assert(
_p_assert(
len(handles) <= 1,
"Expects <=1 handle per FSDP instance; needs to be refactored "
"for >1 handle (e.g. non-recursive wrapping)",
)
if not handles:
return
handle = handles[0]
p_assert(
_p_assert(
handle._use_orig_params,
f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
f"handle: {handle._use_orig_params}",
Expand Down
99 changes: 1 addition & 98 deletions torch/distributed/fsdp/_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
import dataclasses
import traceback
from collections import OrderedDict
from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union
from typing import cast

import torch
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined]
_is_namedtuple,
)
from torch.nn.utils.rnn import PackedSequence
from torch.utils._mode_utils import no_dispatch


Expand All @@ -22,102 +15,12 @@ def _override_batchnorm_mixed_precision(module):
mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment]


def _apply_to_tensors(
fn: Callable,
container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""

def apply(
x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
) -> Any:
if torch.is_tensor(x):
return fn(x)
elif hasattr(x, "__dataclass_fields__"):
dc = dataclasses.replace(x)
for f in dataclasses.fields(dc):
name = f.name
setattr(dc, name, apply(getattr(dc, name)))
return dc
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = apply(value)
return od
elif isinstance(x, PackedSequence):
apply(x.data)
return x
elif isinstance(x, dict):
return {key: apply(value) for key, value in x.items()}
elif _is_namedtuple(x):
res = (apply(el) for el in x)
return type(x)(*res)
elif isinstance(x, (list, tuple, set)):
return type(x)(apply(el) for el in x)
else:
return x

return apply(container)


@torch.no_grad()
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
"""
Allocate storage for ``tensor`` with the given size.
Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
p_assert(
tensor_storage_size == 0,
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
)
tensor._typed_storage()._resize_(size.numel())
return not already_allocated


@torch.no_grad()
def _free_storage(tensor: torch.Tensor) -> bool:
"""
Frees the underlying storage of ``tensor``.
Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)
return not already_freed


def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
"""Returns if ``x`` and ``y`` share the same storage."""
# NOTE: CPU and GPU tensors are ensured to have different data pointers.
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()


def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
"""This is used as an alternate to ``assert`` when in the backward context
to print the error message ``s`` since otherwise, it is swallowed."""
if not cond:
print(s)
traceback.print_stack()
if raise_assertion_error:
raise AssertionError(s)


def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
with no_dispatch():
tensor.record_stream(cast(torch._C.Stream, stream))
Loading

0 comments on commit c43e886

Please sign in to comment.