diff --git a/test/distributed/fsdp/test_utils.py b/test/distributed/fsdp/test_utils.py index 8df1062bc3713..758561b4eded9 100644 --- a/test/distributed/fsdp/test_utils.py +++ b/test/distributed/fsdp/test_utils.py @@ -11,10 +11,9 @@ import torch import torch.nn as nn from torch import distributed as dist -from torch.distributed.fsdp._utils import _apply_to_tensors from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from torch.distributed.utils import _replace_by_prefix +from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index b7a13689e4ff7..75a0d45c01602 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -26,11 +26,7 @@ TrainingState, ) from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES -from torch.distributed.fsdp._utils import ( - _apply_to_tensors, - _no_dispatch_record_stream, - p_assert, -) +from torch.distributed.fsdp._utils import _no_dispatch_record_stream from torch.distributed.fsdp.api import BackwardPrefetch from torch.distributed.fsdp.flat_param import ( _HandlesKey, @@ -39,7 +35,7 @@ HandleShardingStrategy, HandleTrainingState, ) -from torch.distributed.utils import _to_kwargs +from torch.distributed.utils import _apply_to_tensors, _p_assert, _to_kwargs RESHARD_AFTER_FORWARD_STRATEGIES = { HandleShardingStrategy.FULL_SHARD, @@ -221,7 +217,7 @@ def _share_state_and_init_handle_attrs( attr_name_to_values[attr_name] = set() for fsdp_state in traversal_utils._get_fsdp_states(root_module): for attr_name in HOMOGENEOUS_ATTR_NAMES: - p_assert( + _p_assert( hasattr(fsdp_state, attr_name), f"FSDP state missing attribute {attr_name}", ) @@ -246,7 +242,7 @@ def _share_state_and_init_handle_attrs( # Relax the assert for non-root FSDP instances in case the nested # initialized module is wrapped again in FSDP later (e.g. after # training to run inference) - p_assert( + _p_assert( fsdp_state._is_root is None or not fsdp_state._is_root, "Non-root FSDP instance's `_is_root` should not have been " "set yet or should have been set to `False`", @@ -344,7 +340,7 @@ def _reshard( """ if not handles: return - p_assert( + _p_assert( len(handles) == len(free_unsharded_flat_params), "Expects both lists to have equal length but got " f"{len(handles)} and {len(free_unsharded_flat_params)}", @@ -518,7 +514,7 @@ def _root_pre_forward( may not be the root. If not, then this method does not do anything. """ _lazy_init(state, module) - p_assert(state._is_root is not None, "Expects a root FSDP to have been set") + _p_assert(state._is_root is not None, "Expects a root FSDP to have been set") if not state._is_root: return args, kwargs if state.forward_prefetch: @@ -675,7 +671,7 @@ def _post_backward_hook( # the same `FlatParameter`, the post-backward hook may run multiple # times in one backward, in which case we permit the state to already # be in `BACKWARD_POST`. - p_assert( + _p_assert( handle._training_state in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST), f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}", @@ -855,8 +851,8 @@ def _check_comm_hook( comm_hook: Any, comm_hook_state: Any, ) -> None: - p_assert(comm_hook is not None, "Communication hook should not be `None`") - p_assert( + _p_assert(comm_hook is not None, "Communication hook should not be `None`") + _p_assert( comm_hook_state is not None, "Communication hook state should not be `None`" ) @@ -865,13 +861,13 @@ def _check_grad_to_accumulate( new_sharded_grad: torch.Tensor, accumulated_grad: torch.Tensor, ) -> None: - p_assert( + _p_assert( accumulated_grad.shape == new_sharded_grad.shape, "Shape mismatch when accumulating gradients: " f"existing gradient shape={accumulated_grad.shape} " f"new gradient shape={new_sharded_grad.shape}", ) - p_assert( + _p_assert( accumulated_grad.device == new_sharded_grad.device, "Device mismatch when accumulating gradients: " f"existing gradient device={accumulated_grad.device} " @@ -895,7 +891,7 @@ def _post_backward_final_callback( This runs at the end of the entire backward pass and should only be called on the root FSDP instance. """ - p_assert( + _p_assert( state._is_root, "The post-backward callback should only be called on the root FSDP instance", ) @@ -952,7 +948,7 @@ def _catch_all_reshard( if handles_to_reshard: _reshard(state, handles_to_reshard, free_unsharded_flat_params) except Exception as e: - p_assert( + _p_assert( False, f"Got exception in the catch-all reshard for {state}: {str(e)}", raise_assertion_error=False, @@ -969,7 +965,7 @@ def _finalize_params( flat_param = handle.flat_param if flat_param.requires_grad: if hasattr(flat_param, "_post_backward_hook_state"): - p_assert( + _p_assert( len(flat_param._post_backward_hook_state) == 2, f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}", ) @@ -982,7 +978,7 @@ def _finalize_params( # sharded gradient from the last synchronized iteration continue handle.prepare_gradient_for_optim() - p_assert( + _p_assert( hasattr(flat_param, "_post_backward_called"), "Expects `_post_backward_called` to be set on the `FlatParameter`", ) @@ -1029,7 +1025,7 @@ def _get_handles_to_prefetch( HandleTrainingState.BACKWARD_POST, HandleTrainingState.FORWARD, ) - p_assert( + _p_assert( training_state in valid_training_states, f"Prefetching is only supported in {valid_training_states} but " f"currently in {training_state}", @@ -1067,9 +1063,9 @@ def _get_training_state( handles_key: _HandlesKey, ) -> HandleTrainingState: """Returns the training state of the handles in ``handles_key``.""" - p_assert(len(handles_key) > 0, "Expects a non-empty handles key") + _p_assert(len(handles_key) > 0, "Expects a non-empty handles key") training_states = {handle._training_state for handle in handles_key} - p_assert( + _p_assert( len(training_states) == 1, f"Expects uniform training state but got {training_states}", ) @@ -1233,7 +1229,7 @@ def _register_post_backward_hooks( continue # Get the `AccumulateGrad` object temp_flat_param = flat_param.expand_as(flat_param) - p_assert( + _p_assert( temp_flat_param.grad_fn is not None, "The `grad_fn` is needed to access the `AccumulateGrad` and " "register the post-backward hook", @@ -1255,7 +1251,7 @@ def _register_post_backward_final_callback( backward pass. This should be called from the root FSDP instance at the beginning of the pre-backward. """ - p_assert( + _p_assert( state._is_root, "Only the root FSDP instance should register the post-backward callback", ) @@ -1309,7 +1305,7 @@ def _get_buffers_and_dtypes_for_computation( is either ``None`` if buffer mixed precision is not enabled or the buffer low precision dtype otherwise. """ - p_assert(state._is_root, "Expects the root to cast buffers") + _p_assert(state._is_root, "Expects the root to cast buffers") buffers: List[torch.Tensor] = [] buffer_dtypes: List[Optional[torch.dtype]] = [] if _is_composable(state): @@ -1344,7 +1340,7 @@ def _get_buffer_dtypes( """ buffer_dtypes: List[torch.dtype] = [] for buffer_name in buffer_names: - p_assert( + _p_assert( buffer_name in state._buffer_name_to_orig_dtype, f"{buffer_name} is missing from pre-computed dict on rank " f"{state.rank}, which only has keys " @@ -1364,7 +1360,7 @@ def _cast_buffers_to_dtype_and_device( to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the corresponding buffer is only moved to ``device``. """ - p_assert( + _p_assert( buffer_dtypes is None or len(buffers) == len(buffer_dtypes), f"Expects `buffers` and `buffer_dtypes` to have the same length if " f"`buffer_dtypes` is specified but got {len(buffers)} and " diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py index e1c4b7e870448..af75cea11ba70 100644 --- a/torch/distributed/fsdp/_unshard_param_utils.py +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -21,7 +21,7 @@ _unshard, _unshard_grads, ) -from ._utils import p_assert +from torch.distributed.utils import _p_assert from .flat_param import FlatParamHandle FLAT_PARAM = "_flat_param" @@ -336,7 +336,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None: Deregisters the original parameters; registers the ``FlatParameter``. """ handles = _module_handles(state, module) - p_assert( + _p_assert( len(handles) <= 1, "Expects <=1 handle per FSDP instance; needs to be refactored " "for >1 handle (e.g. non-recursive wrapping)", @@ -344,7 +344,7 @@ def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None: if not handles: return handle = handles[0] - p_assert( + _p_assert( handle._use_orig_params, f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} " f"handle: {handle._use_orig_params}", diff --git a/torch/distributed/fsdp/_utils.py b/torch/distributed/fsdp/_utils.py index 5efb376e66458..45c8c455422b6 100644 --- a/torch/distributed/fsdp/_utils.py +++ b/torch/distributed/fsdp/_utils.py @@ -1,14 +1,7 @@ -import dataclasses -import traceback -from collections import OrderedDict -from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union +from typing import cast import torch from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined] - _is_namedtuple, -) -from torch.nn.utils.rnn import PackedSequence from torch.utils._mode_utils import no_dispatch @@ -22,102 +15,12 @@ def _override_batchnorm_mixed_precision(module): mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment] -def _apply_to_tensors( - fn: Callable, - container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence], -) -> Any: - """Recursively apply to all tensor in different kinds of container types.""" - - def apply( - x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence] - ) -> Any: - if torch.is_tensor(x): - return fn(x) - elif hasattr(x, "__dataclass_fields__"): - dc = dataclasses.replace(x) - for f in dataclasses.fields(dc): - name = f.name - setattr(dc, name, apply(getattr(dc, name))) - return dc - elif isinstance(x, OrderedDict): - od = x.__class__() - for key, value in x.items(): - od[key] = apply(value) - return od - elif isinstance(x, PackedSequence): - apply(x.data) - return x - elif isinstance(x, dict): - return {key: apply(value) for key, value in x.items()} - elif _is_namedtuple(x): - res = (apply(el) for el in x) - return type(x)(*res) - elif isinstance(x, (list, tuple, set)): - return type(x)(apply(el) for el in x) - else: - return x - - return apply(container) - - -@torch.no_grad() -def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool: - """ - Allocate storage for ``tensor`` with the given size. - - Returns: - bool: ``True`` if this method allocated storage and ``False`` if the - storage was already allocated. - """ - already_allocated = tensor._typed_storage()._size() == size.numel() - if not already_allocated: - tensor_storage_size = tensor._typed_storage()._size() - p_assert( - tensor_storage_size == 0, - f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}", - ) - tensor._typed_storage()._resize_(size.numel()) - return not already_allocated - - -@torch.no_grad() -def _free_storage(tensor: torch.Tensor) -> bool: - """ - Frees the underlying storage of ``tensor``. - - Returns: - bool: ``True`` if the method freed the storage and ``False`` if the - storage was already freed. - """ - already_freed = tensor._typed_storage()._size() == 0 - if not already_freed: - p_assert( - tensor.storage_offset() == 0, - "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" - f"storage offset: {tensor.storage_offset()}\n" - f"storage size: {tensor._typed_storage()._size()}\n" - f"tensor shape: {tensor.shape}", - ) - tensor._typed_storage()._resize_(0) - return not already_freed - - def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool: """Returns if ``x`` and ``y`` share the same storage.""" # NOTE: CPU and GPU tensors are ensured to have different data pointers. return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr() -def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: - """This is used as an alternate to ``assert`` when in the backward context - to print the error message ``s`` since otherwise, it is swallowed.""" - if not cond: - print(s) - traceback.print_stack() - if raise_assertion_error: - raise AssertionError(s) - - def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None: with no_dispatch(): tensor.record_stream(cast(torch._C.Stream, stream)) diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py index 3cb4efd7a7fe3..1bfc2090a7cfc 100644 --- a/torch/distributed/fsdp/flat_param.py +++ b/torch/distributed/fsdp/flat_param.py @@ -27,15 +27,10 @@ _set_fsdp_flattened, HandleTrainingState, ) +from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform -from ._utils import ( - _alloc_storage, - _free_storage, - _no_dispatch_record_stream, - _same_storage, - p_assert, -) +from ._utils import _no_dispatch_record_stream, _same_storage __all__ = [ "FlatParameter", @@ -558,7 +553,7 @@ def shard(self): if not self.uses_sharded_strategy: self._init_shard_metadata(0, 0, flat_param.numel() - 1) else: - p_assert( + _p_assert( flat_param.storage_offset() == 0, "The `FlatParameter` is not the sole occupant of its storage", ) @@ -600,8 +595,8 @@ def _init_shard_metadata( """ self.flat_param._sharded_size = self.flat_param.size() # type: ignore[attr-defined] sharded_flat_param_numel = self.flat_param.numel() # includes `numel_padded` - p_assert(start >= 0 and start <= end, f"start: {start} end: {end}") - p_assert( + _p_assert(start >= 0 and start <= end, f"start: {start} end: {end}") + _p_assert( numel_padded <= sharded_flat_param_numel, f"numel_padded: {numel_padded} " f"sharded_flat_param_numel: {sharded_flat_param_numel}", @@ -792,7 +787,7 @@ def init_flat_param_attributes(self) -> None: self._orig_param_dtype = flat_param.dtype cpu_device = torch.device("cpu") if self._offload_params: - p_assert( + _p_assert( flat_param.device == cpu_device, f"Expects the `FlatParameter` to be on CPU when parameter CPU " f"offloading is enabled, not {flat_param.device}", @@ -957,7 +952,7 @@ def _get_padded_unsharded_flat_param(self) -> torch.Tensor: # tensor as the all-gather destination to preserve the invariant # that `_full_param_padded` is in the low precision unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined] - p_assert( + _p_assert( unsharded_flat_param.dtype != self._fwd_bwd_param_dtype, f"Expects full precision but got {self._fwd_bwd_param_dtype}", ) @@ -974,13 +969,13 @@ def _all_gather_flat_param( ``padded_unsharded_flat_param``, and switches to using the all-gathered tensor. """ - p_assert( + _p_assert( hasattr(self, "process_group") and hasattr(self, "world_size"), "Expects a process group and world size to have been set via `shard()`", ) sharded_flat_param = self.flat_param.data expected_numel = sharded_flat_param.numel() * self.world_size - p_assert( + _p_assert( padded_unsharded_flat_param.numel() == expected_numel, f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", ) @@ -1111,7 +1106,7 @@ def prepare_gradient_for_backward(self): clearing any existing sharded gradient in ``.grad`` to enable computing a new unsharded gradient. """ - p_assert( + _p_assert( self._training_state in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE), "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)", @@ -1123,7 +1118,7 @@ def prepare_gradient_for_backward(self): ): self._check_on_compute_device(self.flat_param) grad_offloaded = flat_param.grad.device != self.device - p_assert( + _p_assert( not grad_offloaded or self._offload_params, f"Expects the sharded gradient to be on {self.device} " f"but got {flat_param.grad.device}", @@ -1142,7 +1137,7 @@ def prepare_gradient_for_backward(self): flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined] sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] else: - p_assert( + _p_assert( hasattr(flat_param, "_cpu_grad"), "`_cpu_grad` should be defined if the gradient is on CPU", ) @@ -1162,7 +1157,7 @@ def prepare_gradient_for_backward(self): sharded_grad.data = sharded_grad.to(local_shard_dtype) else: padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined] - p_assert( + _p_assert( flat_param.grad.size() == padded_unsharded_size, "Expects `.grad` to be the unsharded gradient in " f"`no_sync()` with size {padded_unsharded_size} " @@ -1203,7 +1198,7 @@ def cast_grad_to_param_dtype_if_needed(flat_param): flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined] cast_grad_to_param_dtype_if_needed(flat_param) else: - p_assert( + _p_assert( not self.uses_sharded_strategy or not flat_param._post_backward_called, # type: ignore[attr-defined] "All sharded parameters that received a gradient in the " @@ -1229,7 +1224,7 @@ def to_cpu(self): Postcondition: Same as the precondition. """ self._check_sharded_strategy() - p_assert( + _p_assert( self.flat_param.size() == self.flat_param._unpadded_unsharded_size, f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", ) @@ -1242,7 +1237,7 @@ def to_cpu(self): padded_storage_ptr = ( self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr() ) - p_assert( + _p_assert( unpadded_storage_ptr == padded_storage_ptr, "Expects the unpadded parameter to be a view into the padded parameter", ) @@ -1251,7 +1246,7 @@ def to_cpu(self): try: yield finally: - p_assert( + _p_assert( self.flat_param.size() == self.flat_param._unpadded_unsharded_size, f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", ) @@ -1314,7 +1309,7 @@ def _use_sharded_flat_param(self) -> None: flat_param = self.flat_param if self._offload_params: device = flat_param._local_shard.device # type: ignore[attr-defined] - p_assert( + _p_assert( device == torch.device("cpu"), f"Expects the local shard to be on CPU but got {device}", ) @@ -1357,7 +1352,7 @@ def _get_unflat_views( """ if tensor is None: tensor = flat_param - p_assert( + _p_assert( tensor.numel() == flat_param._unpadded_unsharded_size.numel(), f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got " f"{tensor.numel()} numel", @@ -1416,7 +1411,7 @@ def _use_unsharded_views(self, as_params: bool) -> None: # hook fires (e.g. for reentrant AC) assert self.flat_param._tensors is not None # mypy tensor = self.flat_param._tensors[i] - p_assert( + _p_assert( tensor is not None, "Expects `Tensor` to have been saved in forward", ) @@ -1439,14 +1434,14 @@ def _use_unsharded_views(self, as_params: bool) -> None: ) in enumerate(self.flat_param._shared_param_infos): if hasattr(module, param_name): delattr(module, param_name) - p_assert( + _p_assert( hasattr(prim_module, prim_param_name), f"Module {prim_module_name} is missing parameter {prim_param_name}", ) prim_param: Union[Tensor, nn.Parameter] = getattr( prim_module, prim_param_name ) - p_assert( + _p_assert( not as_params or isinstance(prim_param, nn.Parameter), f"as_params={as_params} type(prim_param)={type(prim_param)}", ) @@ -1485,7 +1480,7 @@ def _use_unsharded_grad_views(self) -> None: for i, (view, (param_name, module, _)) in enumerate( zip(views, self.flat_param._param_infos) ): - p_assert( + _p_assert( hasattr(module, param_name), f"{self.flat_param._fqns[i]} is missing", ) @@ -1511,7 +1506,7 @@ def _use_unsharded_grad_views(self) -> None: prim_module, _, ) in enumerate(self.flat_param._shared_param_infos): - p_assert( + _p_assert( hasattr(module, param_name), f"{module_name + '.' + param_name if module_name else param_name} is missing", ) # did not save FQN info in `_shared_param_infos` @@ -1793,7 +1788,7 @@ def _writeback_tensor( RuntimeError: If the ``src_tensor`` does not have the expected shape. """ - p_assert( + _p_assert( len(expected_shape) == 1, f"Expects a 1D expected shape but got {expected_shape}", ) @@ -1935,7 +1930,7 @@ def sharded_grad(self) -> Optional[Tensor]: else: # If in the forward, then there may be an accumulated gradient, # which will be in `.grad` - p_assert( + _p_assert( flat_param.grad is None or not self.uses_sharded_strategy or self._training_state == HandleTrainingState.FORWARD, @@ -1954,7 +1949,7 @@ def _reset_is_grad_none(self) -> None: """ if not self._use_orig_params: return - p_assert( + _p_assert( self._training_state == HandleTrainingState.BACKWARD_POST, "Expects to only be called in the post-backward after gradient computation", ) @@ -1971,16 +1966,16 @@ def _reset_is_grad_none(self) -> None: # CHECKS & INVARIANTS # ####################### def _check_sharded_strategy(self): - p_assert(self.uses_sharded_strategy, "Expects sharded strategy") + _p_assert(self.uses_sharded_strategy, "Expects sharded strategy") def _check_on_compute_device(self, tensor: Tensor): - p_assert( + _p_assert( tensor.device == self.device, f"Expects tensor to be on the compute device {self.device}", ) def _check_on_cpu(self, tensor: Tensor): - p_assert( + _p_assert( tensor.device == torch.device("cpu"), f"Expects tensor to be on CPU but got {tensor.device}", ) @@ -1988,7 +1983,7 @@ def _check_on_cpu(self, tensor: Tensor): @staticmethod def _check_storage_freed(tensor: Tensor): storage_size: int = tensor._typed_storage()._size() - p_assert( + _p_assert( storage_size == 0, f"Expects storage to be freed but got storage with size {storage_size}", ) @@ -1996,37 +1991,37 @@ def _check_storage_freed(tensor: Tensor): @staticmethod def _check_storage_allocated(tensor: Tensor): storage_size: int = tensor._typed_storage()._size() - p_assert(storage_size > 0, "Expects storage to be allocated") + _p_assert(storage_size > 0, "Expects storage to be allocated") def _check_low_precision_shard(self): - p_assert( + _p_assert( self._uses_param_mixed_precision, "Not using low precision for parameters", ) - p_assert( + _p_assert( getattr(self.flat_param, "_mp_shard", None) is not None, "Expects `_mp_shard` to exist", ) device = self.flat_param._mp_shard.device # type: ignore[attr-defined] - p_assert( + _p_assert( device == self.device, f"Expects the low precision shard to be on {self.device} but got {device}", ) def _check_unsharded(self, tensor: Tensor): msg_prefix = "Expects tensor to be unsharded " - p_assert(tensor is not None, msg_prefix + "but got `None`") + _p_assert(tensor is not None, msg_prefix + "but got `None`") unsharded_size = self.flat_param._unpadded_unsharded_size - p_assert( + _p_assert( tensor.size() == unsharded_size, msg_prefix + f"with size {unsharded_size} but got {tensor.size()}", ) def _check_sharded(self, tensor: Tensor): msg_prefix = "Expects tensor to be sharded " - p_assert(tensor is not None, msg_prefix + "but got `None`") + _p_assert(tensor is not None, msg_prefix + "but got `None`") sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] - p_assert( + _p_assert( tensor.size() == sharded_size, msg_prefix + f"with size {sharded_size} but got {tensor.size()}", ) diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index a2c95b21d2247..68d515f111248 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -77,6 +77,7 @@ StateDictSettings, StateDictType, ) +from torch.distributed.utils import _p_assert from ._optim_utils import ( _broadcast_pos_dim_tensor_states, @@ -98,7 +99,6 @@ _unshard_params, _unshard_params_recurse, ) -from ._utils import p_assert from .flat_param import FlatParameter from .wrap import _FSDPPolicy @@ -740,7 +740,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs ) for handle in self._handles: - p_assert( + _p_assert( handle.flat_param.device == self.compute_device, "Expected `FlatParameter` to be on the compute device " f"{self.compute_device} but got {handle.flat_param.device}", @@ -830,7 +830,7 @@ def _deregister_orig_params_ctx(self): this refreshes the sharded views before exiting. This method shouuld only be called when using the original parameters. """ - p_assert( + _p_assert( self._use_orig_params, "`_deregister_orig_params_ctx()` should only be called when " "`_use_orig_params=True`", diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index f827de143bf6c..5848c0ecab0ef 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Callable, Union, Set, OrderedDict +import dataclasses +import traceback import torch import torch.distributed as dist @@ -94,6 +96,92 @@ def to_map(obj): to_map = None # type: ignore[assignment] return res +def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: + """This is used as an alternate to ``assert`` when in the backward context + to print the error message ``s`` since otherwise, it is swallowed.""" + if not cond: + print(s) + traceback.print_stack() + if raise_assertion_error: + raise AssertionError(s) + +def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool: + """ + Allocate storage for ``tensor`` with the given size. + + Returns: + bool: ``True`` if this method allocated storage and ``False`` if the + storage was already allocated. + """ + with torch.no_grad(): + already_allocated = tensor._typed_storage()._size() == size.numel() + if not already_allocated: + tensor_storage_size = tensor._typed_storage()._size() + _p_assert( + tensor_storage_size == 0, + f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}", + ) + tensor._typed_storage()._resize_(size.numel()) + return not already_allocated + + +def _free_storage(tensor: torch.Tensor) -> bool: + """ + Frees the underlying storage of ``tensor``. + + Returns: + bool: ``True`` if the method freed the storage and ``False`` if the + storage was already freed. + """ + with torch.no_grad(): + already_freed = tensor._typed_storage()._size() == 0 + if not already_freed: + _p_assert( + tensor.storage_offset() == 0, + "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" + f"storage offset: {tensor.storage_offset()}\n" + f"storage size: {tensor._typed_storage()._size()}\n" + f"tensor shape: {tensor.shape}", + ) + tensor._typed_storage()._resize_(0) + return not already_freed + +def _apply_to_tensors( + fn: Callable, + container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence], +) -> Any: + """Recursively apply to all tensor in different kinds of container types.""" + + def apply( + x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence] + ) -> Any: + if torch.is_tensor(x): + return fn(x) + elif hasattr(x, "__dataclass_fields__"): + dc = dataclasses.replace(x) + for f in dataclasses.fields(dc): + name = f.name + setattr(dc, name, apply(getattr(dc, name))) + return dc + elif isinstance(x, OrderedDict): + od = x.__class__() + for key, value in x.items(): + od[key] = apply(value) + return od + elif isinstance(x, PackedSequence): + apply(x.data) + return x + elif isinstance(x, dict): + return {key: apply(value) for key, value in x.items()} + elif _is_namedtuple(x): + res = (apply(el) for el in x) + return type(x)(*res) + elif isinstance(x, (list, tuple, set)): + return type(x)(apply(el) for el in x) + else: + return x + + return apply(container) def _to_kwargs(inputs, kwargs, device_id, use_side_stream_for_tensor_copies): inputs = (