From bda5babf3bb477cfef824d76fcff49edda32a2c6 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Wed, 22 Nov 2023 00:28:29 +0000 Subject: [PATCH 01/66] fixes to get dtensor to work --- composer/trainer/mosaic_fsdp.py | 2 -- setup.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index fd786efe6f..6b94144ec0 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -50,5 +50,3 @@ def patch_pytorch(): # Monkey patch partial state dict handling _state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook) - elif version.parse(torch.__version__) >= version.parse('2.1.1'): - raise NotImplementedError(f'FullyShardedDataParallel is not supported for torch >= 2.2.0') diff --git a/setup.py b/setup.py index f771049b47..bd3df62ccd 100644 --- a/setup.py +++ b/setup.py @@ -77,8 +77,8 @@ def package_files(prefix: str, directory: str, extension: str): 'tqdm>=4.62.3,<5', 'torchmetrics>=0.10.0,<1.1', 'torch_optimizer>=0.3.0,<0.4', - 'torchvision>=0.13.1,<0.17', - 'torch>=1.13.1,<2.1.1', + 'torchvision>0.16', + 'torch>=2.0.0', 'requests>=2.26.0,<3', 'numpy>=1.21.5,<1.27.0', 'psutil>=5.8.0,<6', From 60d02b4cf6de181ead64630538f134f097d76c41 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Fri, 15 Dec 2023 03:30:07 +0000 Subject: [PATCH 02/66] more fixes --- composer/trainer/dist_strategy.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 4c7e0d6765..8c2d34423c 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -264,6 +264,13 @@ def sync_hook(*args): # `nn.Module.named_parameters`. # Setting it to `True` is mandatory when using `torch.compile()`. kwargs['use_orig_params'] = fsdp_config['use_orig_params'] + print(version.parse(torch.__version__)) + if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'): + from torch.distributed._tensor import init_device_mesh + kwargs['device_mesh'] = init_device_mesh( + 'cuda', + (dist.get_world_size(),), + ) # necessary variables for optimizers with multiple param groups in FSDP num_param_groups = None @@ -455,6 +462,7 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]: if ret and auto_microbatching: module.register_forward_hook(sync_hook) module.register_full_backward_hook(sync_hook) + print(module, '\n', ret) return ret _auto_wrap_policy = CustomPolicy(lambda_fn) From 8897c5201bdc62d2984485974d566517ed2a3774 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Wed, 20 Dec 2023 01:35:06 +0000 Subject: [PATCH 03/66] Change state dict materialization for new version of torch --- composer/core/state.py | 77 ++++++++++++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index dbdba40170..2bbf746929 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -864,18 +864,62 @@ def get_model_state_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The state dict for the model. """ - if self.fsdp_enabled and self.fsdp_state_dict_type is not None: - with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): - model_state_dict = self.model.state_dict() + if version.parse(torch.__version__) > version.parse("2.1.2"): + model_state_dict, _ = self.get_model_and_optimizer_state_dict(model_only=True) else: - model_state_dict = self.model.state_dict() + if self.fsdp_enabled and self.fsdp_state_dict_type is not None: + with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): + model_state_dict = self.model.state_dict() + else: + model_state_dict = self.model.state_dict() - # Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel - # If it is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail - if self.is_model_ddp: - torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.') + # Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel + # If it is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail + if self.is_model_ddp: + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.') return model_state_dict + def _legacy_get_optim_state_dict(self) -> Dict[str, Any]: + optimizer = ensure_tuple(self.optimizers)[0] # Let's stop pretending. We don't support more than one optimizer. + if self.fsdp_enabled and self.fsdp_state_dict_type is not None: + optim_state_dict = { + type(optimizer).__qualname__: + fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type) + } + else: + optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()} + return optim_state_dict + + def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if version.parse(torch.__version__) > version.parse("2.1.2"): + from torch.distributed.checkpoint.state_dict import get_state_dict, StateDictOptions + full_state_dict = True if self.fsdp_state_dict_type == 'full' else False + if self.fsdp_state_dict_type not in ['full', 'sharded']: + raise NotImplementedError( + textwrap.dedent( + f"fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for a torch version > 2.1.2." + f"You are using {version.parse(torch.__version__)}" + ) + ) + + optimizer = ensure_tuple(self.optimizers)[0] + model_state_dict, optim_state_dict = get_state_dict(model=self.model, + optimizers=([] if model_only else optimizer), + submodules=None, + options=StateDictOptions( + full_state_dict=full_state_dict, + cpu_offload=True, + ignore_frozen_params=True, + keep_submodule_prefixes=True, + strict=True, + ) + ) + else: + model_state_dict = self.get_model_state_dict() + optim_state_dict = self._legacy_get_optim_state_dict() + + return model_state_dict, optim_state_dict + def state_dict(self) -> Dict[str, Any]: """Collect the state dicts of our serializable attributes. @@ -883,24 +927,13 @@ def state_dict(self) -> Dict[str, Any]: Dict[str, Any]: The state dict. """ state_dict = {} - + state_dict['model'], state_dict['optimizers'] = self.get_model_and_optimizer_state_dict() for attribute_name in self.serialized_attributes: attribute_value = getattr(self, attribute_name) + if attribute_name in ['model', 'optimizers']: + continue if attribute_name == 'dataset_state': serialized_value = self._dataset_state_dict() - elif attribute_name == 'model': - serialized_value = self.get_model_state_dict() - elif attribute_name == 'optimizers': - optimizer = ensure_tuple(attribute_value)[ - 0] # Let's stop pretending. We don't support more than one optimizer. - if self.fsdp_enabled and self.fsdp_state_dict_type is not None: - optim_state_dict = { - type(optimizer).__qualname__: - fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type) - } - else: - optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()} - serialized_value = optim_state_dict elif attribute_name == 'algorithms': # Store as list to preserve order in which algorithms were applied serialized_value = [(type(obj).__qualname__, obj.state_dict()) for obj in ensure_tuple(attribute_value)] From a88b43048350bb64111bb1e28eade48d02a0b05f Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Thu, 21 Dec 2023 01:09:47 +0000 Subject: [PATCH 04/66] get load working for new set_state_dict api --- composer/core/state.py | 35 ++++++++++++++++++++++++++++------- composer/utils/checkpoint.py | 30 +++++++++++++++++++----------- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 2bbf746929..ec5fcf79b0 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -1238,6 +1238,31 @@ def _load_dataset_state(self, obj: Dict[str, Any]) -> None: # starts. This avoids "CUDA error: initialization error" -- its not clear why. # self.dataset_resumption['eval'][evaluator.label] = True + def load_model_and_optimizer_state(self, + state_dict: Dict[str, Any], + logger: Logger, + strict: bool, + exclude_algorithms: Optional[List[str]] = None, + algorithm_passes: Optional[List[AlgorithmPass]] = None, + load_model_only: bool = False): + # Note: In this case required algorithms not applied. + if version.parse(torch.__version__) > version.parse("2.1.2"): + from torch.distributed.checkpoint.state_dict import set_state_dict, StateDictOptions + optimizer = ensure_tuple(self.optimizers)[0] + set_state_dict(self.model, + optimizers=([] if load_model_only else optimizer), + model_state_dict=state_dict['model'], + optim_state_dict=({} if load_model_only else state_dict['optimizers']), + options=StateDictOptions(strict=False)) + else: + self.load_model_state(state_dict, + logger, + strict=strict, + exclude_algorithms=exclude_algorithms, + algorithm_passes=algorithm_passes,) + if not load_model_only: + self.load_optim_state(state_dict) + def load_state_dict( self, state: Dict[str, Any], @@ -1261,28 +1286,26 @@ def load_state_dict( # Call load_model_state first since it applies required algorithms if 'model' in state: - self.load_model_state( + self.load_model_and_optimizer_state( state, logger, strict=strict, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, + load_model_only=('optimizers' in state) ) for attribute_name in sorted(state.keys()): # Sort so all ranks load in the same order serialized_value = state[attribute_name] # Skip removed attributes as well as algorithms and model, which was already loaded - if attribute_name not in self.serialized_attributes or attribute_name == 'model': + if attribute_name not in self.serialized_attributes or attribute_name in ['model', 'optimizers']: continue - # Integrations are extra information about other libraries (e.g. huggingface) and not attributes to be loaded here if attribute_name == 'integrations': continue - # Skip metadata, which is not an attribute on State if attribute_name == 'metadata': continue - log.debug(f'Loading {attribute_name} into state.') # Restructure algorithms serialized_value from list to dict @@ -1291,8 +1314,6 @@ def load_state_dict( if attribute_name == 'dataset_state': self._load_dataset_state(serialized_value) - elif attribute_name == 'optimizers': - self.load_optim_state(state) elif attribute_name == 'train_metrics': # Get current metrics object and populate each metric present # in serialization with serialized data via load_state_dict() diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index a35298303d..e010af0bd9 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -104,11 +104,15 @@ def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bo def is_checkpoint_legacy_sharded(object_store: Optional[ObjectStore], source_path: str): metadata_path = str(Path(source_path) / Path('.metadata')) if object_store is None: + if not os.path.exists(source_path): + raise FileNotFoundError(f"Couldn't find the directory {source_path}") return not os.path.exists(metadata_path) else: try: with tempfile.TemporaryDirectory() as temp_dir: metadata_destination = os.path.join(str(temp_dir), '.metadata') + if len(object_store.list_objects(prefix=source_path)) == 0: + raise FileNotFoundError(f"Couldn't find the prefix {object_store.get_uri(object_name=source_path)}") object_store.download_object(object_name=metadata_path, filename=metadata_destination) return False except FileNotFoundError: @@ -398,28 +402,30 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): else: storage_reader = dist_cp.FileSystemReader(source_path) + # We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph. with torch.no_grad(): # 1. Load model and metadata first - model_state_dict = None if load_weights_only: - model_state_dict = {'state': {'model': state.get_model_state_dict()}} + state_dict = {'state': {'model': state.get_model_state_dict()}} else: cur_state_dict = state.state_dict() - cur_state_dict.pop('optimizers') - model_state_dict = {'state': cur_state_dict} + # For older versions of torch, we load optimizier separately. + if version.parse(torch.__version__) <= version.parse("2.1.2"): + cur_state_dict.pop('optimizers') + state_dict = {'state': cur_state_dict} if ignore_keys: # Filter provided list of key paths if not callable(ignore_keys): ignore_keys = glob_filter(ignore_keys) # Call function to modify state_dict - ignore_keys(model_state_dict) + ignore_keys(state_dict) - dist_cp.load_state_dict(model_state_dict, storage_reader) + dist_cp.load_state_dict(state_dict, storage_reader) state.load_state_dict( - model_state_dict['state'], + state_dict['state'], logger, strict=strict_model_weights, exclude_algorithms=exclude_algorithms, @@ -427,7 +433,8 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): ) # 2. Optionally load optimizer - if not load_weights_only: + # if we are using later than 2.1.0 then optimizer will already be loaded + if version.parse(torch.__version__) <= version.parse("2.1.2") and not load_weights_only: optim_state = load_sharded_optimizer_state_dict(model_state_dict=state.state_dict()['model'], optimizer_key='optimizers', storage_reader=storage_reader) @@ -722,12 +729,13 @@ def _restore_checkpoint( if load_path is None: raise RuntimeError(f'Failed to load DeepSpeed checkpoint') elif load_weights_only: - state.load_model_state( + state.load_model_and_optimizer_state( state_dict['state'], logger, strict=strict_model_weights, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, + load_model_only=True ) if not load_weights_only: state.load_state_dict( @@ -767,12 +775,12 @@ def save_checkpoint( # Sharded checkpoints get their own little folder. if state.fsdp_sharded_state_dict_enabled: - # To load optimizer states with torch 2.0, the optimizer state must be at the top + # To load optimizer states with 2.0 <= torch <= 2.1.2 , the optimizer state must be at the top # level of the state dict because the load_sharded_optimizer_state_dict function # requires a top level state dict key for the optimizer. # See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271 # for more info. - if using_torch_2(): + if using_torch_2() and version.parse(torch.__version__) <= version.parse("2.1.2"): if not weights_only: state_dict['optimizers'] = state_dict['state'].pop('optimizers') From 888bec00652b85899e4fda0a110868a6c21e3e06 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Wed, 27 Dec 2023 23:50:12 +0000 Subject: [PATCH 05/66] use device_mesh --- composer/trainer/dist_strategy.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 985783f96f..187cec485a 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -271,11 +271,12 @@ def sync_hook(*args): kwargs['use_orig_params'] = fsdp_config['use_orig_params'] print(version.parse(torch.__version__)) if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'): - from torch.distributed._tensor import init_device_mesh - kwargs['device_mesh'] = init_device_mesh( - 'cuda', - (dist.get_world_size(),), - ) + if 'device_mesh' in fsdp_config: + from torch.distributed._tensor import init_device_mesh + kwargs['device_mesh'] = init_device_mesh( + 'cuda', + tuple(fsdp_config['device_mesh']), + ) # necessary variables for optimizers with multiple param groups in FSDP num_param_groups = None From 24ba9ecffa7203a0a5e42ab00bdbfaddb4e1fd48 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Thu, 28 Dec 2023 20:57:07 +0000 Subject: [PATCH 06/66] Add fsdp init monkeypatch for DTensor --- composer/trainer/mosaic_fsdp.py | 5 + composer/trainer/mosaic_fsdp_utils.py | 135 ++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index bf6ebaa228..e74380af74 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -57,6 +57,11 @@ def patch_pytorch(): elif version.parse(torch.__version__) < version.parse('2.2.0'): # Monkey path for torch < 2.2.0 ie torch == 2.1.1, 2.1.2 + # Monkey patch __init__ where __init__ calls the custom _auto_wrap fn + from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0 + + FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore + # Allow 2D HSDP from torch.distributed.fsdp import _runtime_utils _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index da08772a63..0470dcb3b0 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -753,3 +753,138 @@ def _sharded_pre_load_state_dict_hook( state_dict[fqn_from_global_root] = param.to_local() _enter_unshard_params_ctx(module, fsdp_state, writeback=True) + +if version.parse(torch.__version__) >= version.parse('2.1.2') and version.parse( + torch.__version__) < version.parse('2.2.1'): + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + _init_ignored_module_states, + _init_device_handle, + _annotate_modules_for_dynamo, + _init_process_group_state, + _auto_wrap, + _init_core_state, + _init_runtime_state, + _init_prefetching_state, + _init_buffer_state, + _init_extension, + _init_param_handle_from_module, + _check_orig_params_flattened, + _register_flat_param, + _init_state_dict_state, + _register_all_state_dict_hooks, + ) + from torch.distributed.fsdp._init_utils import ( + HYBRID_SHARDING_STRATEGIES, + ProcessGroupType, + ) + from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy + from torch.distributed._tensor import DeviceMesh + def init_fn_t2p2p0( + self, + module: nn.Module, + process_group: ProcessGroupType = None, + sharding_strategy: Optional[ShardingStrategy] = None, + cpu_offload: Optional[CPUOffload] = None, + auto_wrap_policy: Optional[ + Union[Callable, ModuleWrapPolicy, CustomPolicy] + ] = None, + backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, + mixed_precision: Optional[MixedPrecision] = None, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + device_id: Optional[Union[int, torch.device]] = None, + sync_module_states: bool = False, + forward_prefetch: bool = False, + limit_all_gathers: bool = True, + use_orig_params: bool = False, + ignored_states: Union[ + Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] + ] = None, + device_mesh: Optional[DeviceMesh] = None, + ): + torch._C._log_api_usage_once("torch.distributed.fsdp") + super(FullyShardedDataParallel, self).__init__() + _init_ignored_module_states(self, module, ignored_modules, ignored_states) + _init_device_handle(self, module, self._ignored_params, device_id) + + # Add module annotations for Dynamo support (see function for details) + _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params) + + # Initializes self.process_group, along with rank and world size. This will + # also set another attribute, _inter_node_pg, to control the process group + # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}. + # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up + # the same process group state as the root FSDP module. + self._device_mesh = device_mesh + _init_process_group_state( + self, + process_group, + sharding_strategy, + auto_wrap_policy, + device_mesh, + ) + if auto_wrap_policy is not None: + root_kwargs = { + "process_group": process_group, + "sharding_strategy": sharding_strategy, + "cpu_offload": cpu_offload, + "backward_prefetch": backward_prefetch, + "mixed_precision": mixed_precision, + "param_init_fn": param_init_fn, + "device_id": device_id, + "sync_module_states": sync_module_states, + "forward_prefetch": forward_prefetch, + "limit_all_gathers": limit_all_gathers, + "use_orig_params": use_orig_params, + "ignored_states": self._ignored_params, + "device_mesh": device_mesh, + } + if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None: + # Share root process groups with children to maintain + # the invariant that all FSDP modules will have the same + # process groups. + root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) + + _auto_wrap( + module, + auto_wrap_policy, + self._ignored_modules, + self._ignored_params, + root_kwargs, + FullyShardedDataParallel, + ) + + backward_prefetch_limit = 1 + forward_prefetch_limit = 1 + _init_core_state( + self, + sharding_strategy, + mixed_precision, + cpu_offload, + limit_all_gathers, + use_orig_params, + backward_prefetch_limit, + forward_prefetch_limit, + ) + _init_runtime_state(self) + _init_prefetching_state(self, backward_prefetch, forward_prefetch) + _init_buffer_state(self, module) + # extension needs to be set before `_init_param_handle_from_module()` + _init_extension(self, device_mesh) + _init_param_handle_from_module( + self, + module, + device_id, + param_init_fn, + sync_module_states, + ) + self._fsdp_wrapped_module = module + if not use_orig_params: + _check_orig_params_flattened(self, self._ignored_params) + _register_flat_param(self, self) + + # `_state_dict_type` controls the `state_dict()` behavior, which is + # implemented using post-save and pre-load hooks + _init_state_dict_state(self) + _register_all_state_dict_hooks(self) + From d07b5aff8c53902da7042a32527aab3d40528a31 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Thu, 28 Dec 2023 20:57:31 +0000 Subject: [PATCH 07/66] Add checkpoint profiling logs --- composer/utils/checkpoint.py | 7 +- composer/utils/checkpoint_debug.py | 169 +++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 composer/utils/checkpoint_debug.py diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 1d1e28d4d3..30d4e4e906 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -955,11 +955,14 @@ def save_checkpoint( import torch.distributed.checkpoint as dist_cp log.debug('Saving sharded checkpoints to %s...', save_filename) - dist_cp.save_state_dict( + from composer.utils import checkpoint_debug + log.warning('starting pytorch save state dict') + checkpoint_debug.save_state_dict( state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname), planner=save_planner, ) + log.warning('finished pytorch save state dict') # Only rank 0 saves the state_dict unless you are using sharded checkpointing with torch <2.0 elif dist.get_global_rank() == 0 or state.fsdp_sharded_state_dict_enabled: @@ -976,7 +979,9 @@ def save_checkpoint( else: log.debug(f'Only rank 0 is saving a checkpoint, so rank {dist.get_global_rank()} skips checkpointing.') + log.warning('starting dist barrier') dist.barrier() # ensure all ranks saved their files + log.warning('finished dist barrier') if dist.get_global_rank() == 0 or is_deepspeed or state.fsdp_sharded_state_dict_enabled: assert os.path.exists(save_filename), 'Expected file to have been saved.' diff --git a/composer/utils/checkpoint_debug.py b/composer/utils/checkpoint_debug.py new file mode 100644 index 0000000000..4e9d3e2a43 --- /dev/null +++ b/composer/utils/checkpoint_debug.py @@ -0,0 +1,169 @@ +from typing import Optional +import warnings + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner + + +from torch.distributed.checkpoint.storage import ( + StorageWriter, +) + +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.utils import _DistWrapper + +import logging +log = logging.getLogger(__name__) + +def save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, +) -> Metadata: + """This method is deprecated. Please switch to 'save'.""" + warnings.warn( + "'save_state_dict' is deprecated and will be removed in future versions. Please use 'save' instead." + ) + + # TODO: test returning `save` here instead. + return _save_state_dict(state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner) + +def _save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, +) -> Metadata: + log.warning('starting pytorch save state dict') + + torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + distW.reduce_scatter = reduce_scatter.__get__(distW, _DistWrapper) + if planner is None: + planner = DefaultSavePlanner() + assert planner is not None + + global_metatadata = None + + def local_step(): + log.warning('starting local step') + assert planner is not None + planner.set_up_planner(state_dict, distW.is_coordinator) + storage_writer.set_up_storage_writer(distW.is_coordinator) + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + log.warning('finished local step') + return local_plan + + def global_step(all_local_plans): + log.warning('starting global step') + nonlocal global_metatadata + + assert planner is not None + all_local_plans, global_metatadata = planner.create_global_plan( + all_local_plans + ) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + log.warning('finished global step') + return all_local_plans + + log.warning('starting reduce scatter') + central_plan = distW.reduce_scatter("plan", local_step, global_step) + + def write_data(): + log.warning('starting write data') + assert planner is not None + final_local_plan = planner.finish_plan(central_plan) + all_writes = storage_writer.write_data(final_local_plan, planner) + + all_writes.wait() + log.warning('finished write data') + return all_writes.value() + + def finish_checkpoint(all_results): + log.warning('starting finish checkpoint') + assert global_metatadata is not None + storage_writer.finish(metadata=global_metatadata, results=all_results) + log.warning('finished finish checkpoint') + return global_metatadata + + log.warning('starting all reduce') + return distW.all_reduce("write", write_data, finish_checkpoint) + +from typing import ( + List, + Callable, + Optional, + Union, + TypeVar, + cast, +) +from torch.distributed.checkpoint.api import ( + CheckpointException, + _wrap_exception, + WRAPPED_EXCEPTION, +) +from torch.distributed.checkpoint.utils import _get_failure_dict +T = TypeVar("T") +R = TypeVar("R") + + +def reduce_scatter( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[List[T]], List[R]], +) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Scatter to each rank part of the result. + """ + local_data: Union[WRAPPED_EXCEPTION, T] + try: + local_data = map_fun() + except BaseException as e: + local_data = _wrap_exception(e) + + log.warning('starting gather') + all_data = self.gather_object(local_data) + log.warning('finished gather') + all_results: Optional[List[Union[R, CheckpointException]]] = None + log.warning('starting rank 0 work') + if self.is_coordinator: + assert all_data is not None + node_failures = _get_failure_dict(all_data) + + if len(node_failures) == 0: + try: + # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? + all_results = cast( + List[Union[R, CheckpointException]], + reduce_fun(cast(List[T], all_data)), + ) + except BaseException as e: + node_failures[self.rank] = _wrap_exception(e) + + if len(node_failures) > 0: + all_results = [CheckpointException(step, node_failures)] * self.get_world_size() + log.warning('finished rank 0 work') + + log.warning('starting scatter') + result = self.scatter_object(all_results) + log.warning('finished scatter') + if isinstance(result, CheckpointException): + raise result + return result From 9326b0e82150dddd144deaeec8c17afed36a9536 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Thu, 28 Dec 2023 17:11:14 -0800 Subject: [PATCH 08/66] attempt --- composer/core/state.py | 17 ++++++++++++----- composer/utils/checkpoint.py | 6 ++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 67e4d5371d..d59a3fe902 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -792,6 +792,13 @@ def fsdp_state_dict_type(self): def fsdp_sharded_state_dict_enabled(self): return self.fsdp_config is not None and self.fsdp_enabled and self.fsdp_state_dict_type in ['sharded', 'local'] + @property + def fsdp_device_mesh(self): + if self.fsdp_enabled: + return self.model._device_map + else: + return None + @property def load_fsdp_monolith_rank0_only(self): return self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config[ @@ -897,12 +904,12 @@ def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str if self.fsdp_state_dict_type not in ['full', 'sharded']: raise NotImplementedError( textwrap.dedent( - f"fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for a torch version > 2.1.2." + f"fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for a torch version > 2.1.2." f"You are using {version.parse(torch.__version__)}" - ) + ) ) - optimizer = ensure_tuple(self.optimizers)[0] + optimizer = ensure_tuple(self.optimizers)[0] model_state_dict, optim_state_dict = get_state_dict(model=self.model, optimizers=([] if model_only else optimizer), submodules=None, @@ -1245,10 +1252,10 @@ def load_model_and_optimizer_state(self, exclude_algorithms: Optional[List[str]] = None, algorithm_passes: Optional[List[AlgorithmPass]] = None, load_model_only: bool = False): - # Note: In this case required algorithms not applied. + # Note: In this case required algorithms not applied. if version.parse(torch.__version__) > version.parse("2.1.2"): from torch.distributed.checkpoint.state_dict import set_state_dict, StateDictOptions - optimizer = ensure_tuple(self.optimizers)[0] + optimizer = ensure_tuple(self.optimizers)[0] set_state_dict(self.model, optimizers=([] if load_model_only else optimizer), model_state_dict=state_dict['model'], diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 30d4e4e906..a62d19a402 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -957,10 +957,16 @@ def save_checkpoint( log.debug('Saving sharded checkpoints to %s...', save_filename) from composer.utils import checkpoint_debug log.warning('starting pytorch save state dict') + device_mesh = state.fsdp_device_mesh + if device_mesh is not None: + process_group = device_mesh.get_group(1) + else: + process_group = None checkpoint_debug.save_state_dict( state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname), planner=save_planner, + process_group=process_group, ) log.warning('finished pytorch save state dict') From 221e58ed07683ee84546fd3531ea3645fab775ec Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Fri, 29 Dec 2023 02:09:32 +0000 Subject: [PATCH 09/66] working single node --- composer/core/state.py | 2 +- composer/utils/checkpoint.py | 29 +++-- composer/utils/checkpoint_debug.py | 169 ----------------------------- 3 files changed, 21 insertions(+), 179 deletions(-) delete mode 100644 composer/utils/checkpoint_debug.py diff --git a/composer/core/state.py b/composer/core/state.py index d59a3fe902..3635949b68 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -795,7 +795,7 @@ def fsdp_sharded_state_dict_enabled(self): @property def fsdp_device_mesh(self): if self.fsdp_enabled: - return self.model._device_map + return self.model.model._device_mesh else: return None diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index a62d19a402..1090ddca0d 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -934,8 +934,12 @@ def save_checkpoint( if dirname: os.makedirs(dirname, exist_ok=True) + # Only some ranks are meant to save checkpoint and produce a file + expect_file = False + # All ranks save for deepspeed if is_deepspeed: + expect_file = True log.debug('Saving deepspeed checkpoints to %s...', save_filename) if dist.get_global_rank() == 0: with open(save_filename, 'wb') as f: @@ -953,25 +957,32 @@ def save_checkpoint( _validate_save_planner(save_planner) import torch.distributed.checkpoint as dist_cp + from torch.distributed import get_process_group_ranks log.debug('Saving sharded checkpoints to %s...', save_filename) - from composer.utils import checkpoint_debug log.warning('starting pytorch save state dict') device_mesh = state.fsdp_device_mesh if device_mesh is not None: - process_group = device_mesh.get_group(1) + mesh_pg_1 = device_mesh.get_group(1) + mesh_pg_1_ranks = get_process_group_ranks(mesh_pg_1) + expect_file = (0 in mesh_pg_1_ranks) + log.debug(f'global_rank={dist.get_global_rank()}, {mesh_pg_1_ranks=}, {expect_file=}') else: process_group = None - checkpoint_debug.save_state_dict( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(dirname), - planner=save_planner, - process_group=process_group, - ) + expect_file = True + + if expect_file: + dist_cp.save( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(dirname), + planner=save_planner, + process_group=mesh_pg_1, + ) log.warning('finished pytorch save state dict') # Only rank 0 saves the state_dict unless you are using sharded checkpointing with torch <2.0 elif dist.get_global_rank() == 0 or state.fsdp_sharded_state_dict_enabled: + expect_file = True log_msg = f'Saving sharded checkpoints to {save_filename}...' if state.fsdp_sharded_state_dict_enabled else f'Saving monolithic checkpoint to {save_filename}' with open(save_filename, 'wb') as f: log.debug(log_msg) @@ -989,7 +1000,7 @@ def save_checkpoint( dist.barrier() # ensure all ranks saved their files log.warning('finished dist barrier') - if dist.get_global_rank() == 0 or is_deepspeed or state.fsdp_sharded_state_dict_enabled: + if expect_file: assert os.path.exists(save_filename), 'Expected file to have been saved.' return save_filename else: diff --git a/composer/utils/checkpoint_debug.py b/composer/utils/checkpoint_debug.py deleted file mode 100644 index 4e9d3e2a43..0000000000 --- a/composer/utils/checkpoint_debug.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import Optional -import warnings - -import torch -import torch.distributed as dist -from torch.distributed.checkpoint.stateful import Stateful -from torch.distributed.checkpoint.planner import SavePlanner -from torch.distributed.checkpoint.default_planner import DefaultSavePlanner - - -from torch.distributed.checkpoint.storage import ( - StorageWriter, -) - -from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE -from torch.distributed.checkpoint.utils import _DistWrapper - -import logging -log = logging.getLogger(__name__) - -def save_state_dict( - state_dict: STATE_DICT_TYPE, - storage_writer: StorageWriter, - process_group: Optional[dist.ProcessGroup] = None, - coordinator_rank: int = 0, - no_dist: bool = False, - planner: Optional[SavePlanner] = None, -) -> Metadata: - """This method is deprecated. Please switch to 'save'.""" - warnings.warn( - "'save_state_dict' is deprecated and will be removed in future versions. Please use 'save' instead." - ) - - # TODO: test returning `save` here instead. - return _save_state_dict(state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner) - -def _save_state_dict( - state_dict: STATE_DICT_TYPE, - storage_writer: StorageWriter, - process_group: Optional[dist.ProcessGroup] = None, - coordinator_rank: int = 0, - no_dist: bool = False, - planner: Optional[SavePlanner] = None, -) -> Metadata: - log.warning('starting pytorch save state dict') - - torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict") - - distW = _DistWrapper(process_group, not no_dist, coordinator_rank) - distW.reduce_scatter = reduce_scatter.__get__(distW, _DistWrapper) - if planner is None: - planner = DefaultSavePlanner() - assert planner is not None - - global_metatadata = None - - def local_step(): - log.warning('starting local step') - assert planner is not None - planner.set_up_planner(state_dict, distW.is_coordinator) - storage_writer.set_up_storage_writer(distW.is_coordinator) - local_plan = planner.create_local_plan() - local_plan = storage_writer.prepare_local_plan(local_plan) - log.warning('finished local step') - return local_plan - - def global_step(all_local_plans): - log.warning('starting global step') - nonlocal global_metatadata - - assert planner is not None - all_local_plans, global_metatadata = planner.create_global_plan( - all_local_plans - ) - all_local_plans = storage_writer.prepare_global_plan(all_local_plans) - log.warning('finished global step') - return all_local_plans - - log.warning('starting reduce scatter') - central_plan = distW.reduce_scatter("plan", local_step, global_step) - - def write_data(): - log.warning('starting write data') - assert planner is not None - final_local_plan = planner.finish_plan(central_plan) - all_writes = storage_writer.write_data(final_local_plan, planner) - - all_writes.wait() - log.warning('finished write data') - return all_writes.value() - - def finish_checkpoint(all_results): - log.warning('starting finish checkpoint') - assert global_metatadata is not None - storage_writer.finish(metadata=global_metatadata, results=all_results) - log.warning('finished finish checkpoint') - return global_metatadata - - log.warning('starting all reduce') - return distW.all_reduce("write", write_data, finish_checkpoint) - -from typing import ( - List, - Callable, - Optional, - Union, - TypeVar, - cast, -) -from torch.distributed.checkpoint.api import ( - CheckpointException, - _wrap_exception, - WRAPPED_EXCEPTION, -) -from torch.distributed.checkpoint.utils import _get_failure_dict -T = TypeVar("T") -R = TypeVar("R") - - -def reduce_scatter( - self, - step: str, - map_fun: Callable[[], T], - reduce_fun: Callable[[List[T]], List[R]], -) -> R: - """ - Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. - - This method operates in the following way: - Run ``map_fun`` on all ranks - Gather results on rank 0 - Call ``reduce_fun`` on all those values - Scatter to each rank part of the result. - """ - local_data: Union[WRAPPED_EXCEPTION, T] - try: - local_data = map_fun() - except BaseException as e: - local_data = _wrap_exception(e) - - log.warning('starting gather') - all_data = self.gather_object(local_data) - log.warning('finished gather') - all_results: Optional[List[Union[R, CheckpointException]]] = None - log.warning('starting rank 0 work') - if self.is_coordinator: - assert all_data is not None - node_failures = _get_failure_dict(all_data) - - if len(node_failures) == 0: - try: - # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? - all_results = cast( - List[Union[R, CheckpointException]], - reduce_fun(cast(List[T], all_data)), - ) - except BaseException as e: - node_failures[self.rank] = _wrap_exception(e) - - if len(node_failures) > 0: - all_results = [CheckpointException(step, node_failures)] * self.get_world_size() - log.warning('finished rank 0 work') - - log.warning('starting scatter') - result = self.scatter_object(all_results) - log.warning('finished scatter') - if isinstance(result, CheckpointException): - raise result - return result From b60d0ab7d0489a42ba0e60d4f7d65caa84438cef Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 30 Dec 2023 22:09:55 +0000 Subject: [PATCH 10/66] fix optimizer --- composer/core/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/core/state.py b/composer/core/state.py index 3635949b68..2906656f50 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -1299,7 +1299,7 @@ def load_state_dict( strict=strict, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, - load_model_only=('optimizers' in state) + load_model_only=(not 'optimizers' in state) ) for attribute_name in sorted(state.keys()): # Sort so all ranks load in the same order From 28aa784723e0910259ecc993f9cfb612ad3b0a81 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Wed, 3 Jan 2024 09:00:00 +0000 Subject: [PATCH 11/66] allow 3d device mesh --- composer/trainer/mosaic_fsdp_utils.py | 116 +++++++++++++++++++++++++- composer/utils/checkpoint.py | 14 ++-- 2 files changed, 120 insertions(+), 10 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 0470dcb3b0..d5901db323 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -760,7 +760,6 @@ def _sharded_pre_load_state_dict_hook( _init_ignored_module_states, _init_device_handle, _annotate_modules_for_dynamo, - _init_process_group_state, _auto_wrap, _init_core_state, _init_runtime_state, @@ -776,9 +775,120 @@ def _sharded_pre_load_state_dict_hook( from torch.distributed.fsdp._init_utils import ( HYBRID_SHARDING_STRATEGIES, ProcessGroupType, + _init_intra_and_inter_node_groups, + _is_valid_hybrid_shard_pg_type, + _get_default_comm_hook_state, ) - from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy + from torch.distributed.fsdp.wrap import _Policy, CustomPolicy, ModuleWrapPolicy from torch.distributed._tensor import DeviceMesh + from torch.distributed.fsdp._common_utils import _FSDPState + from torch.distributed.algorithms._comm_hooks import default_hooks + + + def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool: + #parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) + #if parent_mesh is not None: + # raise RuntimeError( + # f"Found device_mesh {device_mesh} passed in has a parent device_mesh {parent_mesh}.", + # "Hybrid sharding + TP is not supported yet.", + # ) + return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim in [2,3] + + + def _init_process_group_state_for_hybrid_shard_t2p2p0( + state: _FSDPState, + process_group: ProcessGroupType, + device_mesh: DeviceMesh, + ) -> _FSDPState: + if device_mesh: + if _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh): + state._device_mesh = device_mesh + # We currently only allow _inter_node_pg to be the outermost dimension, and the + # process_group(intra_node) to be the innermost dimension. + state._inter_node_pg = device_mesh.get_group(mesh_dim=0) + state.process_group = device_mesh.get_group(mesh_dim=1) + else: + raise ValueError( + "Expected device_mesh to have ndim=2 " + f"but got {len(device_mesh.get_group())}" + ) + elif process_group is None: + default_group = _get_default_group() + intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( + default_group, state._device_handle.device_count() + ) + # we shard across intra-node + state.process_group = intra_node_group + # save _inter_node_pg to allreduce across. + state._inter_node_pg = inter_node_group + else: + # Check type and assign state.process_group and state._inter_node_pg. + if _is_valid_hybrid_shard_pg_type(process_group): + # Assuming that user passed in as intra node group and inter node group + # as documented. + state.process_group, state._inter_node_pg = process_group + else: + raise ValueError( + "Expected process_group to be passed in as either None or " + f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}" + ) + # Create state for allreduce + state._inter_node_state = _get_default_comm_hook_state( + process_group=state._inter_node_pg, + ) + return state + + + def _init_process_group_state_t2p2p0( + state: _FSDPState, + process_group: ProcessGroupType, + sharding_strategy: ShardingStrategy, + policy: Optional[_Policy], + device_mesh: Optional[DeviceMesh] = None, + ) -> _FSDPState: + if process_group is not None and device_mesh is not None: + raise ValueError( + "Cannot pass both process_group and device_mesh at the " + "same time. Please just pass only one of them." + ) + is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES + if is_hybrid_strategy: + if process_group is None and policy is None and device_mesh is None: + # Raise an error here, since this is manual wrapping with no process group + # passed in, there is no way to ensure all wrapped FSDP instances use the same + # process groups. + raise ValueError( + f"Manual wrapping with {sharding_strategy}", + "requires explicit specification of process group or device_mesh.", + ) + else: + state = _init_process_group_state_for_hybrid_shard_t2p2p0( + state, process_group, device_mesh + ) + else: + if device_mesh: + state._device_mesh = device_mesh + state.process_group = device_mesh.get_group(mesh_dim=0) + else: + state.process_group = ( + process_group if process_group is not None else _get_default_group() + ) + + state.rank = state.process_group.rank() + state.world_size = state.process_group.size() + data_parallel_world_size = state.world_size + if is_hybrid_strategy: + data_parallel_world_size *= state._inter_node_pg.size() + state._gradient_predivide_factor = ( + default_hooks.DefaultState._get_gradient_predivide_factor( + data_parallel_world_size + ) + ) + state._gradient_postdivide_factor = ( + data_parallel_world_size / state._gradient_predivide_factor + ) + return state + def init_fn_t2p2p0( self, module: nn.Module, @@ -816,7 +926,7 @@ def init_fn_t2p2p0( # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up # the same process group state as the root FSDP module. self._device_mesh = device_mesh - _init_process_group_state( + _init_process_group_state_t2p2p0( self, process_group, sharding_strategy, diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 1090ddca0d..c1ad1ee6f7 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -962,21 +962,21 @@ def save_checkpoint( log.debug('Saving sharded checkpoints to %s...', save_filename) log.warning('starting pytorch save state dict') device_mesh = state.fsdp_device_mesh - if device_mesh is not None: - mesh_pg_1 = device_mesh.get_group(1) - mesh_pg_1_ranks = get_process_group_ranks(mesh_pg_1) - expect_file = (0 in mesh_pg_1_ranks) - log.debug(f'global_rank={dist.get_global_rank()}, {mesh_pg_1_ranks=}, {expect_file=}') + process_group = None + if device_mesh is not None and device_mesh.ndim == 2: + expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) + process_group = device_mesh.get_group(1) + log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}') else: - process_group = None expect_file = True + process_group = None if expect_file: dist_cp.save( state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname), planner=save_planner, - process_group=mesh_pg_1, + process_group=process_group, ) log.warning('finished pytorch save state dict') From 730368bdbef20c6bce6cc0c01e9dd92b967ad291 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Wed, 3 Jan 2024 09:59:51 +0000 Subject: [PATCH 12/66] attempt to use different pg during 3d mesh save --- composer/utils/checkpoint.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index c1ad1ee6f7..afdd8158f4 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -963,13 +963,17 @@ def save_checkpoint( log.warning('starting pytorch save state dict') device_mesh = state.fsdp_device_mesh process_group = None - if device_mesh is not None and device_mesh.ndim == 2: + coordinator_rank = 0 + if device_mesh is not None and device_mesh.ndim in [2, 3]: expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) process_group = device_mesh.get_group(1) - log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}') + if device_mesh.ndim == 3: + coordinator_rank = dist.get_global_rank() % device_mesh.shape[2] + log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}, {coordinator_rank=}') else: expect_file = True process_group = None + coordinator_rank = 0 if expect_file: dist_cp.save( @@ -977,6 +981,7 @@ def save_checkpoint( storage_writer=dist_cp.FileSystemWriter(dirname), planner=save_planner, process_group=process_group, + coordinator_rank=coordinator_rank, ) log.warning('finished pytorch save state dict') From ddb336f4095941d8847a7b7b9150dcd70ca87708 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Fri, 5 Jan 2024 01:11:01 +0000 Subject: [PATCH 13/66] undo 3d mesh changes --- composer/trainer/mosaic_fsdp_utils.py | 116 +------------------------- composer/utils/checkpoint.py | 10 +-- 2 files changed, 5 insertions(+), 121 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index d5901db323..0470dcb3b0 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -760,6 +760,7 @@ def _sharded_pre_load_state_dict_hook( _init_ignored_module_states, _init_device_handle, _annotate_modules_for_dynamo, + _init_process_group_state, _auto_wrap, _init_core_state, _init_runtime_state, @@ -775,120 +776,9 @@ def _sharded_pre_load_state_dict_hook( from torch.distributed.fsdp._init_utils import ( HYBRID_SHARDING_STRATEGIES, ProcessGroupType, - _init_intra_and_inter_node_groups, - _is_valid_hybrid_shard_pg_type, - _get_default_comm_hook_state, ) - from torch.distributed.fsdp.wrap import _Policy, CustomPolicy, ModuleWrapPolicy + from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy from torch.distributed._tensor import DeviceMesh - from torch.distributed.fsdp._common_utils import _FSDPState - from torch.distributed.algorithms._comm_hooks import default_hooks - - - def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool: - #parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) - #if parent_mesh is not None: - # raise RuntimeError( - # f"Found device_mesh {device_mesh} passed in has a parent device_mesh {parent_mesh}.", - # "Hybrid sharding + TP is not supported yet.", - # ) - return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim in [2,3] - - - def _init_process_group_state_for_hybrid_shard_t2p2p0( - state: _FSDPState, - process_group: ProcessGroupType, - device_mesh: DeviceMesh, - ) -> _FSDPState: - if device_mesh: - if _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh): - state._device_mesh = device_mesh - # We currently only allow _inter_node_pg to be the outermost dimension, and the - # process_group(intra_node) to be the innermost dimension. - state._inter_node_pg = device_mesh.get_group(mesh_dim=0) - state.process_group = device_mesh.get_group(mesh_dim=1) - else: - raise ValueError( - "Expected device_mesh to have ndim=2 " - f"but got {len(device_mesh.get_group())}" - ) - elif process_group is None: - default_group = _get_default_group() - intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( - default_group, state._device_handle.device_count() - ) - # we shard across intra-node - state.process_group = intra_node_group - # save _inter_node_pg to allreduce across. - state._inter_node_pg = inter_node_group - else: - # Check type and assign state.process_group and state._inter_node_pg. - if _is_valid_hybrid_shard_pg_type(process_group): - # Assuming that user passed in as intra node group and inter node group - # as documented. - state.process_group, state._inter_node_pg = process_group - else: - raise ValueError( - "Expected process_group to be passed in as either None or " - f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}" - ) - # Create state for allreduce - state._inter_node_state = _get_default_comm_hook_state( - process_group=state._inter_node_pg, - ) - return state - - - def _init_process_group_state_t2p2p0( - state: _FSDPState, - process_group: ProcessGroupType, - sharding_strategy: ShardingStrategy, - policy: Optional[_Policy], - device_mesh: Optional[DeviceMesh] = None, - ) -> _FSDPState: - if process_group is not None and device_mesh is not None: - raise ValueError( - "Cannot pass both process_group and device_mesh at the " - "same time. Please just pass only one of them." - ) - is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES - if is_hybrid_strategy: - if process_group is None and policy is None and device_mesh is None: - # Raise an error here, since this is manual wrapping with no process group - # passed in, there is no way to ensure all wrapped FSDP instances use the same - # process groups. - raise ValueError( - f"Manual wrapping with {sharding_strategy}", - "requires explicit specification of process group or device_mesh.", - ) - else: - state = _init_process_group_state_for_hybrid_shard_t2p2p0( - state, process_group, device_mesh - ) - else: - if device_mesh: - state._device_mesh = device_mesh - state.process_group = device_mesh.get_group(mesh_dim=0) - else: - state.process_group = ( - process_group if process_group is not None else _get_default_group() - ) - - state.rank = state.process_group.rank() - state.world_size = state.process_group.size() - data_parallel_world_size = state.world_size - if is_hybrid_strategy: - data_parallel_world_size *= state._inter_node_pg.size() - state._gradient_predivide_factor = ( - default_hooks.DefaultState._get_gradient_predivide_factor( - data_parallel_world_size - ) - ) - state._gradient_postdivide_factor = ( - data_parallel_world_size / state._gradient_predivide_factor - ) - return state - def init_fn_t2p2p0( self, module: nn.Module, @@ -926,7 +816,7 @@ def init_fn_t2p2p0( # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up # the same process group state as the root FSDP module. self._device_mesh = device_mesh - _init_process_group_state_t2p2p0( + _init_process_group_state( self, process_group, sharding_strategy, diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index afdd8158f4..327b365045 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -962,18 +962,13 @@ def save_checkpoint( log.debug('Saving sharded checkpoints to %s...', save_filename) log.warning('starting pytorch save state dict') device_mesh = state.fsdp_device_mesh - process_group = None - coordinator_rank = 0 - if device_mesh is not None and device_mesh.ndim in [2, 3]: + if device_mesh is not None and device_mesh.ndim == 2: expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) process_group = device_mesh.get_group(1) - if device_mesh.ndim == 3: - coordinator_rank = dist.get_global_rank() % device_mesh.shape[2] - log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}, {coordinator_rank=}') + log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}') else: expect_file = True process_group = None - coordinator_rank = 0 if expect_file: dist_cp.save( @@ -981,7 +976,6 @@ def save_checkpoint( storage_writer=dist_cp.FileSystemWriter(dirname), planner=save_planner, process_group=process_group, - coordinator_rank=coordinator_rank, ) log.warning('finished pytorch save state dict') From f184e318b28123e9aff94e1e8b0e203207c74f50 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Fri, 5 Jan 2024 02:59:30 +0000 Subject: [PATCH 14/66] load_state_dict -> load --- composer/utils/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 327b365045..956de96d98 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -517,7 +517,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # Call function to modify state_dict ignore_keys(state_dict) - dist_cp.load_state_dict(state_dict, storage_reader) + dist_cp.load(state_dict, storage_reader) state.load_state_dict( state_dict['state'], From dd1c3c46459163d1fae12b0204a8bb16325f4f3e Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Fri, 5 Jan 2024 16:43:33 +0000 Subject: [PATCH 15/66] allow parent mesh in FSDP init --- composer/trainer/mosaic_fsdp_utils.py | 117 +++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 3 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 0470dcb3b0..e99ec335ae 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -760,7 +760,6 @@ def _sharded_pre_load_state_dict_hook( _init_ignored_module_states, _init_device_handle, _annotate_modules_for_dynamo, - _init_process_group_state, _auto_wrap, _init_core_state, _init_runtime_state, @@ -776,9 +775,121 @@ def _sharded_pre_load_state_dict_hook( from torch.distributed.fsdp._init_utils import ( HYBRID_SHARDING_STRATEGIES, ProcessGroupType, + _init_intra_and_inter_node_groups, + _is_valid_hybrid_shard_pg_type, + _get_default_comm_hook_state, ) - from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy + from torch.distributed.fsdp.wrap import _Policy, CustomPolicy, ModuleWrapPolicy from torch.distributed._tensor import DeviceMesh + from torch.distributed.fsdp._common_utils import _FSDPState + from torch.distributed.algorithms._comm_hooks import default_hooks + from torch.distributed.distributed_c10d import _get_default_group + + + def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool: + #parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) + #if parent_mesh is not None: + # raise RuntimeError( + # f"Found device_mesh {device_mesh} passed in has a parent device_mesh {parent_mesh}.", + # "Hybrid sharding + TP is not supported yet.", + # ) + return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2 + + + def _init_process_group_state_for_hybrid_shard_t2p2p0( + state: _FSDPState, + process_group: ProcessGroupType, + device_mesh: DeviceMesh, + ) -> _FSDPState: + if device_mesh: + if _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh): + state._device_mesh = device_mesh + # We currently only allow _inter_node_pg to be the outermost dimension, and the + # process_group(intra_node) to be the innermost dimension. + state._inter_node_pg = device_mesh.get_group(mesh_dim=0) + state.process_group = device_mesh.get_group(mesh_dim=1) + else: + raise ValueError( + "Expected device_mesh to have ndim=2 " + f"but got {len(device_mesh.get_group())}" + ) + elif process_group is None: + default_group = _get_default_group() + intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( + default_group, state._device_handle.device_count() + ) + # we shard across intra-node + state.process_group = intra_node_group + # save _inter_node_pg to allreduce across. + state._inter_node_pg = inter_node_group + else: + # Check type and assign state.process_group and state._inter_node_pg. + if _is_valid_hybrid_shard_pg_type(process_group): + # Assuming that user passed in as intra node group and inter node group + # as documented. + state.process_group, state._inter_node_pg = process_group + else: + raise ValueError( + "Expected process_group to be passed in as either None or " + f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}" + ) + # Create state for allreduce + state._inter_node_state = _get_default_comm_hook_state( + process_group=state._inter_node_pg, + ) + return state + + + def _init_process_group_state_t2p2p0( + state: _FSDPState, + process_group: ProcessGroupType, + sharding_strategy: ShardingStrategy, + policy: Optional[_Policy], + device_mesh: Optional[DeviceMesh] = None, + ) -> _FSDPState: + if process_group is not None and device_mesh is not None: + raise ValueError( + "Cannot pass both process_group and device_mesh at the " + "same time. Please just pass only one of them." + ) + is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES + if is_hybrid_strategy: + if process_group is None and policy is None and device_mesh is None: + # Raise an error here, since this is manual wrapping with no process group + # passed in, there is no way to ensure all wrapped FSDP instances use the same + # process groups. + raise ValueError( + f"Manual wrapping with {sharding_strategy}", + "requires explicit specification of process group or device_mesh.", + ) + else: + state = _init_process_group_state_for_hybrid_shard_t2p2p0( + state, process_group, device_mesh + ) + else: + if device_mesh: + state._device_mesh = device_mesh + state.process_group = device_mesh.get_group(mesh_dim=0) + else: + state.process_group = ( + process_group if process_group is not None else _get_default_group() + ) + + state.rank = state.process_group.rank() + state.world_size = state.process_group.size() + data_parallel_world_size = state.world_size + if is_hybrid_strategy: + data_parallel_world_size *= state._inter_node_pg.size() + state._gradient_predivide_factor = ( + default_hooks.DefaultState._get_gradient_predivide_factor( + data_parallel_world_size + ) + ) + state._gradient_postdivide_factor = ( + data_parallel_world_size / state._gradient_predivide_factor + ) + return state + def init_fn_t2p2p0( self, module: nn.Module, @@ -816,7 +927,7 @@ def init_fn_t2p2p0( # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up # the same process group state as the root FSDP module. self._device_mesh = device_mesh - _init_process_group_state( + _init_process_group_state_t2p2p0( self, process_group, sharding_strategy, From d5f9951faa087f7536115950b268ff4eac0c4a63 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Fri, 5 Jan 2024 17:46:53 +0000 Subject: [PATCH 16/66] allow override of force_sync_module_states --- composer/trainer/dist_strategy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 424e13f546..2b4540050d 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -232,7 +232,9 @@ def prepare_fsdp_module( set_fsdp_default(fsdp_config) # Check sync_module_states is True for mixed initialization or HSDP - if fsdp_config['sync_module_states'] == False: + print (fsdp_config) + exit(0) + if fsdp_config['sync_module_states'] == False and not fsdp_config.get('force_sync_module_states', False): rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0 all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8)) dist.all_reduce(all_ranks_meta, reduce_operation='MIN') @@ -533,7 +535,6 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]: if ret and auto_microbatching: module.register_forward_hook(sync_hook) module.register_full_backward_hook(sync_hook) - print(module, '\n', ret) return ret _auto_wrap_policy = CustomPolicy(lambda_fn) From 0aed0bf65dba2d8ebc9f502a48e22cb4b1852aea Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Fri, 5 Jan 2024 17:49:23 +0000 Subject: [PATCH 17/66] remove unnecessary exit --- composer/trainer/dist_strategy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 2b4540050d..c8f1a7b8d0 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -232,8 +232,6 @@ def prepare_fsdp_module( set_fsdp_default(fsdp_config) # Check sync_module_states is True for mixed initialization or HSDP - print (fsdp_config) - exit(0) if fsdp_config['sync_module_states'] == False and not fsdp_config.get('force_sync_module_states', False): rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0 all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8)) From dd936ea1bfa407e34a2cd893e05fc5e187da761d Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Fri, 5 Jan 2024 18:55:53 +0000 Subject: [PATCH 18/66] ignore _validate_and_get_shard_state() --- composer/trainer/mosaic_fsdp_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 2dd1935bc4..1d4a9021d5 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -1151,7 +1151,7 @@ def _share_state_and_init_handle_attrs_t2p2( handle = root_state._handle if handle: handle.init_flat_param_attributes() - _validate_and_get_hybrid_shard_state(root_module) + #_validate_and_get_hybrid_shard_state(root_module) attr_name_to_values: Dict[str, Set[Any]] = {} for attr_name in HOMOGENEOUS_ATTR_NAMES: attr_name_to_values[attr_name] = set() From f864a2699a67fbb28d30df1c9f98719b32f921a4 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Fri, 5 Jan 2024 21:15:54 +0000 Subject: [PATCH 19/66] save/load hsdp-moe working --- composer/trainer/mosaic_fsdp_utils.py | 114 +++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 3 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 1d4a9021d5..f8e102f7fb 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -766,7 +766,7 @@ def _sharded_pre_load_state_dict_hook( _init_runtime_state, _init_prefetching_state, _init_buffer_state, - _init_extension, + #_init_extension, _init_param_handle_from_module, _check_orig_params_flattened, _register_flat_param, @@ -781,11 +781,119 @@ def _sharded_pre_load_state_dict_hook( _get_default_comm_hook_state, ) from torch.distributed.fsdp.wrap import _Policy, CustomPolicy, ModuleWrapPolicy - from torch.distributed._tensor import DeviceMesh + from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.fsdp._common_utils import _FSDPState from torch.distributed.algorithms._comm_hooks import default_hooks from torch.distributed.distributed_c10d import _get_default_group + from torch.distributed.tensor.parallel.fsdp import DTensorExtensions + from torch.distributed.device_mesh import _mesh_resources + import copy + + def all_gather_dtensor_t2p2p0( + self, + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + ) -> torch.Tensor: + print('\n\n\n\n\here!') + """All gather a DTensor in its FSDP dimension and return the local tensor.""" + assert parent_mesh == tensor.device_mesh + + placements = list(copy.deepcopy(tensor.placements)) + # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement] + # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement] + for i in range(0, len(placements)-1): + placements[i] = Replicate() + print (len(placements), placements) + tensor = tensor.redistribute( + device_mesh=tensor.device_mesh, + placements=placements, + ) + return tensor.to_local() + + + def chunk_dtensor_t2p2p0( + self, + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + ) -> DTensor: + """ + Shard a tensor to chunks along the first dimension. + + The local rank will gets its corresponding chunk as the local tensor to create a DTensor. + """ + parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) + if parent_mesh is None: + raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.") + # if parent_mesh.ndim != 2: + # raise RuntimeError( + # f"Found parent device_mesh of ndim={parent_mesh.ndim},", + # "but only 2D meshes are currently supported.", + # ) + + # We need to explicitly call .detach() to return a new tensor detached from the current graph. + tensor = tensor.clone().detach() + + # When a layer is not involved in TP, then the tensor will not be a DTensor. + # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. + # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. + if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): + + # For tensors, it is replicated across tp dimension and sharded across FSDP dimension. + # TP is the inner dimension and FSDP is the outer dimension. + # Therefore, shard placements for tensor is (Shard(0), Replicate()). + replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)] + shard_placements = [Replicate() for _ in range(parent_mesh.ndim)] + shard_placements[0] = DShard(0) # type: ignore[call-overload] + + return DTensor.from_local( + tensor, parent_mesh, replicate_placements + ).redistribute( + device_mesh=parent_mesh, + placements=shard_placements, + ) + + else: + tp_placements = tensor.placements + tp_placement = tp_placements[0] + + tensor = tensor.to_local() + + if parent_mesh.ndim <= 2: + # For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension. + # TP is the inner dimension and FSDP is the outer dimension. + # Therefore, shard placements for tensor is (Shard(0), tp_placement). + replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)] + replicate_placements[-1] = tp_placement # type: ignore[call-overload] + shard_placements = [DShard(0) for _ in range(parent_mesh.ndim)] # type: ignore[misc] + shard_placements[-1] = tp_placement # type: ignore[call-overload] + + + elif parent_mesh.ndim == 3: + replicate_placements = [Replicate(), Replicate(), tp_placement] + shard_placements = [Replicate(), DShard(0), tp_placement] # type: ignore[misc] + + return DTensor.from_local( + tensor, parent_mesh, replicate_placements + ).redistribute( + device_mesh=parent_mesh, + placements=shard_placements, + ) + + DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p2p0 + DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p2p0 + + def _init_extension_t2p2p0(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: + # TODO: we need to add additional check once we support FSDP + PiPPy. + # This check is currently sufficient, since we only support FSDP + TP. + if device_mesh and _mesh_resources.get_parent_mesh(state._device_mesh) is not None: + state._fsdp_extension = DTensorExtensions() + else: + # We need to explicilty set _fsdp_extension to None. + # Otherwise, we will run into an infinite recursion when getting the attribute. + state._fsdp_extension = None + return state def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool: #parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) @@ -982,7 +1090,7 @@ def init_fn_t2p2p0( _init_prefetching_state(self, backward_prefetch, forward_prefetch) _init_buffer_state(self, module) # extension needs to be set before `_init_param_handle_from_module()` - _init_extension(self, device_mesh) + _init_extension_t2p2p0(self, device_mesh) _init_param_handle_from_module( self, module, From 77a1417a93dbba8608aa105ed009f450abf4c120 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Fri, 5 Jan 2024 21:37:17 +0000 Subject: [PATCH 20/66] remove prints --- composer/trainer/mosaic_fsdp_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index f8e102f7fb..b0d4c466e9 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -795,7 +795,6 @@ def all_gather_dtensor_t2p2p0( tensor: DTensor, parent_mesh: Optional[DeviceMesh], ) -> torch.Tensor: - print('\n\n\n\n\here!') """All gather a DTensor in its FSDP dimension and return the local tensor.""" assert parent_mesh == tensor.device_mesh @@ -804,7 +803,6 @@ def all_gather_dtensor_t2p2p0( # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement] for i in range(0, len(placements)-1): placements[i] = Replicate() - print (len(placements), placements) tensor = tensor.redistribute( device_mesh=tensor.device_mesh, placements=placements, From 355d53570b4a7416c3542776bf69353324f46aac Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 16:35:39 -0500 Subject: [PATCH 21/66] v1 --- composer/core/state.py | 83 +++++++++++++++------------ composer/trainer/mosaic_fsdp.py | 7 +-- composer/trainer/mosaic_fsdp_utils.py | 2 +- 3 files changed, 49 insertions(+), 43 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 2906656f50..b31698be86 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -871,19 +871,19 @@ def get_model_state_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The state dict for the model. """ - if version.parse(torch.__version__) > version.parse("2.1.2"): - model_state_dict, _ = self.get_model_and_optimizer_state_dict(model_only=True) - else: - if self.fsdp_enabled and self.fsdp_state_dict_type is not None: - with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): - model_state_dict = self.model.state_dict() - else: + return self.get_model_and_optimizer_state_dict(model_only=True)[0] + + def _legacy_get_model_state_dict(self) -> Dict[str, Any]: + if self.fsdp_enabled and self.fsdp_state_dict_type is not None: + with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): model_state_dict = self.model.state_dict() + else: + model_state_dict = self.model.state_dict() - # Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel - # If it is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail - if self.is_model_ddp: - torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.') + # Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel + # If it is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail + if self.is_model_ddp: + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.') return model_state_dict def _legacy_get_optim_state_dict(self) -> Dict[str, Any]: @@ -900,29 +900,30 @@ def _legacy_get_optim_state_dict(self) -> Dict[str, Any]: def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str, Any], Dict[str, Any]]: if version.parse(torch.__version__) > version.parse("2.1.2"): from torch.distributed.checkpoint.state_dict import get_state_dict, StateDictOptions - full_state_dict = True if self.fsdp_state_dict_type == 'full' else False - if self.fsdp_state_dict_type not in ['full', 'sharded']: + if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: raise NotImplementedError( textwrap.dedent( - f"fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for a torch version > 2.1.2." - f"You are using {version.parse(torch.__version__)}" + f"fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for " + f'torch version {{version.parse(torch.__version__)}} > 2.1.2. Please set ' + 'fsdp_state_dict_type to None, "full", or "sharded".' ) ) optimizer = ensure_tuple(self.optimizers)[0] - model_state_dict, optim_state_dict = get_state_dict(model=self.model, - optimizers=([] if model_only else optimizer), - submodules=None, - options=StateDictOptions( - full_state_dict=full_state_dict, - cpu_offload=True, - ignore_frozen_params=True, - keep_submodule_prefixes=True, - strict=True, - ) - ) + model_state_dict, optim_state_dict = get_state_dict( + model=self.model, + optimizers=([] if model_only else optimizer), + submodules=None, + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type != 'sharded', + cpu_offload=True, + ignore_frozen_params=True, + keep_submodule_prefixes=True, + strict=True, + ), + ) else: - model_state_dict = self.get_model_state_dict() + model_state_dict = self._legacy_get_model_state_dict() optim_state_dict = self._legacy_get_optim_state_dict() return model_state_dict, optim_state_dict @@ -1245,28 +1246,34 @@ def _load_dataset_state(self, obj: Dict[str, Any]) -> None: # starts. This avoids "CUDA error: initialization error" -- its not clear why. # self.dataset_resumption['eval'][evaluator.label] = True - def load_model_and_optimizer_state(self, - state_dict: Dict[str, Any], - logger: Logger, - strict: bool, - exclude_algorithms: Optional[List[str]] = None, - algorithm_passes: Optional[List[AlgorithmPass]] = None, - load_model_only: bool = False): + def load_model_and_optimizer_state( + self, + state_dict: Dict[str, Any], + logger: Logger, + strict: bool, + exclude_algorithms: Optional[List[str]] = None, + algorithm_passes: Optional[List[AlgorithmPass]] = None, + load_model_only: bool = False, + ): # Note: In this case required algorithms not applied. if version.parse(torch.__version__) > version.parse("2.1.2"): from torch.distributed.checkpoint.state_dict import set_state_dict, StateDictOptions optimizer = ensure_tuple(self.optimizers)[0] - set_state_dict(self.model, + set_state_dict( + self.model, optimizers=([] if load_model_only else optimizer), model_state_dict=state_dict['model'], optim_state_dict=({} if load_model_only else state_dict['optimizers']), - options=StateDictOptions(strict=False)) + options=StateDictOptions(strict=False), + ) else: - self.load_model_state(state_dict, + self.load_model_state( + state_dict, logger, strict=strict, exclude_algorithms=exclude_algorithms, - algorithm_passes=algorithm_passes,) + algorithm_passes=algorithm_passes, + ) if not load_model_only: self.load_optim_state(state_dict) diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index d4def64698..fc18cd7033 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -57,7 +57,6 @@ def patch_pytorch(): elif version.parse(torch.__version__) < version.parse('2.1.3'): # Monkey patch for torch < 2.1.3 ie torch == 2.1.1, 2.1.2 - # Allow 2D HSDP from torch.distributed.fsdp import _runtime_utils _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None @@ -71,9 +70,9 @@ def patch_pytorch(): # Better overlap communication and computation from torch.distributed.fsdp import _runtime_utils - - from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p2, init_fn_t2p2p0 + from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p2 _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2 - # Monkey patch __init__ where __init__ calls the custom _auto_wrap fn + # Monkeypatch dtensor support + from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0 FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index b0d4c466e9..4257e6fe37 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -1257,7 +1257,7 @@ def _share_state_and_init_handle_attrs_t2p2( handle = root_state._handle if handle: handle.init_flat_param_attributes() - #_validate_and_get_hybrid_shard_state(root_module) + # _validate_and_get_hybrid_shard_state(root_module) attr_name_to_values: Dict[str, Set[Any]] = {} for attr_name in HOMOGENEOUS_ATTR_NAMES: attr_name_to_values[attr_name] = set() From 1033fdf9d84b9d0041f80755c067ae89927289bf Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 17:10:58 -0500 Subject: [PATCH 22/66] v2 --- composer/core/state.py | 98 +++++++++++++++++------------------- composer/utils/checkpoint.py | 2 +- 2 files changed, 46 insertions(+), 54 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index b31698be86..d202438f7a 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -922,6 +922,7 @@ def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str strict=True, ), ) + optim_state_dict = {type(optimizer).__qualname__: optim_state_dict} else: model_state_dict = self._legacy_get_model_state_dict() optim_state_dict = self._legacy_get_optim_state_dict() @@ -1099,49 +1100,34 @@ def _apply_required_algorithms( 'have undergone surgery, the following algorithms may be excluded using ' f'`load_exclude_algorithms`, e.g. `load_exclude_algorithms=[{missing_algo_names}]`.')) from e - def load_model_state( + def _legacy_load_model_state( self, state_dict: Dict[str, Any], - logger: Logger, strict: bool, - exclude_algorithms: Optional[List[str]] = None, - algorithm_passes: Optional[List[AlgorithmPass]] = None, ): """Loads the model's state from a ``state_dict``. Args: state_dict (Dict[str, Any]): The state dict, generated from a previous call to :meth:`state_dict`. - logger (Logger): The logger. strict (bool): Whether the keys (i.e., model parameter names) in the model state dict should perfectly match the keys in the model instance. - exclude_algorithms (List[str], optional): List of algorithm names to exclude from autoloading. (default: ``None``) - algorithm_passes (List[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms - to sort them into the correct order. (default: ``None``) """ - if 'algorithms' in state_dict: - self._apply_required_algorithms(state_dict, logger, exclude_algorithms, algorithm_passes) - - if state_dict.get('is_model_ddp', False) and not self.is_model_ddp: - # This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state - # with the `module.` prefix - torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.') - # For FSDP monolith checkpoints, the model does not exist on ranks > 0 - model_on_rank = state_dict['model'] is not None + if state_dict['model'] is None: + return missing_keys, unexpected_keys = [], [] try: - # Load model if it exists. For FSDP monolith checkpoints, the model does not exist on ranks > 0 - if model_on_rank: - if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only: - log.debug( - f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}' - ) - with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): - missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict) - else: - log.debug(f'Loading model state dict with strict={strict}') + # Load model if it exists + if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only: + log.debug( + f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}' + ) + with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict) + else: + log.debug(f'Loading model state dict with strict={strict}') + missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict) except RuntimeError as e: if 'Missing key(s) in state_dict' in str(e) or 'Unexpected key(s) in state_dict' in str(e): raise RuntimeError( @@ -1151,9 +1137,9 @@ def load_model_state( else: raise e - if model_on_rank and len(missing_keys) > 0: + if len(missing_keys) > 0: log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if model_on_rank and len(unexpected_keys) > 0: + if len(unexpected_keys) > 0: if self.fsdp_config is not None and self.fsdp_config[ 'use_orig_params'] and self.fsdp_state_dict_type == 'local': log.warning( @@ -1163,16 +1149,7 @@ def load_model_state( 'was still loaded correctly.') log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - # If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading - if self.load_fsdp_monolith_rank0_only: - assert self.fsdp_config is not None - log.info('Wrapping model with FSDP after loading model_state.') - from composer.trainer.dist_strategy import prepare_fsdp_module - prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device, - self.auto_microbatching) - log.debug('Finished wrapping model with FSDP.') - - def load_optim_state(self, state_dict: Dict[str, Any]): + def _legacy_load_optim_state(self, state_dict: Dict[str, Any]): """Load the optimizer state. Args: @@ -1255,27 +1232,42 @@ def load_model_and_optimizer_state( algorithm_passes: Optional[List[AlgorithmPass]] = None, load_model_only: bool = False, ): - # Note: In this case required algorithms not applied. + if 'algorithms' in state_dict: + self._apply_required_algorithms(state_dict, logger, exclude_algorithms, algorithm_passes) + + if state_dict.get('is_model_ddp', False) and not self.is_model_ddp: + # This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state + # with the `module.` prefix + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.') + + # Load model and optimizer state if version.parse(torch.__version__) > version.parse("2.1.2"): from torch.distributed.checkpoint.state_dict import set_state_dict, StateDictOptions optimizer = ensure_tuple(self.optimizers)[0] + model_state_dict = state_dict.get('model', {}) + optim_state_dict = state_dict['optimizers'].get(type(optimizer).__qualname__, {}) + if load_model_only: + optimizer, optim_state_dict = [], {} set_state_dict( self.model, - optimizers=([] if load_model_only else optimizer), - model_state_dict=state_dict['model'], - optim_state_dict=({} if load_model_only else state_dict['optimizers']), - options=StateDictOptions(strict=False), + optimizers=optimizer, + model_state_dict=model_state_dict, + optim_state_dict=optim_state_dict, + options=StateDictOptions(strict=strict), ) else: - self.load_model_state( - state_dict, - logger, - strict=strict, - exclude_algorithms=exclude_algorithms, - algorithm_passes=algorithm_passes, - ) + self._legacy_load_model_state(state_dict, strict) if not load_model_only: - self.load_optim_state(state_dict) + self._legacy_load_optim_state(state_dict) + + # If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading + if self.load_fsdp_monolith_rank0_only: + assert self.fsdp_config is not None + log.info('Wrapping model with FSDP after loading model_state.') + from composer.trainer.dist_strategy import prepare_fsdp_module + prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device, + self.auto_microbatching) + log.debug('Finished wrapping model with FSDP.') def load_state_dict( self, @@ -1306,7 +1298,7 @@ def load_state_dict( strict=strict, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, - load_model_only=(not 'optimizers' in state) + load_model_only=(not 'optimizers' in state), ) for attribute_name in sorted(state.keys()): # Sort so all ranks load in the same order diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index db927b4f93..5eda9f46f9 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -868,7 +868,7 @@ def _restore_checkpoint( strict=strict_model_weights, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, - load_model_only=True + load_model_only=True, ) if not load_weights_only: state.load_state_dict( From 06b60a3037be679d9530f0227d7b4e33025f27c1 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 17:16:13 -0500 Subject: [PATCH 23/66] lint --- composer/core/state.py | 33 +++--- composer/trainer/mosaic_fsdp.py | 1 + composer/trainer/mosaic_fsdp_utils.py | 159 ++++++++++---------------- composer/utils/checkpoint.py | 23 ++-- 4 files changed, 88 insertions(+), 128 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index d202438f7a..94d95c78e0 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -10,7 +10,7 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast import numpy as np import torch @@ -872,7 +872,7 @@ def get_model_state_dict(self) -> Dict[str, Any]: Dict[str, Any]: The state dict for the model. """ return self.get_model_and_optimizer_state_dict(model_only=True)[0] - + def _legacy_get_model_state_dict(self) -> Dict[str, Any]: if self.fsdp_enabled and self.fsdp_state_dict_type is not None: with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): @@ -898,16 +898,13 @@ def _legacy_get_optim_state_dict(self) -> Dict[str, Any]: return optim_state_dict def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str, Any], Dict[str, Any]]: - if version.parse(torch.__version__) > version.parse("2.1.2"): - from torch.distributed.checkpoint.state_dict import get_state_dict, StateDictOptions + if version.parse(torch.__version__) > version.parse('2.1.2'): + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: raise NotImplementedError( - textwrap.dedent( - f"fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for " - f'torch version {{version.parse(torch.__version__)}} > 2.1.2. Please set ' - 'fsdp_state_dict_type to None, "full", or "sharded".' - ) - ) + textwrap.dedent(f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' + f'torch version {{version.parse(torch.__version__)}} > 2.1.2. Please set ' + 'fsdp_state_dict_type to None, "full", or "sharded".')) optimizer = ensure_tuple(self.optimizers)[0] model_state_dict, optim_state_dict = get_state_dict( @@ -915,12 +912,12 @@ def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str optimizers=([] if model_only else optimizer), submodules=None, options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type != 'sharded', - cpu_offload=True, - ignore_frozen_params=True, - keep_submodule_prefixes=True, - strict=True, - ), + full_state_dict=self.fsdp_state_dict_type != 'sharded', + cpu_offload=True, + ignore_frozen_params=True, + keep_submodule_prefixes=True, + strict=True, + ), ) optim_state_dict = {type(optimizer).__qualname__: optim_state_dict} else: @@ -1241,8 +1238,8 @@ def load_model_and_optimizer_state( torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.') # Load model and optimizer state - if version.parse(torch.__version__) > version.parse("2.1.2"): - from torch.distributed.checkpoint.state_dict import set_state_dict, StateDictOptions + if version.parse(torch.__version__) > version.parse('2.1.2'): + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_state_dict optimizer = ensure_tuple(self.optimizers)[0] model_state_dict = state_dict.get('model', {}) optim_state_dict = state_dict['optimizers'].get(type(optimizer).__qualname__, {}) diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index fc18cd7033..c19827cd7c 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -70,6 +70,7 @@ def patch_pytorch(): # Better overlap communication and computation from torch.distributed.fsdp import _runtime_utils + from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p2 _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2 diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 4257e6fe37..bd4729560e 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -755,40 +755,31 @@ def _sharded_pre_load_state_dict_hook( _enter_unshard_params_ctx(module, fsdp_state, writeback=True) + if version.parse(torch.__version__) >= version.parse('2.1.2') and version.parse( torch.__version__) < version.parse('2.2.1'): - from torch.distributed.fsdp.fully_sharded_data_parallel import ( - _init_ignored_module_states, - _init_device_handle, - _annotate_modules_for_dynamo, - _auto_wrap, - _init_core_state, - _init_runtime_state, - _init_prefetching_state, - _init_buffer_state, - #_init_extension, - _init_param_handle_from_module, - _check_orig_params_flattened, - _register_flat_param, - _init_state_dict_state, - _register_all_state_dict_hooks, - ) - from torch.distributed.fsdp._init_utils import ( - HYBRID_SHARDING_STRATEGIES, - ProcessGroupType, - _init_intra_and_inter_node_groups, - _is_valid_hybrid_shard_pg_type, - _get_default_comm_hook_state, - ) - from torch.distributed.fsdp.wrap import _Policy, CustomPolicy, ModuleWrapPolicy - from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard - from torch.distributed.fsdp._common_utils import _FSDPState + import copy + + from torch.distributed._tensor import DeviceMesh, DTensor, Replicate + from torch.distributed._tensor import Shard as DShard from torch.distributed.algorithms._comm_hooks import default_hooks + from torch.distributed.device_mesh import _mesh_resources from torch.distributed.distributed_c10d import _get_default_group + from torch.distributed.fsdp._common_utils import _FSDPState + from torch.distributed.fsdp._init_utils import (HYBRID_SHARDING_STRATEGIES, ProcessGroupType, + _get_default_comm_hook_state, _init_intra_and_inter_node_groups, + _is_valid_hybrid_shard_pg_type) + from torch.distributed.fsdp.fully_sharded_data_parallel import (_annotate_modules_for_dynamo, # _init_extension, + _auto_wrap, _check_orig_params_flattened, + _init_buffer_state, _init_core_state, + _init_device_handle, _init_ignored_module_states, + _init_param_handle_from_module, + _init_prefetching_state, _init_runtime_state, + _init_state_dict_state, + _register_all_state_dict_hooks, + _register_flat_param) + from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy, _Policy from torch.distributed.tensor.parallel.fsdp import DTensorExtensions - from torch.distributed.device_mesh import _mesh_resources - import copy - def all_gather_dtensor_t2p2p0( self, @@ -801,15 +792,14 @@ def all_gather_dtensor_t2p2p0( placements = list(copy.deepcopy(tensor.placements)) # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement] # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement] - for i in range(0, len(placements)-1): - placements[i] = Replicate() + for i in range(0, len(placements) - 1): + placements[i] = Replicate() tensor = tensor.redistribute( - device_mesh=tensor.device_mesh, - placements=placements, + device_mesh=tensor.device_mesh, + placements=placements, ) return tensor.to_local() - def chunk_dtensor_t2p2p0( self, tensor: torch.Tensor, @@ -823,7 +813,7 @@ def chunk_dtensor_t2p2p0( """ parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) if parent_mesh is None: - raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.") + raise RuntimeError('No parent device_mesh is found for FSDP device_mesh.') # if parent_mesh.ndim != 2: # raise RuntimeError( # f"Found parent device_mesh of ndim={parent_mesh.ndim},", @@ -845,9 +835,7 @@ def chunk_dtensor_t2p2p0( shard_placements = [Replicate() for _ in range(parent_mesh.ndim)] shard_placements[0] = DShard(0) # type: ignore[call-overload] - return DTensor.from_local( - tensor, parent_mesh, replicate_placements - ).redistribute( + return DTensor.from_local(tensor, parent_mesh, replicate_placements).redistribute( device_mesh=parent_mesh, placements=shard_placements, ) @@ -867,18 +855,15 @@ def chunk_dtensor_t2p2p0( shard_placements = [DShard(0) for _ in range(parent_mesh.ndim)] # type: ignore[misc] shard_placements[-1] = tp_placement # type: ignore[call-overload] - elif parent_mesh.ndim == 3: replicate_placements = [Replicate(), Replicate(), tp_placement] shard_placements = [Replicate(), DShard(0), tp_placement] # type: ignore[misc] - return DTensor.from_local( - tensor, parent_mesh, replicate_placements - ).redistribute( + return DTensor.from_local(tensor, parent_mesh, replicate_placements).redistribute( device_mesh=parent_mesh, placements=shard_placements, ) - + DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p2p0 DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p2p0 @@ -902,7 +887,6 @@ def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool: # ) return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2 - def _init_process_group_state_for_hybrid_shard_t2p2p0( state: _FSDPState, process_group: ProcessGroupType, @@ -916,15 +900,12 @@ def _init_process_group_state_for_hybrid_shard_t2p2p0( state._inter_node_pg = device_mesh.get_group(mesh_dim=0) state.process_group = device_mesh.get_group(mesh_dim=1) else: - raise ValueError( - "Expected device_mesh to have ndim=2 " - f"but got {len(device_mesh.get_group())}" - ) + raise ValueError('Expected device_mesh to have ndim=2 ' + f'but got {len(device_mesh.get_group())}') elif process_group is None: default_group = _get_default_group() - intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( - default_group, state._device_handle.device_count() - ) + intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(default_group, + state._device_handle.device_count()) # we shard across intra-node state.process_group = intra_node_group # save _inter_node_pg to allreduce across. @@ -936,17 +917,12 @@ def _init_process_group_state_for_hybrid_shard_t2p2p0( # as documented. state.process_group, state._inter_node_pg = process_group else: - raise ValueError( - "Expected process_group to be passed in as either None or " - f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}" - ) + raise ValueError('Expected process_group to be passed in as either None or ' + f'Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}') # Create state for allreduce - state._inter_node_state = _get_default_comm_hook_state( - process_group=state._inter_node_pg, - ) + state._inter_node_state = _get_default_comm_hook_state(process_group=state._inter_node_pg,) return state - def _init_process_group_state_t2p2p0( state: _FSDPState, process_group: ProcessGroupType, @@ -955,10 +931,8 @@ def _init_process_group_state_t2p2p0( device_mesh: Optional[DeviceMesh] = None, ) -> _FSDPState: if process_group is not None and device_mesh is not None: - raise ValueError( - "Cannot pass both process_group and device_mesh at the " - "same time. Please just pass only one of them." - ) + raise ValueError('Cannot pass both process_group and device_mesh at the ' + 'same time. Please just pass only one of them.') is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES if is_hybrid_strategy: if process_group is None and policy is None and device_mesh is None: @@ -966,21 +940,17 @@ def _init_process_group_state_t2p2p0( # passed in, there is no way to ensure all wrapped FSDP instances use the same # process groups. raise ValueError( - f"Manual wrapping with {sharding_strategy}", - "requires explicit specification of process group or device_mesh.", + f'Manual wrapping with {sharding_strategy}', + 'requires explicit specification of process group or device_mesh.', ) else: - state = _init_process_group_state_for_hybrid_shard_t2p2p0( - state, process_group, device_mesh - ) + state = _init_process_group_state_for_hybrid_shard_t2p2p0(state, process_group, device_mesh) else: if device_mesh: state._device_mesh = device_mesh state.process_group = device_mesh.get_group(mesh_dim=0) else: - state.process_group = ( - process_group if process_group is not None else _get_default_group() - ) + state.process_group = (process_group if process_group is not None else _get_default_group()) state.rank = state.process_group.rank() state.world_size = state.process_group.size() @@ -988,13 +958,8 @@ def _init_process_group_state_t2p2p0( if is_hybrid_strategy: data_parallel_world_size *= state._inter_node_pg.size() state._gradient_predivide_factor = ( - default_hooks.DefaultState._get_gradient_predivide_factor( - data_parallel_world_size - ) - ) - state._gradient_postdivide_factor = ( - data_parallel_world_size / state._gradient_predivide_factor - ) + default_hooks.DefaultState._get_gradient_predivide_factor(data_parallel_world_size)) + state._gradient_postdivide_factor = (data_parallel_world_size / state._gradient_predivide_factor) return state def init_fn_t2p2p0( @@ -1003,9 +968,7 @@ def init_fn_t2p2p0( process_group: ProcessGroupType = None, sharding_strategy: Optional[ShardingStrategy] = None, cpu_offload: Optional[CPUOffload] = None, - auto_wrap_policy: Optional[ - Union[Callable, ModuleWrapPolicy, CustomPolicy] - ] = None, + auto_wrap_policy: Optional[Union[Callable, ModuleWrapPolicy, CustomPolicy]] = None, backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, mixed_precision: Optional[MixedPrecision] = None, ignored_modules: Optional[Iterable[torch.nn.Module]] = None, @@ -1015,12 +978,10 @@ def init_fn_t2p2p0( forward_prefetch: bool = False, limit_all_gathers: bool = True, use_orig_params: bool = False, - ignored_states: Union[ - Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] - ] = None, + ignored_states: Union[Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]] = None, device_mesh: Optional[DeviceMesh] = None, ): - torch._C._log_api_usage_once("torch.distributed.fsdp") + torch._C._log_api_usage_once('torch.distributed.fsdp') super(FullyShardedDataParallel, self).__init__() _init_ignored_module_states(self, module, ignored_modules, ignored_states) _init_device_handle(self, module, self._ignored_params, device_id) @@ -1043,25 +1004,25 @@ def init_fn_t2p2p0( ) if auto_wrap_policy is not None: root_kwargs = { - "process_group": process_group, - "sharding_strategy": sharding_strategy, - "cpu_offload": cpu_offload, - "backward_prefetch": backward_prefetch, - "mixed_precision": mixed_precision, - "param_init_fn": param_init_fn, - "device_id": device_id, - "sync_module_states": sync_module_states, - "forward_prefetch": forward_prefetch, - "limit_all_gathers": limit_all_gathers, - "use_orig_params": use_orig_params, - "ignored_states": self._ignored_params, - "device_mesh": device_mesh, + 'process_group': process_group, + 'sharding_strategy': sharding_strategy, + 'cpu_offload': cpu_offload, + 'backward_prefetch': backward_prefetch, + 'mixed_precision': mixed_precision, + 'param_init_fn': param_init_fn, + 'device_id': device_id, + 'sync_module_states': sync_module_states, + 'forward_prefetch': forward_prefetch, + 'limit_all_gathers': limit_all_gathers, + 'use_orig_params': use_orig_params, + 'ignored_states': self._ignored_params, + 'device_mesh': device_mesh, } if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None: # Share root process groups with children to maintain # the invariant that all FSDP modules will have the same # process groups. - root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) + root_kwargs['process_group'] = (self.process_group, self._inter_node_pg) _auto_wrap( module, diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 5eda9f46f9..a28b8f9fef 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -497,7 +497,6 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): else: storage_reader = FileSystemReaderWithValidation(source_path) - # We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph. with torch.no_grad(): # 1. Load model and metadata first @@ -505,8 +504,8 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): state_dict = {'state': {'model': state.get_model_state_dict()}} else: cur_state_dict = state.state_dict() - # For older versions of torch, we load optimizier separately. - if version.parse(torch.__version__) <= version.parse("2.1.2"): + # For older versions of torch, we load optimizer separately. + if version.parse(torch.__version__) <= version.parse('2.1.2'): cur_state_dict.pop('optimizers') state_dict = {'state': cur_state_dict} @@ -529,11 +528,11 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # 2. Optionally load optimizer # if we are using later than 2.1.0 then optimizer will already be loaded - if version.parse(torch.__version__) <= version.parse("2.1.2") and not load_weights_only: + if version.parse(torch.__version__) <= version.parse('2.1.2') and not load_weights_only: optim_state = load_sharded_optimizer_state_dict(model_state_dict=state.state_dict()['model'], optimizer_key='optimizers', storage_reader=storage_reader) - state.load_optim_state(optim_state) + state._legacy_load_optim_state(optim_state) # 3. Optionally load RNG rng_state_dicts = reproducibility.get_rng_state() @@ -913,7 +912,7 @@ def save_checkpoint( # requires a top level state dict key for the optimizer. # See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271 # for more info. - if using_torch_2() and version.parse(torch.__version__) <= version.parse("2.1.2"): + if using_torch_2() and version.parse(torch.__version__) <= version.parse('2.1.2'): if not weights_only: state_dict['optimizers'] = state_dict['state'].pop('optimizers') @@ -965,17 +964,19 @@ def save_checkpoint( if device_mesh is not None and device_mesh.ndim == 2: expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) process_group = device_mesh.get_group(1) - log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}') + log.debug( + f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}' + ) else: expect_file = True process_group = None if expect_file: dist_cp.save( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(dirname), - planner=save_planner, - process_group=process_group, + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(dirname), + planner=save_planner, + process_group=process_group, ) log.warning('finished pytorch save state dict') From e1cd8cf22054413219220d886cf4387679ae62dc Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 17:18:19 -0500 Subject: [PATCH 24/66] add more tests --- .github/workflows/pr-cpu.yaml | 5 +++++ .github/workflows/pr-gpu.yaml | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 989b4ded43..ac3feaff25 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -27,6 +27,11 @@ jobs: markers: 'not daily and not remote and not gpu and not vision and not doctest' pytest_command: 'coverage run -m pytest' composer_package_name: 'mosaicml' + - name: 'cpu-3.10-2.2' + container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04 + markers: 'not daily and not remote and not gpu and not vision and not doctest' + pytest_command: 'coverage run -m pytest' + composer_package_name: 'mosaicml' - name: 'cpu-vision' container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04 markers: 'not daily and not remote and not gpu and vision and not doctest' diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 2c818b7229..2483619b16 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -17,6 +17,11 @@ jobs: markers: 'not daily and not remote and gpu and (doctest or not doctest)' pytest_command: 'coverage run -m pytest' composer_package_name: 'mosaicml' + - name: 'gpu-3.10-2.2' + container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04 + markers: 'not daily and not remote and gpu and (doctest or not doctest)' + pytest_command: 'coverage run -m pytest' + composer_package_name: 'mosaicml' name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: From de6a568bd70cb60ba35693d5c281cf9f7fd1d0dc Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 17:20:56 -0500 Subject: [PATCH 25/66] switch to PRs --- .github/workflows/pr-gpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 2483619b16..2e68a9de33 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -1,6 +1,6 @@ name: PR GPU tests on: - pull_request_target: + pull_request: workflow_dispatch: # Cancel old runs when a new commit is pushed to the same branch if not on main or dev concurrency: From 0b721f3df8ce8eb2a309f60379cc355e3095788a Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 20:21:48 -0500 Subject: [PATCH 26/66] ignore warning --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 342c9b3d7e..0741804fec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,6 +151,9 @@ filterwarnings = [ '''ignore:torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead:UserWarning''', # Ignore torch sharded tensor deprecated warnings '''ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning''', + # Ignore torch pytree deprecated warnings + '''ignore:torch.utils._pytree._register_pytree_node is deprecated.*.UserWarning''' + ] # Coverage From 893bcb709974f363d29b7e041d7faef391103833 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 20:31:12 -0500 Subject: [PATCH 27/66] fix lint --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0741804fec..7dd1f1842e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,7 +152,7 @@ filterwarnings = [ # Ignore torch sharded tensor deprecated warnings '''ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning''', # Ignore torch pytree deprecated warnings - '''ignore:torch.utils._pytree._register_pytree_node is deprecated.*.UserWarning''' + '''ignore:torch.utils._pytree._register_pytree_node is deprecated.*:UserWarning''' ] From d66da7bc61db00764fa9cbbf8d8ea39d54d06ad5 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 20:58:11 -0500 Subject: [PATCH 28/66] version error --- composer/trainer/mosaic_fsdp_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index bd4729560e..d17954b54a 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -756,7 +756,7 @@ def _sharded_pre_load_state_dict_hook( _enter_unshard_params_ctx(module, fsdp_state, writeback=True) -if version.parse(torch.__version__) >= version.parse('2.1.2') and version.parse( +if version.parse(torch.__version__) > version.parse('2.1.2') and version.parse( torch.__version__) < version.parse('2.2.1'): import copy From 02faba71c5a63a95876ac732a2b8dc79e6114531 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 21:03:11 -0500 Subject: [PATCH 29/66] fix version --- composer/trainer/mosaic_fsdp_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index d17954b54a..196281262a 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -756,7 +756,7 @@ def _sharded_pre_load_state_dict_hook( _enter_unshard_params_ctx(module, fsdp_state, writeback=True) -if version.parse(torch.__version__) > version.parse('2.1.2') and version.parse( +if version.parse(torch.__version__) > version.parse('2.1.3') and version.parse( torch.__version__) < version.parse('2.2.1'): import copy From 7f8e08c7ab3e060b4addb32054c2c4a9d885ad1f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 22:06:26 -0500 Subject: [PATCH 30/66] fix state dict --- composer/core/state.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/composer/core/state.py b/composer/core/state.py index 94d95c78e0..75a85bd5c6 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -933,13 +933,17 @@ def state_dict(self) -> Dict[str, Any]: Dict[str, Any]: The state dict. """ state_dict = {} - state_dict['model'], state_dict['optimizers'] = self.get_model_and_optimizer_state_dict() + model_state_dict, optim_state_dict = self.get_model_and_optimizer_state_dict() for attribute_name in self.serialized_attributes: attribute_value = getattr(self, attribute_name) if attribute_name in ['model', 'optimizers']: continue if attribute_name == 'dataset_state': serialized_value = self._dataset_state_dict() + elif attribute_name == 'model': + serialized_value = model_state_dict + elif attribute_name == 'optimizers': + serialized_value = optim_state_dict elif attribute_name == 'algorithms': # Store as list to preserve order in which algorithms were applied serialized_value = [(type(obj).__qualname__, obj.state_dict()) for obj in ensure_tuple(attribute_value)] From 3c557bf1aeb0d16671319c5451818501af76d556 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 22:14:12 -0500 Subject: [PATCH 31/66] update versions --- composer/core/state.py | 6 +++--- composer/utils/checkpoint.py | 33 ++++++++++++++++++++++----------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 75a85bd5c6..6dd5124312 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -898,12 +898,12 @@ def _legacy_get_optim_state_dict(self) -> Dict[str, Any]: return optim_state_dict def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str, Any], Dict[str, Any]]: - if version.parse(torch.__version__) > version.parse('2.1.2'): + if version.parse(torch.__version__) > version.parse('2.1.3'): from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: raise NotImplementedError( textwrap.dedent(f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' - f'torch version {{version.parse(torch.__version__)}} > 2.1.2. Please set ' + f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set ' 'fsdp_state_dict_type to None, "full", or "sharded".')) optimizer = ensure_tuple(self.optimizers)[0] @@ -1242,7 +1242,7 @@ def load_model_and_optimizer_state( torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.') # Load model and optimizer state - if version.parse(torch.__version__) > version.parse('2.1.2'): + if version.parse(torch.__version__) > version.parse('2.1.3'): from torch.distributed.checkpoint.state_dict import StateDictOptions, set_state_dict optimizer = ensure_tuple(self.optimizers)[0] model_state_dict = state_dict.get('model', {}) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index a28b8f9fef..cf2ad01d7c 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -505,7 +505,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): else: cur_state_dict = state.state_dict() # For older versions of torch, we load optimizer separately. - if version.parse(torch.__version__) <= version.parse('2.1.2'): + if version.parse(torch.__version__) < version.parse('2.1.3'): cur_state_dict.pop('optimizers') state_dict = {'state': cur_state_dict} @@ -516,7 +516,10 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # Call function to modify state_dict ignore_keys(state_dict) - dist_cp.load(state_dict, storage_reader) + if version.parse(torch.__version__) > version.parse('2.1.3'): + dist_cp.load(state_dict, storage_reader) + else: + dist_cp.load_state_dict(state_dict, storage_reader) state.load_state_dict( state_dict['state'], @@ -528,7 +531,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # 2. Optionally load optimizer # if we are using later than 2.1.0 then optimizer will already be loaded - if version.parse(torch.__version__) <= version.parse('2.1.2') and not load_weights_only: + if version.parse(torch.__version__) < version.parse('2.1.3') and not load_weights_only: optim_state = load_sharded_optimizer_state_dict(model_state_dict=state.state_dict()['model'], optimizer_key='optimizers', storage_reader=storage_reader) @@ -907,12 +910,12 @@ def save_checkpoint( # Sharded checkpoints get their own little folder. if state.fsdp_sharded_state_dict_enabled: - # To load optimizer states with 2.0 <= torch <= 2.1.2 , the optimizer state must be at the top + # To load optimizer states with 2.0 <= torch < 2.1.3 , the optimizer state must be at the top # level of the state dict because the load_sharded_optimizer_state_dict function # requires a top level state dict key for the optimizer. # See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271 # for more info. - if using_torch_2() and version.parse(torch.__version__) <= version.parse('2.1.2'): + if using_torch_2() and version.parse(torch.__version__) < version.parse('2.1.3'): if not weights_only: state_dict['optimizers'] = state_dict['state'].pop('optimizers') @@ -972,12 +975,20 @@ def save_checkpoint( process_group = None if expect_file: - dist_cp.save( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(dirname), - planner=save_planner, - process_group=process_group, - ) + if version.parse(torch.__version__) > version.parse('2.1.3'): + dist_cp.save( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(dirname), + planner=save_planner, + process_group=process_group, + ) + else: + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(dirname), + planner=save_planner, + process_group=process_group, + ) log.warning('finished pytorch save state dict') # Only rank 0 saves the state_dict unless you are using sharded checkpointing with torch <2.0 From 92e38cacd0715ed6b12ee63e06a0a3ea150f4a2f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 22:16:51 -0500 Subject: [PATCH 32/66] lint --- composer/trainer/mosaic_fsdp_utils.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 196281262a..1cf969952c 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -769,15 +769,20 @@ def _sharded_pre_load_state_dict_hook( from torch.distributed.fsdp._init_utils import (HYBRID_SHARDING_STRATEGIES, ProcessGroupType, _get_default_comm_hook_state, _init_intra_and_inter_node_groups, _is_valid_hybrid_shard_pg_type) - from torch.distributed.fsdp.fully_sharded_data_parallel import (_annotate_modules_for_dynamo, # _init_extension, - _auto_wrap, _check_orig_params_flattened, - _init_buffer_state, _init_core_state, - _init_device_handle, _init_ignored_module_states, - _init_param_handle_from_module, - _init_prefetching_state, _init_runtime_state, - _init_state_dict_state, - _register_all_state_dict_hooks, - _register_flat_param) + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + _annotate_modules_for_dynamo, + _auto_wrap, + _check_orig_params_flattened, + _init_buffer_state, + _init_core_state, + _init_device_handle, + _init_ignored_module_states, + _init_param_handle_from_module, + _init_prefetching_state, + _init_runtime_state, + _init_state_dict_state, + _register_all_state_dict_hooks, + _register_flat_param) from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy, _Policy from torch.distributed.tensor.parallel.fsdp import DTensorExtensions @@ -806,8 +811,7 @@ def chunk_dtensor_t2p2p0( rank: int, device_mesh: DeviceMesh, ) -> DTensor: - """ - Shard a tensor to chunks along the first dimension. + """Shard a tensor to chunks along the first dimension. The local rank will gets its corresponding chunk as the local tensor to create a DTensor. """ @@ -981,6 +985,7 @@ def init_fn_t2p2p0( ignored_states: Union[Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]] = None, device_mesh: Optional[DeviceMesh] = None, ): + """Docstring for lint.""" torch._C._log_api_usage_once('torch.distributed.fsdp') super(FullyShardedDataParallel, self).__init__() _init_ignored_module_states(self, module, ignored_modules, ignored_states) From 3ee188fbe343e980670bd907ed286e55f40c546a Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 22:18:12 -0500 Subject: [PATCH 33/66] lint --- composer/trainer/mosaic_fsdp_utils.py | 23 +++++++++-------------- composer/utils/checkpoint.py | 4 ++-- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 1cf969952c..0f75ff25fe 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -769,20 +769,15 @@ def _sharded_pre_load_state_dict_hook( from torch.distributed.fsdp._init_utils import (HYBRID_SHARDING_STRATEGIES, ProcessGroupType, _get_default_comm_hook_state, _init_intra_and_inter_node_groups, _is_valid_hybrid_shard_pg_type) - from torch.distributed.fsdp.fully_sharded_data_parallel import ( - _annotate_modules_for_dynamo, - _auto_wrap, - _check_orig_params_flattened, - _init_buffer_state, - _init_core_state, - _init_device_handle, - _init_ignored_module_states, - _init_param_handle_from_module, - _init_prefetching_state, - _init_runtime_state, - _init_state_dict_state, - _register_all_state_dict_hooks, - _register_flat_param) + from torch.distributed.fsdp.fully_sharded_data_parallel import (_annotate_modules_for_dynamo, _auto_wrap, + _check_orig_params_flattened, _init_buffer_state, + _init_core_state, _init_device_handle, + _init_ignored_module_states, + _init_param_handle_from_module, + _init_prefetching_state, _init_runtime_state, + _init_state_dict_state, + _register_all_state_dict_hooks, + _register_flat_param) from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy, _Policy from torch.distributed.tensor.parallel.fsdp import DTensorExtensions diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index cf2ad01d7c..e0620bc8bb 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -517,7 +517,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): ignore_keys(state_dict) if version.parse(torch.__version__) > version.parse('2.1.3'): - dist_cp.load(state_dict, storage_reader) + dist_cp.load(state_dict, storage_reader) # type: ignore else: dist_cp.load_state_dict(state_dict, storage_reader) @@ -976,7 +976,7 @@ def save_checkpoint( if expect_file: if version.parse(torch.__version__) > version.parse('2.1.3'): - dist_cp.save( + dist_cp.save( # type: ignore state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname), planner=save_planner, From 91cd547f8fed889c6a51c2622557c7dd361b7acd Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 22:22:15 -0500 Subject: [PATCH 34/66] disable lint for mosaic fsdp utils --- composer/trainer/mosaic_fsdp_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 0f75ff25fe..9b7bd53219 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -4,6 +4,9 @@ # Released under BSD 3-Clause License, # Copyright (c) Facebook, Inc. and its affiliates. +# yapf: disable +# isort: skip_file + """Utilities for monkey patching FSDP.""" import functools From b902b1b88c05334dcaca9ba4b38cbc7342c80d4c Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Sun, 7 Jan 2024 22:50:14 -0500 Subject: [PATCH 35/66] remove bad line --- composer/core/state.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 6dd5124312..338444ad85 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -936,8 +936,6 @@ def state_dict(self) -> Dict[str, Any]: model_state_dict, optim_state_dict = self.get_model_and_optimizer_state_dict() for attribute_name in self.serialized_attributes: attribute_value = getattr(self, attribute_name) - if attribute_name in ['model', 'optimizers']: - continue if attribute_name == 'dataset_state': serialized_value = self._dataset_state_dict() elif attribute_name == 'model': From 74f459de9fcb1f87f46365847597ba40b3b45f73 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 12:10:40 -0500 Subject: [PATCH 36/66] move around for legacy --- composer/core/state.py | 9 ++++++--- composer/trainer/dist_strategy.py | 1 - 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 338444ad85..7fdf8ca003 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -1240,7 +1240,8 @@ def load_model_and_optimizer_state( torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.') # Load model and optimizer state - if version.parse(torch.__version__) > version.parse('2.1.3'): + use_state_dict_fns = version.parse(torch.__version__) > version.parse('2.1.3') + if use_state_dict_fns: from torch.distributed.checkpoint.state_dict import StateDictOptions, set_state_dict optimizer = ensure_tuple(self.optimizers)[0] model_state_dict = state_dict.get('model', {}) @@ -1256,8 +1257,6 @@ def load_model_and_optimizer_state( ) else: self._legacy_load_model_state(state_dict, strict) - if not load_model_only: - self._legacy_load_optim_state(state_dict) # If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading if self.load_fsdp_monolith_rank0_only: @@ -1268,6 +1267,10 @@ def load_model_and_optimizer_state( self.auto_microbatching) log.debug('Finished wrapping model with FSDP.') + # Legacy optimizer state load must happen after FSDP monolith + if not use_state_dict_fns and not load_model_only: + self._legacy_load_optim_state(state_dict) + def load_state_dict( self, state: Dict[str, Any], diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index c8f1a7b8d0..10d95b43a8 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -273,7 +273,6 @@ def sync_hook(*args): # `nn.Module.named_parameters`. # Setting it to `True` is mandatory when using `torch.compile()`. kwargs['use_orig_params'] = fsdp_config['use_orig_params'] - print(version.parse(torch.__version__)) if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'): if 'device_mesh' in fsdp_config: from torch.distributed._tensor import init_device_mesh From 0e0cefc0e2493f37fd805780905cd873cfa54ecb Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 13:47:36 -0500 Subject: [PATCH 37/66] device mesh --- composer/core/state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/composer/core/state.py b/composer/core/state.py index 7fdf8ca003..92aaf3685e 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -795,6 +795,8 @@ def fsdp_sharded_state_dict_enabled(self): @property def fsdp_device_mesh(self): if self.fsdp_enabled: + if not hasattr(self.model, 'model'): + return None return self.model.model._device_mesh else: return None From 37941f5aca6a18946fb125559a24437093ace6fb Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 14:19:42 -0500 Subject: [PATCH 38/66] ignore warning --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7dd1f1842e..d426dbee84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,8 +152,9 @@ filterwarnings = [ # Ignore torch sharded tensor deprecated warnings '''ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning''', # Ignore torch pytree deprecated warnings - '''ignore:torch.utils._pytree._register_pytree_node is deprecated.*:UserWarning''' - + '''ignore:torch.utils._pytree._register_pytree_node is deprecated.*:UserWarning''', + # Ignore autograd kernel warning inside DeepSpeed + '''ignore:c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s).*:UserWarning''' ] # Coverage From db9f54df9bd6499d00ade1862f2618ccc044ea13 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 14:42:46 -0500 Subject: [PATCH 39/66] fix import --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d426dbee84..b00b8b2651 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,7 +154,7 @@ filterwarnings = [ # Ignore torch pytree deprecated warnings '''ignore:torch.utils._pytree._register_pytree_node is deprecated.*:UserWarning''', # Ignore autograd kernel warning inside DeepSpeed - '''ignore:c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s).*:UserWarning''' + '''ignore:.*an autograd kernel was not registered to the Autograd key(s).*:UserWarning''' ] # Coverage From 52e7b32f76fdd04a9220fc301775cab387448e56 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 15:30:55 -0500 Subject: [PATCH 40/66] always init --- composer/trainer/trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index c8c6d325e0..0655be3004 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -960,10 +960,7 @@ def __init__( assert not isinstance(device_train_microbatch_size, str) # Distributed - if deepspeed_config is not None or fsdp_config is not None or dist.get_world_size() > 1: - # Deepspeed and FSDP both require torch.distributed to be initialized, even if the world size is 1 - # And torch.distributed is always required for multi-rank training - dist.initialize_dist(device, dist_timeout) + dist.initialize_dist(device, dist_timeout) # Reproducibility rank_zero_seed, seed = _distribute_and_get_random_seed(seed, device) From 81dfad0c31882ee78ae65c106b391c87f4bce1e3 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 15:34:39 -0500 Subject: [PATCH 41/66] fix error --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b00b8b2651..f4155e23ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,7 +154,7 @@ filterwarnings = [ # Ignore torch pytree deprecated warnings '''ignore:torch.utils._pytree._register_pytree_node is deprecated.*:UserWarning''', # Ignore autograd kernel warning inside DeepSpeed - '''ignore:.*an autograd kernel was not registered to the Autograd key(s).*:UserWarning''' + '''ignore:.*an autograd kernel was not registered to the Autograd key.*:UserWarning''' ] # Coverage From 956ce5570e04aa38d71b23b1178ce6affbba8445 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 16:26:50 -0500 Subject: [PATCH 42/66] fix load planner --- composer/utils/checkpoint.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index e0620bc8bb..614f32ef15 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -517,9 +517,17 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): ignore_keys(state_dict) if version.parse(torch.__version__) > version.parse('2.1.3'): - dist_cp.load(state_dict, storage_reader) # type: ignore + dist_cp.load( # type: ignore + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) else: - dist_cp.load_state_dict(state_dict, storage_reader) + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) state.load_state_dict( state_dict['state'], From b40b8a5772bd0906b618111c2510c6c18f6ad35b Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 17:14:17 -0500 Subject: [PATCH 43/66] remove --- tests/callbacks/test_memory_monitor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/callbacks/test_memory_monitor.py b/tests/callbacks/test_memory_monitor.py index f40a04eeb3..21fec22aad 100644 --- a/tests/callbacks/test_memory_monitor.py +++ b/tests/callbacks/test_memory_monitor.py @@ -10,9 +10,7 @@ from tests.common import RandomClassificationDataset, SimpleModel, device -@device('cpu', 'gpu') def test_memory_monitor_warnings_on_cpu_models(device: str): - # Error if the user sets device=cpu even when cuda is available del device # unused. always using cpu with pytest.warns(UserWarning, match='The memory monitor only works on CUDA devices'): Trainer( From 1ec377470bed0902176e3b806996a7ed46b6a455 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 17:32:13 -0500 Subject: [PATCH 44/66] fix lint --- tests/callbacks/test_memory_monitor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_memory_monitor.py b/tests/callbacks/test_memory_monitor.py index 21fec22aad..fb83c6f4fb 100644 --- a/tests/callbacks/test_memory_monitor.py +++ b/tests/callbacks/test_memory_monitor.py @@ -7,10 +7,10 @@ from composer.callbacks import MemoryMonitor from composer.loggers import InMemoryLogger from composer.trainer import Trainer -from tests.common import RandomClassificationDataset, SimpleModel, device +from tests.common import RandomClassificationDataset, SimpleModel -def test_memory_monitor_warnings_on_cpu_models(device: str): +def test_memory_monitor_warnings_on_cpu_models(): del device # unused. always using cpu with pytest.warns(UserWarning, match='The memory monitor only works on CUDA devices'): Trainer( From f59bdaab9bd094761964166519480884c545c6ac Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 18:01:23 -0500 Subject: [PATCH 45/66] lint --- tests/callbacks/test_memory_monitor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/callbacks/test_memory_monitor.py b/tests/callbacks/test_memory_monitor.py index fb83c6f4fb..f2badc638c 100644 --- a/tests/callbacks/test_memory_monitor.py +++ b/tests/callbacks/test_memory_monitor.py @@ -11,7 +11,6 @@ def test_memory_monitor_warnings_on_cpu_models(): - del device # unused. always using cpu with pytest.warns(UserWarning, match='The memory monitor only works on CUDA devices'): Trainer( model=SimpleModel(), From ffaa7f48c4a29089ff9cb9d36f3aeae6fa966320 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 18:02:19 -0500 Subject: [PATCH 46/66] delay state dict --- composer/core/state.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/composer/core/state.py b/composer/core/state.py index 92aaf3685e..45526c84ca 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -935,7 +935,9 @@ def state_dict(self) -> Dict[str, Any]: Dict[str, Any]: The state dict. """ state_dict = {} - model_state_dict, optim_state_dict = self.get_model_and_optimizer_state_dict() + model_state_dict, optim_state_dict = None, None + if 'model' in self.serialized_attributes or 'optimizers' in self.serialized_attributes: + model_state_dict, optim_state_dict = self.get_model_and_optimizer_state_dict() for attribute_name in self.serialized_attributes: attribute_value = getattr(self, attribute_name) if attribute_name == 'dataset_state': From 51f6e3bc3e11c6fc05d6f6fdd01aeb48fa69c7e3 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 01:50:36 +0000 Subject: [PATCH 47/66] test checkpoint --- tests/trainer/test_checkpoint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 8e2e83d30c..abcd622d1d 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -689,7 +689,11 @@ def test_strict_errors(self, missing_key: bool, unexpected_key: bool): last_checkpoint = os.path.join('first', 'ep2.pt') if missing_key or unexpected_key: - error_context = pytest.raises(RuntimeError, match='Failed to load checkpoint due to') + message = r'Error\(s\) in loading state_dict' + if version.parse(torch.__version__) < version.parse('2.1.3'): + # Composer implements strict for older torch versions + message = 'Failed to load checkpoint due to' + error_context = pytest.raises(RuntimeError, match=message) else: error_context = contextlib.nullcontext() From 13b86ce7bdc79530f699dc36f749b20083c47805 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 02:11:20 +0000 Subject: [PATCH 48/66] checkpoint --- tests/trainer/test_checkpoint.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index abcd622d1d..74f719a1da 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -976,8 +976,10 @@ def test_autoload_algorithm_old_checkpoint(self): old_init, old_repr = NoOpModel.__init__, NoOpModel.__repr__ NoOpModel.__init__ = lambda self, x: None # type: ignore NoOpModel.__repr__ = lambda self: 'NoOpModel(3)' - with pytest.warns(UserWarning, match='required_on_load algorithm.*'), pytest.raises( - ValueError, match='loaded state dict contains a parameter group.*'): + error_context = pytest.raises(KeyError, match='module.0.weight') + if version.parse(torch.__version__) < version.parse('2.1.3'): + error_context = pytest.raises(ValueError, match='loaded state dict contains a parameter group.*') + with pytest.warns(UserWarning, match='required_on_load algorithm.*'), error_context: trainer_3 = self.get_trainer(load_path=os.path.join('first', 'ep1.pt'),) trainer_3.fit(duration='1ba') # Restore algorithm From 1bb1e4a836ece220eaa01aed77f724c0c2735099 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 03:05:41 +0000 Subject: [PATCH 49/66] fix cpu tests --- tests/algorithms/test_required_on_load.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/algorithms/test_required_on_load.py b/tests/algorithms/test_required_on_load.py index 3844a57084..e9eaae9484 100644 --- a/tests/algorithms/test_required_on_load.py +++ b/tests/algorithms/test_required_on_load.py @@ -5,6 +5,7 @@ import copy import os import pathlib +from packaging import version from typing import Type import pytest @@ -163,14 +164,18 @@ def test_autoload(algo_name: str, load_weights_only: bool, already_added: bool, context = pytest.warns(UserWarning, match='Automatically adding required_on_load algorithm*') # Excluding some algorithms leads to errors when loading elif exclude: - if algo_name in ['Factorize', 'SqueezeExcite']: - context = pytest.raises( - ValueError, - match= - "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", - ) - elif algo_name == 'Alibi': - context = pytest.raises(RuntimeError) + if version.parse(torch.__version__) > version.parse('2.1.3'): + if algo_name in ['BlurPool', 'Factorize', 'GatedLinearUnits', 'GhostBatchNorm', 'SqueezeExcite']: + context = pytest.raises(KeyError) # Optimizer loading is strict + else: + if algo_name in ['Factorize', 'SqueezeExcite']: + context = pytest.raises( + ValueError, + match= + "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", + ) + elif algo_name == 'Alibi': + context = pytest.raises(RuntimeError) with context: trainer2 = Trainer( From 989982675657ccdf2413f984ef21f8d8c77b950f Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 03:25:53 +0000 Subject: [PATCH 50/66] fix rotate tests --- tests/trainer/test_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 74f719a1da..7c684c406d 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -1311,6 +1311,7 @@ def test_rotate_checkpoints( dataset=train_dataset, sampler=dist.get_sampler(train_dataset), ), + precision='fp32', save_folder=str(save_folder), save_filename='checkpoint_{rank}_{batch}.pt', save_interval='1ba', From 7efbd32b8f17cd98a885b49f5f6222c34d82cae6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 03:27:29 +0000 Subject: [PATCH 51/66] fix precision --- tests/algorithms/test_algorithm_resumption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_algorithm_resumption.py b/tests/algorithms/test_algorithm_resumption.py index d1fb4e2c40..f61d288622 100644 --- a/tests/algorithms/test_algorithm_resumption.py +++ b/tests/algorithms/test_algorithm_resumption.py @@ -57,7 +57,7 @@ def test_algorithm_resumption( 'save_filename': 'ep{epoch}-rank{rank}', 'save_interval': '1ep', 'train_subset_num_batches': 2, - 'precision': 'amp_fp16', + 'precision': 'amp_bf16', } train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True) # train model once, saving checkpoints every epoch From 5d185d16fa781ad5067e5a45ed696c72c85b72d7 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 8 Jan 2024 22:36:52 -0500 Subject: [PATCH 52/66] lint --- tests/algorithms/test_required_on_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_required_on_load.py b/tests/algorithms/test_required_on_load.py index e9eaae9484..961619a023 100644 --- a/tests/algorithms/test_required_on_load.py +++ b/tests/algorithms/test_required_on_load.py @@ -5,11 +5,11 @@ import copy import os import pathlib -from packaging import version from typing import Type import pytest import torch +from packaging import version from composer import Trainer, algorithms from composer.callbacks import CheckpointSaver From 575fe3511f41d7cca3a5989cedf4f7373152940f Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 05:00:22 +0000 Subject: [PATCH 53/66] fix alibi --- .../alibi/attention_surgery_functions/utils.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/composer/algorithms/alibi/attention_surgery_functions/utils.py b/composer/algorithms/alibi/attention_surgery_functions/utils.py index d510e1e371..6988806356 100644 --- a/composer/algorithms/alibi/attention_surgery_functions/utils.py +++ b/composer/algorithms/alibi/attention_surgery_functions/utils.py @@ -125,13 +125,8 @@ def zero_and_freeze_expand_position_embeddings( if not isinstance(old_weight, torch.nn.Parameter): raise TypeError(f'Module {module._get_name()}, position embedding {position_embedding_attribute}, ' f"'weight' attribute must be of type torch.nn.Module") - new_weight = torch.nn.Parameter( - torch.zeros((max_sequence_length, old_weight.shape[1]), - dtype=old_weight.dtype, - layout=old_weight.layout, - device=old_weight.device)) - new_weight.requires_grad = False - setattr(pos_embedding_module, 'weight', new_weight) + old_weight.requires_grad = False + old_weight.zero_() log.info(f' Position embedding expanded to sequence length {max_sequence_length}, zeroed, and frozen') From f3ce8f85d823c01bd5b03f21fdf1bb1febdc91c6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 05:05:18 +0000 Subject: [PATCH 54/66] cleanup --- composer/core/state.py | 5 +---- composer/utils/checkpoint.py | 7 ++----- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 45526c84ca..4f1de29585 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -916,9 +916,6 @@ def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str options=StateDictOptions( full_state_dict=self.fsdp_state_dict_type != 'sharded', cpu_offload=True, - ignore_frozen_params=True, - keep_submodule_prefixes=True, - strict=True, ), ) optim_state_dict = {type(optimizer).__qualname__: optim_state_dict} @@ -1257,7 +1254,7 @@ def load_model_and_optimizer_state( optimizers=optimizer, model_state_dict=model_state_dict, optim_state_dict=optim_state_dict, - options=StateDictOptions(strict=strict), + options=StateDictOptions(strict=strict, cpu_offload=True), ) else: self._legacy_load_model_state(state_dict, strict) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 614f32ef15..d46c69ecad 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -969,8 +969,7 @@ def save_checkpoint( import torch.distributed.checkpoint as dist_cp from torch.distributed import get_process_group_ranks - log.debug('Saving sharded checkpoints to %s...', save_filename) - log.warning('starting pytorch save state dict') + log.debug(f'Saving sharded checkpoints to {save_filename}...') device_mesh = state.fsdp_device_mesh if device_mesh is not None and device_mesh.ndim == 2: expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) @@ -997,7 +996,7 @@ def save_checkpoint( planner=save_planner, process_group=process_group, ) - log.warning('finished pytorch save state dict') + log.debug('Finished pytorch save state dict') # Only rank 0 saves the state_dict unless you are using sharded checkpointing with torch <2.0 elif dist.get_global_rank() == 0 or state.fsdp_sharded_state_dict_enabled: @@ -1015,9 +1014,7 @@ def save_checkpoint( else: log.debug(f'Only rank 0 is saving a checkpoint, so rank {dist.get_global_rank()} skips checkpointing.') - log.warning('starting dist barrier') dist.barrier() # ensure all ranks saved their files - log.warning('finished dist barrier') if expect_file: assert os.path.exists(save_filename), 'Expected file to have been saved.' From 6e9fa0cd3cd0bb848be35b5c7a84f616de3f3370 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 9 Jan 2024 00:09:24 -0500 Subject: [PATCH 55/66] cleanup --- composer/utils/checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index d46c69ecad..f29d467ef9 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -970,16 +970,16 @@ def save_checkpoint( from torch.distributed import get_process_group_ranks log.debug(f'Saving sharded checkpoints to {save_filename}...') + process_group = None device_mesh = state.fsdp_device_mesh if device_mesh is not None and device_mesh.ndim == 2: expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) - process_group = device_mesh.get_group(1) + process_group = device_mesh.get_group(1) # Only save on first replica log.debug( f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}' ) else: expect_file = True - process_group = None if expect_file: if version.parse(torch.__version__) > version.parse('2.1.3'): From 743ef106e5b71d37e161e61778584068dfde8127 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 9 Jan 2024 10:55:10 -0500 Subject: [PATCH 56/66] remove force sync --- composer/trainer/dist_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 10d95b43a8..0ff803657f 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -232,7 +232,7 @@ def prepare_fsdp_module( set_fsdp_default(fsdp_config) # Check sync_module_states is True for mixed initialization or HSDP - if fsdp_config['sync_module_states'] == False and not fsdp_config.get('force_sync_module_states', False): + if fsdp_config['sync_module_states'] == False: rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0 all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8)) dist.all_reduce(all_ranks_meta, reduce_operation='MIN') From 4042d766b1ffa7e097a2722c7a8d3a7bdf8afa32 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 9 Jan 2024 10:59:02 -0500 Subject: [PATCH 57/66] fix type --- composer/trainer/dist_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 0ff803657f..ca4d973dc9 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -278,7 +278,7 @@ def sync_hook(*args): from torch.distributed._tensor import init_device_mesh kwargs['device_mesh'] = init_device_mesh( 'cuda', - tuple(fsdp_config['device_mesh']), + tuple([int(x) for x in fsdp_config['device_mesh']]), ) # necessary variables for optimizers with multiple param groups in FSDP From 02a0b20d338f349943d4afdf657452ac88689c56 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 17:48:07 +0000 Subject: [PATCH 58/66] merge --- .../alibi/attention_surgery_functions/_bert.py | 10 ++++++---- .../alibi/attention_surgery_functions/_gpt2.py | 6 ++++-- .../alibi/attention_surgery_functions/utils.py | 9 +++++++-- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/composer/algorithms/alibi/attention_surgery_functions/_bert.py b/composer/algorithms/alibi/attention_surgery_functions/_bert.py index 915e940cad..71b31796bb 100644 --- a/composer/algorithms/alibi/attention_surgery_functions/_bert.py +++ b/composer/algorithms/alibi/attention_surgery_functions/_bert.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +import copy from types import MethodType from typing import Optional, Tuple @@ -20,13 +21,14 @@ def bert_embedding_converter(module: torch.nn.Module, module_index: int, max_seq """ assert isinstance(module, (BertEmbeddings, RobertaEmbeddings)) del module_index # unused - zero_and_freeze_expand_position_embeddings(module, + new_module = copy.deepcopy(module) + zero_and_freeze_expand_position_embeddings(new_module, max_sequence_length, position_embedding_attribute='position_embeddings') - module_device = next(module.parameters()).device - module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device)) - return module + module_device = next(new_module.parameters()).device + new_module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device)) + return new_module @policy_registry.register(BertSelfAttention, RobertaSelfAttention) diff --git a/composer/algorithms/alibi/attention_surgery_functions/_gpt2.py b/composer/algorithms/alibi/attention_surgery_functions/_gpt2.py index ed92ab757d..0fdfbd4945 100644 --- a/composer/algorithms/alibi/attention_surgery_functions/_gpt2.py +++ b/composer/algorithms/alibi/attention_surgery_functions/_gpt2.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import copy from types import MethodType from typing import Tuple @@ -17,8 +18,9 @@ def gpt2_embedding_converter(module: torch.nn.Module, module_index: int, max_seq assert isinstance(module, GPT2Model) del module_index # unused - zero_and_freeze_expand_position_embeddings(module, max_sequence_length, position_embedding_attribute='wpe') - return module + new_module = copy.deepcopy(module) + zero_and_freeze_expand_position_embeddings(new_module, max_sequence_length, position_embedding_attribute='wpe') + return new_module @policy_registry.register(GPT2Attention) diff --git a/composer/algorithms/alibi/attention_surgery_functions/utils.py b/composer/algorithms/alibi/attention_surgery_functions/utils.py index 6988806356..d510e1e371 100644 --- a/composer/algorithms/alibi/attention_surgery_functions/utils.py +++ b/composer/algorithms/alibi/attention_surgery_functions/utils.py @@ -125,8 +125,13 @@ def zero_and_freeze_expand_position_embeddings( if not isinstance(old_weight, torch.nn.Parameter): raise TypeError(f'Module {module._get_name()}, position embedding {position_embedding_attribute}, ' f"'weight' attribute must be of type torch.nn.Module") - old_weight.requires_grad = False - old_weight.zero_() + new_weight = torch.nn.Parameter( + torch.zeros((max_sequence_length, old_weight.shape[1]), + dtype=old_weight.dtype, + layout=old_weight.layout, + device=old_weight.device)) + new_weight.requires_grad = False + setattr(pos_embedding_module, 'weight', new_weight) log.info(f' Position embedding expanded to sequence length {max_sequence_length}, zeroed, and frozen') From b5622b3800dae921603bf261293c66ac641ca571 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 9 Jan 2024 12:50:33 -0500 Subject: [PATCH 59/66] lint --- composer/algorithms/alibi/attention_surgery_functions/_bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/algorithms/alibi/attention_surgery_functions/_bert.py b/composer/algorithms/alibi/attention_surgery_functions/_bert.py index 71b31796bb..c2a7bb3bd5 100644 --- a/composer/algorithms/alibi/attention_surgery_functions/_bert.py +++ b/composer/algorithms/alibi/attention_surgery_functions/_bert.py @@ -1,8 +1,8 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -import math import copy +import math from types import MethodType from typing import Optional, Tuple From adc611ac3c2a2a82f968d96af8c3d1b8591ce14a Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 18:06:34 +0000 Subject: [PATCH 60/66] fix gpt --- .../algorithms/alibi/attention_surgery_functions/_gpt2.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/composer/algorithms/alibi/attention_surgery_functions/_gpt2.py b/composer/algorithms/alibi/attention_surgery_functions/_gpt2.py index 0fdfbd4945..ed92ab757d 100644 --- a/composer/algorithms/alibi/attention_surgery_functions/_gpt2.py +++ b/composer/algorithms/alibi/attention_surgery_functions/_gpt2.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -import copy from types import MethodType from typing import Tuple @@ -18,9 +17,8 @@ def gpt2_embedding_converter(module: torch.nn.Module, module_index: int, max_seq assert isinstance(module, GPT2Model) del module_index # unused - new_module = copy.deepcopy(module) - zero_and_freeze_expand_position_embeddings(new_module, max_sequence_length, position_embedding_attribute='wpe') - return new_module + zero_and_freeze_expand_position_embeddings(module, max_sequence_length, position_embedding_attribute='wpe') + return module @policy_registry.register(GPT2Attention) From b5691dd85c11bc772e2fd48f23dc67d1e9479513 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 9 Jan 2024 13:30:01 -0500 Subject: [PATCH 61/66] comment --- composer/trainer/dist_strategy.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index ca4d973dc9..66e5f7a509 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -243,10 +243,11 @@ def prepare_fsdp_module( 'gpu and some ranks are on meta. Either keep all ranks on the same ' "device or set fsdp_config['sync_module_states'] = True. Otherwise, " 'some weights may be randomly initialized when loading a checkpoint.') - if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'): - raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires ' - 'fsdp_config["sync_module_states"] = True or different replicas will ' - 'have different weights.') + # Comment out while we debug deadlock + # if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'): + # raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires ' + # 'fsdp_config["sync_module_states"] = True or different replicas will ' + # 'have different weights.') # Check if other ranks OOMed after forward/backward pass when using auto microbatching. This # may happen when close to memory limit or with uneven memory usage across ranks. Since we From 04279ca73c7e218e83677df3d8b46acc980c6b74 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 9 Jan 2024 13:32:06 -0500 Subject: [PATCH 62/66] fix test --- tests/algorithms/test_required_on_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_required_on_load.py b/tests/algorithms/test_required_on_load.py index 961619a023..3df4998a7d 100644 --- a/tests/algorithms/test_required_on_load.py +++ b/tests/algorithms/test_required_on_load.py @@ -165,7 +165,7 @@ def test_autoload(algo_name: str, load_weights_only: bool, already_added: bool, # Excluding some algorithms leads to errors when loading elif exclude: if version.parse(torch.__version__) > version.parse('2.1.3'): - if algo_name in ['BlurPool', 'Factorize', 'GatedLinearUnits', 'GhostBatchNorm', 'SqueezeExcite']: + if algo_name in ['Alibi', 'BlurPool', 'Factorize', 'GatedLinearUnits', 'GhostBatchNorm', 'SqueezeExcite']: context = pytest.raises(KeyError) # Optimizer loading is strict else: if algo_name in ['Factorize', 'SqueezeExcite']: From 50fd5e2d2c975b04a8c359069a5b55aee21debe2 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 9 Jan 2024 14:29:17 -0500 Subject: [PATCH 63/66] lint --- tests/algorithms/test_required_on_load.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/algorithms/test_required_on_load.py b/tests/algorithms/test_required_on_load.py index 3df4998a7d..ddb05a0c3c 100644 --- a/tests/algorithms/test_required_on_load.py +++ b/tests/algorithms/test_required_on_load.py @@ -165,7 +165,9 @@ def test_autoload(algo_name: str, load_weights_only: bool, already_added: bool, # Excluding some algorithms leads to errors when loading elif exclude: if version.parse(torch.__version__) > version.parse('2.1.3'): - if algo_name in ['Alibi', 'BlurPool', 'Factorize', 'GatedLinearUnits', 'GhostBatchNorm', 'SqueezeExcite']: + if algo_name in [ + 'Alibi', 'BlurPool', 'Factorize', 'GatedLinearUnits', 'GhostBatchNorm', 'SqueezeExcite' + ]: context = pytest.raises(KeyError) # Optimizer loading is strict else: if algo_name in ['Factorize', 'SqueezeExcite']: From fa2f112891dcf35babd95969733bb2cf7dad4c49 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Jan 2024 19:50:19 +0000 Subject: [PATCH 64/66] minor optimizations --- composer/core/state.py | 8 ++++---- composer/utils/checkpoint.py | 9 +++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 4f1de29585..b057aeb315 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -1244,11 +1244,11 @@ def load_model_and_optimizer_state( use_state_dict_fns = version.parse(torch.__version__) > version.parse('2.1.3') if use_state_dict_fns: from torch.distributed.checkpoint.state_dict import StateDictOptions, set_state_dict - optimizer = ensure_tuple(self.optimizers)[0] model_state_dict = state_dict.get('model', {}) - optim_state_dict = state_dict['optimizers'].get(type(optimizer).__qualname__, {}) - if load_model_only: - optimizer, optim_state_dict = [], {} + optimizer, optim_state_dict = [], {} + if not load_model_only: + optimizer = optimizer = ensure_tuple(self.optimizers)[0] + optim_state_dict = state_dict['optimizers'].get(type(optimizer).__qualname__, {}) set_state_dict( self.model, optimizers=optimizer, diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index f29d467ef9..f538ff33c9 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -974,10 +974,11 @@ def save_checkpoint( device_mesh = state.fsdp_device_mesh if device_mesh is not None and device_mesh.ndim == 2: expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) - process_group = device_mesh.get_group(1) # Only save on first replica - log.debug( - f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}' - ) + if expect_file: + process_group = device_mesh.get_group(1) # Only save on first replica + log.debug( + f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}' + ) else: expect_file = True From 81701441b0244bb85a1d83dd4f826cdb10830d74 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 9 Jan 2024 18:04:01 -0500 Subject: [PATCH 65/66] Update composer/core/state.py Co-authored-by: Evan Racah --- composer/core/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/core/state.py b/composer/core/state.py index b057aeb315..be92a799ce 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -1247,7 +1247,7 @@ def load_model_and_optimizer_state( model_state_dict = state_dict.get('model', {}) optimizer, optim_state_dict = [], {} if not load_model_only: - optimizer = optimizer = ensure_tuple(self.optimizers)[0] + optimizer = ensure_tuple(self.optimizers)[0] optim_state_dict = state_dict['optimizers'].get(type(optimizer).__qualname__, {}) set_state_dict( self.model, From ad0c8f638cfe4e273f79d786374bc2a9cbcd1307 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 9 Jan 2024 18:10:12 -0500 Subject: [PATCH 66/66] revert tests --- .github/workflows/pr-cpu.yaml | 10 +++++----- .github/workflows/pr-gpu.yaml | 12 ++++++------ composer/utils/checkpoint.py | 4 ---- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index ac3feaff25..55fbefcfe6 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -27,11 +27,11 @@ jobs: markers: 'not daily and not remote and not gpu and not vision and not doctest' pytest_command: 'coverage run -m pytest' composer_package_name: 'mosaicml' - - name: 'cpu-3.10-2.2' - container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04 - markers: 'not daily and not remote and not gpu and not vision and not doctest' - pytest_command: 'coverage run -m pytest' - composer_package_name: 'mosaicml' + # - name: 'cpu-3.10-2.2' + # container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04 + # markers: 'not daily and not remote and not gpu and not vision and not doctest' + # pytest_command: 'coverage run -m pytest' + # composer_package_name: 'mosaicml' - name: 'cpu-vision' container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04 markers: 'not daily and not remote and not gpu and vision and not doctest' diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 2e68a9de33..acd7b4266a 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -1,6 +1,6 @@ name: PR GPU tests on: - pull_request: + pull_request_target: workflow_dispatch: # Cancel old runs when a new commit is pushed to the same branch if not on main or dev concurrency: @@ -17,11 +17,11 @@ jobs: markers: 'not daily and not remote and gpu and (doctest or not doctest)' pytest_command: 'coverage run -m pytest' composer_package_name: 'mosaicml' - - name: 'gpu-3.10-2.2' - container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04 - markers: 'not daily and not remote and gpu and (doctest or not doctest)' - pytest_command: 'coverage run -m pytest' - composer_package_name: 'mosaicml' + # - name: 'gpu-3.10-2.2' + # container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04 + # markers: 'not daily and not remote and gpu and (doctest or not doctest)' + # pytest_command: 'coverage run -m pytest' + # composer_package_name: 'mosaicml' name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index f29d467ef9..1eb27c3998 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -165,15 +165,11 @@ def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bo def is_checkpoint_legacy_sharded(object_store: Optional[ObjectStore], source_path: str): metadata_path = str(Path(source_path) / Path('.metadata')) if object_store is None: - if not os.path.exists(source_path): - raise FileNotFoundError(f"Couldn't find the directory {source_path}") return not os.path.exists(metadata_path) else: try: with tempfile.TemporaryDirectory() as temp_dir: metadata_destination = os.path.join(str(temp_dir), '.metadata') - if len(object_store.list_objects(prefix=source_path)) == 0: - raise FileNotFoundError(f"Couldn't find the prefix {object_store.get_uri(object_name=source_path)}") object_store.download_object(object_name=metadata_path, filename=metadata_destination) return False except FileNotFoundError: