From eb8cfc586ba27686918dd5032c4623b6330e9bc8 Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 26 Apr 2024 08:23:10 -0700 Subject: [PATCH 1/2] Delay import of device mesh until version check --- scripts/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 295a65713..39bbe1c34 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -11,7 +11,6 @@ import torch.multiprocessing as mp import wandb from packaging import version -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ShardingStrategy @@ -138,7 +137,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: param_init_fn = None # Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica - device_mesh: Optional[DeviceMesh] = None + device_mesh = None if cfg.fsdp.sharding_strategy in (ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2): if version.parse(torch.__version__) < version.parse("2.2.0"): # Device mesh was not added to PyTorch until v2.2.0 @@ -146,6 +145,8 @@ def dummy_init_fn(module: torch.nn.Module) -> None: "OLMo training does not correctly support hybrid sharding before torch 2.2.0" ) + from torch.distributed.device_mesh import init_device_mesh + num_model_replicas = cfg.fsdp.hybrid_sharding_num_model_replicas or ( get_world_size() // get_local_world_size() ) From 652b9cbc5730a348374bde4a5ef6e2c4c306b0ec Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 26 Apr 2024 08:27:18 -0700 Subject: [PATCH 2/2] Pass device_mesh via a dictionary --- scripts/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index 39bbe1c34..23471ca94 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -138,6 +138,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: # Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica device_mesh = None + hybrid_sharding_fsdp_kwargs = {} if cfg.fsdp.sharding_strategy in (ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2): if version.parse(torch.__version__) < version.parse("2.2.0"): # Device mesh was not added to PyTorch until v2.2.0 @@ -159,10 +160,10 @@ def dummy_init_fn(module: torch.nn.Module) -> None: raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes") device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas)) + hybrid_sharding_fsdp_kwargs["device_mesh"] = device_mesh fsdp_model = FSDP( olmo_model, - device_mesh=device_mesh, sharding_strategy=cfg.fsdp.sharding_strategy, mixed_precision=cfg.fsdp_precision, auto_wrap_policy=wrap_policy, @@ -170,6 +171,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: limit_all_gathers=True, device_id=get_local_rank(), param_init_fn=param_init_fn, + **hybrid_sharding_fsdp_kwargs, ) # when param_init_fn is None, FSDP will call reset_parameters() automatically if param_init_fn is not None: