From 704c07ee5b1a1723711a6af2059eee1c044af85f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 23 Jan 2024 17:44:07 -0800 Subject: [PATCH] Remove monkeypatch and new state dict APIs for torch 2.2 (#2899) * fix mosaicfsdp * bump to 2.3 * remove init --- composer/core/state.py | 8 ++++---- composer/trainer/mosaic_fsdp.py | 10 ---------- composer/trainer/trainer.py | 5 ++++- composer/utils/checkpoint.py | 14 +++++++------- 4 files changed, 15 insertions(+), 22 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 4967aa1dba..59b5babfe7 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -874,7 +874,7 @@ def get_model_state_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The state dict for the model. """ - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: raise NotImplementedError( @@ -909,7 +909,7 @@ def get_optim_state_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The state dict for the optimizer. """ - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: raise NotImplementedError( @@ -1216,7 +1216,7 @@ def load_model_state( model_on_rank = state_dict['model'] is not None if model_on_rank: - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict set_model_state_dict( model=self.model, @@ -1277,7 +1277,7 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True): strict (bool): Whether the keys (i.e., optimizer parameter names) in the optimizer state dict should perfectly match the keys in the optimizer instance. """ - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict optimizer = self.optimizers[0] set_optimizer_state_dict( diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 1a8ab77bbf..1b346e92e4 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -69,16 +69,6 @@ def patch_pytorch(): from torch.distributed.fsdp import _runtime_utils _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None - # Monkeypatch dtensor support - from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0 - FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore - - # Monkeypath state_dict - from torch.distributed.checkpoint import state_dict # type: ignore - - from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0 - state_dict._verify_options = _verify_options_t2p2p0 - elif version.parse(torch.__version__) < version.parse('2.3.1'): # Monkey patch for torch < 2.3.1 ie torch == 2.3.0 # Note: this is the same patch as 2.2.0, we are just making a new if branch diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index e8c587288a..80a519d758 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -986,7 +986,10 @@ def __init__( assert not isinstance(device_train_microbatch_size, str) # Distributed - dist.initialize_dist(device, dist_timeout) + if deepspeed_config is not None or fsdp_config is not None or dist.get_world_size() > 1: + # Deepspeed and FSDP both require torch.distributed to be initialized, even if the world size is 1 + # And torch.distributed is always required for multi-rank training + dist.initialize_dist(device, dist_timeout) # Reproducibility rank_zero_seed, seed = _distribute_and_get_random_seed(seed, device) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index c47184eada..ddb2f3236a 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -494,7 +494,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): else: cur_state_dict = state.state_dict() # For older versions of torch, we load optimizer separately. - if version.parse(torch.__version__) < version.parse('2.1.3'): + if version.parse(torch.__version__) < version.parse('2.2.9'): cur_state_dict.pop('optimizers') state_dict: Dict[str, Any] = { 'state': cur_state_dict, @@ -523,7 +523,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): else: expect_file = True - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): dist_cp.load( # type: ignore state_dict=state_dict, storage_reader=storage_reader, @@ -547,8 +547,8 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): ) # 2. Optionally load optimizer - # if we are using later than 2.1.0 then optimizer will already be loaded - if version.parse(torch.__version__) < version.parse('2.1.3') and not load_weights_only: + # if we are using later than 2.2.9 then optimizer will already be loaded + if version.parse(torch.__version__) < version.parse('2.2.9') and not load_weights_only: optim_state = load_sharded_optimizer_state_dict(model_state_dict=state.state_dict()['model'], optimizer_key='optimizers', storage_reader=storage_reader) @@ -956,12 +956,12 @@ def _save_checkpoint( state_dict['state'] = state_dict.get('state', {}) if state.fsdp_sharded_state_dict_enabled: - # To load optimizer states with 2.0 <= torch < 2.1.3 , the optimizer state must be at the top + # To load optimizer states with 2.0 <= torch < 2.2.9 , the optimizer state must be at the top # level of the state dict because the load_sharded_optimizer_state_dict function # requires a top level state dict key for the optimizer. # See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271 # for more info. - if using_torch_2() and version.parse(torch.__version__) < version.parse('2.1.3'): + if using_torch_2() and version.parse(torch.__version__) < version.parse('2.2.9'): if not weights_only: state_dict['optimizers'] = state_dict['state'].pop('optimizers') log.debug('State dict created.') @@ -1007,7 +1007,7 @@ def _save_checkpoint( expect_file = True if expect_file: - if version.parse(torch.__version__) > version.parse('2.1.3'): + if version.parse(torch.__version__) > version.parse('2.2.9'): dist_cp.save( # type: ignore state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(dirname),