diff --git a/composer/trainer/_patch_pytorch.py b/composer/trainer/_patch_pytorch.py index e4269cebd3..5e59849d45 100644 --- a/composer/trainer/_patch_pytorch.py +++ b/composer/trainer/_patch_pytorch.py @@ -121,7 +121,6 @@ def patch_pytorch(): _MeshEnv.create_child_mesh = create_child_mesh DeviceMesh.__getitem__ = device_mesh__getitem__ - DeviceMesh.__init__ = device_mesh__init__ def build_metadata( @@ -300,7 +299,7 @@ def _shard_orig_param_state( if version.parse(torch.__version__) >= version.parse('2.3.0') and version.parse( torch.__version__, -) < version.parse('2.3.1'): +) < version.parse('2.3.2'): from torch.distributed._tensor import DTensor @no_type_check @@ -785,148 +784,259 @@ def create_global_plan( return self.global_plan, self.metadata - from torch.utils._typing_utils import not_none - from torch.distributed.device_mesh import DeviceMesh - - def create_child_mesh( - self, - device_mesh, - mesh_dim_names: tuple[str], - ): - """Monkeypatch create_child_mesh to nightly version.""" - # swap the current dim to the last dim then reshape to flatten out other - # dims, so we can just extract the list of ranks which contains cur_rank. - mesh_dims = [ - not_none(device_mesh.mesh_dim_names).index(mesh_dim_name) - for mesh_dim_name in mesh_dim_names - ] - cur_rank = device_mesh.get_rank() - mesh = device_mesh.mesh - all_mesh_dims = list(range(mesh.ndim)) - for mesh_dim in mesh_dims: - # remove not pop b/c we want the value of the ind removed not it's position in the list - # because this list dynamically changes. - all_mesh_dims.remove(mesh_dim) - - mesh_sizes = [device_mesh.mesh.size(mesh_dim) for mesh_dim in mesh_dims] - - pg_ranks_by_dim = device_mesh.mesh.permute( - *all_mesh_dims, *mesh_dims, - ).reshape(-1, *mesh_sizes) - - for mesh_nd in pg_ranks_by_dim: - if cur_rank in mesh_nd: - sub_mesh = DeviceMesh( - device_mesh.device_type, - mesh_nd, - mesh_dim_names=mesh_dim_names, - ) - res_sub_mesh = sub_mesh - - res_sub_mesh._dim_group_infos = [ # type: ignore - device_mesh._dim_group_infos[mesh_dim] for mesh_dim in mesh_dims - ] - - # Assign the current DeviceMesh as the parent of the child DeviceMesh. - self.child_to_parent_mapping[res_sub_mesh] = device_mesh # type: ignore - return res_sub_mesh # type: ignore - - from torch.distributed.device_mesh import _mesh_resources - - def device_mesh__init__( - self, - device_type: str, - mesh, - *, - mesh_dim_names: Optional[tuple[str, ...]] = None, - ) -> None: - """Monkeypatch device mesh __init__ to nightly version.""" - self.device_type = device_type - if isinstance(mesh, torch.Tensor) and mesh.device.type != 'cpu': - raise ValueError(f'`mesh` must be a CPU tensor, got {mesh}') - self.mesh = ( - mesh.detach().cpu() - if isinstance(mesh, torch.Tensor) - else torch.tensor(mesh, dtype=torch.int) - ) - self.mesh_dim_names = mesh_dim_names - - # private field to pre-generate DeviceMesh's hash - self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) - self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self))) - self._parent_mesh = _mesh_resources.get_parent_mesh(self) + # DeviceMesh monkeypatch slightly changes in PyTorch 2.3.1 + if version.parse(torch.__version__) < version.parse('2.3.1'): + from torch.utils._typing_utils import not_none + from torch.distributed.device_mesh import DeviceMesh + + def create_child_mesh( + self, + device_mesh, + mesh_dim_names: tuple[str], + ): + """Monkeypatch create_child_mesh to nightly version.""" + # swap the current dim to the last dim then reshape to flatten out other + # dims, so we can just extract the list of ranks which contains cur_rank. + mesh_dims = [ + not_none(device_mesh.mesh_dim_names).index(mesh_dim_name) + for mesh_dim_name in mesh_dim_names + ] + cur_rank = device_mesh.get_rank() + mesh = device_mesh.mesh + all_mesh_dims = list(range(mesh.ndim)) + for mesh_dim in mesh_dims: + # remove not pop b/c we want the value of the ind removed not it's position in the list + # because this list dynamically changes. + all_mesh_dims.remove(mesh_dim) + + mesh_sizes = [device_mesh.mesh.size(mesh_dim) for mesh_dim in mesh_dims] + + pg_ranks_by_dim = device_mesh.mesh.permute( + *all_mesh_dims, *mesh_dims, + ).reshape(-1, *mesh_sizes) + + for mesh_nd in pg_ranks_by_dim: + if cur_rank in mesh_nd: + sub_mesh = DeviceMesh( + device_mesh.device_type, + mesh_nd, + mesh_dim_names=mesh_dim_names, + ) + res_sub_mesh = sub_mesh + + res_sub_mesh._dim_group_infos = [ # type: ignore + device_mesh._dim_group_infos[mesh_dim] for mesh_dim in mesh_dims + ] - # Skip process group initialization if xla device. - # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. - if device_type != 'xla': - # always try to create default (world) pg, even if it is not initialized - # already. The world pg is used for device mesh identity (rank) on each - # process (we need to know if the current global rank is in the mesh or not). - self._get_or_create_default_group() - if not self._parent_mesh: - self._init_process_groups() + # Assign the current DeviceMesh as the parent of the child DeviceMesh. + self.child_to_parent_mapping[res_sub_mesh] = device_mesh # type: ignore + return res_sub_mesh # type: ignore + + from torch.distributed.device_mesh import _mesh_resources + + def device_mesh__init__( + self, + device_type: str, + mesh, + *, + mesh_dim_names: Optional[tuple[str, ...]] = None, + ) -> None: + """Monkeypatch device mesh __init__ to nightly version.""" + self.device_type = device_type + if isinstance(mesh, torch.Tensor) and mesh.device.type != 'cpu': + raise ValueError(f'`mesh` must be a CPU tensor, got {mesh}') + self.mesh = ( + mesh.detach().cpu() + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, dtype=torch.int) + ) + self.mesh_dim_names = mesh_dim_names + + # private field to pre-generate DeviceMesh's hash + self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) + self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self))) + self._parent_mesh = _mesh_resources.get_parent_mesh(self) + + # Skip process group initialization if xla device. + # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. + if device_type != 'xla': + # always try to create default (world) pg, even if it is not initialized + # already. The world pg is used for device mesh identity (rank) on each + # process (we need to know if the current global rank is in the mesh or not). + self._get_or_create_default_group() + if not self._parent_mesh: + self._init_process_groups() + + def device_mesh__getitem__(self, mesh_dim_names: Union[str, tuple[str]]) -> 'DeviceMesh': + """Monkeypatch device_mesh __getitem__ to nightly version. + + Slice the current DeviceMesh based on the mesh_dim_name given to create a child + DeviceMesh. + + Args: + mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh + to create a child DeviceMesh for. + + Returns: + A :class:`DeviceMesh` object + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]). + Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]). + Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]). + Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]). + Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]). + Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + if not self.mesh_dim_names: + raise RuntimeError('Cannot slice a DeviceMesh without mesh_dim_names.') + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) - def device_mesh__getitem__(self, mesh_dim_names: Union[str, tuple[str]]) -> 'DeviceMesh': - """Monkeypatch device_mesh __getitem__ to nightly version. + error_msg = ( + f'Invalid mesh_dim_name {mesh_dim_names} specified. ' + f'Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}.' + ) - Slice the current DeviceMesh based on the mesh_dim_name given to create a child - DeviceMesh. + # When the dimension slicing out is equal to the mesh dimensions of the current DeviceMesh, + # we simply return self if the given slicing is valid. + if mesh_dim_names == self.mesh_dim_names: + return self + # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names + # of the current DeviceMesh. + elif len(mesh_dim_names) < len(self.mesh_dim_names): + outermost_dim_name = mesh_dim_names[0] + if outermost_dim_name not in self.mesh_dim_names: + raise ValueError(error_msg) + outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) + for i, j in zip( + mesh_dim_names, + self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], + ): + if i != j: + raise ValueError(error_msg) + else: + raise ValueError(error_msg) - Args: - mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh - to create a child DeviceMesh for. + submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) + return submesh + + else: + def create_child_mesh( + self, parent_mesh: 'DeviceMesh', submesh_dim_names: Tuple[str, ...], + ) -> 'DeviceMesh': + """Monkeypatch create_child_mesh to nightly version.""" + # submesh_dims are the mesh dimension of the submesh in the parent mesh. + submesh_dims = [ + not_none(parent_mesh.mesh_dim_names).index(mesh_dim_name) + for mesh_dim_name in submesh_dim_names + ] + submesh_dim_sizes = [ + parent_mesh.mesh.size(mesh_dim) for mesh_dim in submesh_dims + ] - Returns: - A :class:`DeviceMesh` object - - The following program runs on each process/rank in an SPMD manner. In this example, we have 2 - hosts with 4 GPUs each. - Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]). - Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]). - Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]). - Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]). - Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]). - Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). - - Example:: - >>> # xdoctest: +SKIP("no rank") - >>> from torch.distributed.device_mesh import DeviceMesh - >>> - >>> # Initialize device mesh as (2, 4) to represent the topology - >>> # of cross-host(dim 0), and within-host (dim 1). - >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) - """ - if not self.mesh_dim_names: - raise RuntimeError('Cannot slice a DeviceMesh without mesh_dim_names.') + mesh_dims_remained = list(range(parent_mesh.mesh.ndim)) + for submesh_dim in submesh_dims: + mesh_dims_remained.remove(submesh_dim) + + # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *sub_mesh_dims] + # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with + # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. + pg_ranks_by_dim = parent_mesh.mesh.permute( + *mesh_dims_remained, *submesh_dims, + ).reshape(-1, *submesh_dim_sizes) + + cur_rank = parent_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + submesh = DeviceMesh( + parent_mesh.device_type, + mesh_nd, + mesh_dim_names=submesh_dim_names, + _init_backend=False, + ) + if cur_rank in mesh_nd: + res_submesh = submesh - mesh_dim_names = ( - (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names - ) + res_submesh._parent_mesh = parent_mesh # type: ignore + res_submesh._dim_group_infos = [ # type: ignore + parent_mesh._dim_group_infos[mesh_dim] for mesh_dim in submesh_dims # type: ignore + ] + self.child_to_parent_mapping[res_submesh] = parent_mesh # type: ignore + + return res_submesh # type: ignore + + def device_mesh__getitem__( + self, mesh_dim_names: Union[str, tuple[str, ...]], + ) -> 'DeviceMesh': + """Monkeypatch device_mesh __getitem__ to nightly version. + + Slice the current DeviceMesh based on the mesh_dim_name given to create a child + DeviceMesh. + + Args: + mesh_dim_name (Union[str, Tuple[str]]): the name or the tuple of names of the + mesh dimension of the parent DeviceMesh to create the child DeviceMesh for. + + Returns: + A :class:`DeviceMesh` object + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]). + Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]). + Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]). + Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]). + Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]). + Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + if not self.mesh_dim_names: + raise RuntimeError('Cannot slice a DeviceMesh without mesh_dim_names!') + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) - error_msg = ( - f'Invalid mesh_dim_name {mesh_dim_names} specified. ' - f'Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}.' - ) + error_msg = ( + f'Invalid mesh_dim_name {mesh_dim_names} specified. ' + f'Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}.' + ) - # When the dimension slicing out is equal to the mesh dimensions of the current DeviceMesh, - # we simply return self if the given slicing is valid. - if mesh_dim_names == self.mesh_dim_names: - return self - # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names - # of the current DeviceMesh. - elif len(mesh_dim_names) < len(self.mesh_dim_names): - outermost_dim_name = mesh_dim_names[0] - if outermost_dim_name not in self.mesh_dim_names: - raise ValueError(error_msg) - outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) - for i, j in zip( - mesh_dim_names, - self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], + if mesh_dim_names == self.mesh_dim_names: + return self + elif len(mesh_dim_names) > len(self.mesh_dim_names) or not all( + mesh_dim_name in self.mesh_dim_names for mesh_dim_name in mesh_dim_names ): - if i != j: - raise ValueError(error_msg) - else: - raise ValueError(error_msg) - - submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) - return submesh + raise KeyError(error_msg) + # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names + # of the current DeviceMesh. + else: + outermost_dim_name = mesh_dim_names[0] + outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) + for i, j in zip( + mesh_dim_names, + self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], + ): + if i != j: + raise KeyError(error_msg) + + submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) + return submesh