Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay device mesh import #561

Merged
merged 4 commits into from
Apr 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.multiprocessing as mp
import wandb
from packaging import version
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

Expand Down Expand Up @@ -138,14 +137,17 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
param_init_fn = None

# Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica
device_mesh: Optional[DeviceMesh] = None
device_mesh = None
hybrid_sharding_fsdp_kwargs = {}
if cfg.fsdp.sharding_strategy in (ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2):
if version.parse(torch.__version__) < version.parse("2.2.0"):
# Device mesh was not added to PyTorch until v2.2.0
raise OLMoConfigurationError(
"OLMo training does not correctly support hybrid sharding before torch 2.2.0"
)

from torch.distributed.device_mesh import init_device_mesh

num_model_replicas = cfg.fsdp.hybrid_sharding_num_model_replicas or (
get_world_size() // get_local_world_size()
)
Expand All @@ -158,17 +160,18 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes")

device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas))
hybrid_sharding_fsdp_kwargs["device_mesh"] = device_mesh

fsdp_model = FSDP(
olmo_model,
device_mesh=device_mesh,
sharding_strategy=cfg.fsdp.sharding_strategy,
mixed_precision=cfg.fsdp_precision,
auto_wrap_policy=wrap_policy,
use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics
limit_all_gathers=True,
device_id=get_local_rank(),
param_init_fn=param_init_fn,
**hybrid_sharding_fsdp_kwargs,
)
# when param_init_fn is None, FSDP will call reset_parameters() automatically
if param_init_fn is not None:
Expand Down
Loading