Skip to content

Commit

Permalink
Dataclasses for ParallelismConfig (mosaicml#3346)
Browse files Browse the repository at this point in the history
* v1 paralleism

* fix

* add doc strings

* lint

* fix tests

* clean u ptest

* fix error

* check if dict instances are configs

* fix tests

* fix lint

* fix tests

* fix test

---------

Co-authored-by: Saaketh Narayan <saaketh@mosaicml.com>
Co-authored-by: Your Name <you@example.com>
  • Loading branch information
3 people authored Jun 4, 2024
1 parent ca472cc commit a60bf3a
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 186 deletions.
2 changes: 1 addition & 1 deletion composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
keep_placeholders=True,
).lstrip('/')
assert state.fsdp_config is not None
remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir']
remote_prefix = state.fsdp_config.sharded_ckpt_prefix_dir
assert remote_prefix is not None
ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename)
Expand Down
48 changes: 25 additions & 23 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time
from composer.devices import Device
from composer.utils import (
FSDPConfig,
ParallelismConfig,
ParallelismType,
TPConfig,
VersionedDeprecationWarning,
batch_get,
batch_set,
Expand Down Expand Up @@ -197,8 +200,8 @@ def _ensure_backwards_compatible_checkpointing(state_dict: dict[str, Any]):

def _create_device_mesh(
device: Device,
fsdp_config: Optional[dict[str, Any]],
tp_config: Optional[dict[str, Any]],
fsdp_config: Optional[FSDPConfig],
tp_config: Optional[TPConfig],
) -> Optional[DeviceMesh]:
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
# Device mesh has correctness issues before torch 2.3.0
Expand All @@ -210,13 +213,13 @@ def _create_device_mesh(
# Gather dimensions and names for the device mesh
dims: list[int] = []
names: list[str] = []
if fsdp_config['data_parallel_replicate_degree'] is not None:
dims.append(fsdp_config['data_parallel_replicate_degree'])
if fsdp_config.data_parallel_replicate_degree is not None:
dims.append(fsdp_config.data_parallel_replicate_degree)
names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
dims.append(fsdp_config['data_parallel_shard_degree'])
dims.append(fsdp_config.data_parallel_shard_degree)
names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
if tp_config is not None:
dims.append(tp_config['tensor_parallel_degree'])
dims.append(tp_config.tensor_parallel_degree)
names.append(ParallelismType.TENSOR_PARALLEL.value)

# Fill in the unspecified dimensions
Expand Down Expand Up @@ -329,7 +332,7 @@ class State(Serializable):
algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training.
callbacks (Callback | Sequence[Callback], optional): The callbacks used for training.
deepspeed_config (dict[str, Any], optional): The configuration dictionary for deepspeed.
parallelism_config (dict[str, Any], optional): The configuration dictionary for parallelism.
parallelism_config (ParallelismConfig, optional): The configuration dictionary for parallelism.
Attributes:
batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a
Expand Down Expand Up @@ -496,7 +499,7 @@ def __init__(

# Distributed training configs
deepspeed_config: Optional[dict[str, Any]] = None,
parallelism_config: Optional[dict[str, Any]] = None,
parallelism_config: Optional[ParallelismConfig] = None,
):
self.rank_zero_seed = rank_zero_seed
self.model = model
Expand Down Expand Up @@ -540,9 +543,8 @@ def __init__(
self.profiler: Optional[Profiler] = None

self.deepspeed_config = deepspeed_config
parallelism_config = parallelism_config or {}
self.fsdp_config = parallelism_config.get('fsdp', None)
self.tp_config = parallelism_config.get('tp', None)
self.fsdp_config = parallelism_config.fsdp if parallelism_config is not None else None
self.tp_config = parallelism_config.tp if parallelism_config is not None else None

self._validate_parallelism_configs()

Expand All @@ -552,9 +554,9 @@ def __init__(
if self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names:
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
self.fsdp_config['device_mesh'] = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore
self.fsdp_config.device_mesh = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore
if self.tp_config is not None and self.device_mesh is not None:
self.tp_config['device_mesh'] = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value]
self.tp_config.device_mesh = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value]

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
Expand Down Expand Up @@ -598,11 +600,11 @@ def _validate_parallelism_configs(self):
if self.fsdp_config is None:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP to be enabled. '
'An empty `fsdp_config` can be specified to enable FSDP with '
'default settings. Additionally, PyTorch currently errors if FSDP '
"An empty `parallelism_config['fsdp'] = {}` config can be specified to enable "
'FSDP with default settings. Additionally, PyTorch currently errors if FSDP '
'data_parallel_shard_degree is not at least 2.',
)
if not self.fsdp_config['use_orig_params']:
if not self.fsdp_config.use_orig_params:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, '
'which is the default and recommended setting.',
Expand All @@ -614,10 +616,10 @@ def _validate_parallelism_configs(self):
raise ValueError('load_fsdp_monolith_rank0_only is not compatible with tensor parallelism (TP).')
assert self.fsdp_config is not None
error_message = ''
if self.fsdp_config['sync_module_states'] == False:
if self.fsdp_config.sync_module_states == False:
error_message += textwrap.dedent(
"load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. "
"Either set fsdp_config['sync_module_states'] = True or set load_monolith_rank0_only = False. ",
"load_monolith_rank0_only requires parallelism_config['fsdp']['sync_module_states'] to be True. "
"Either set parallelism_config['fsdp']['sync_module_states'] = True or set load_monolith_rank0_only = False.",
)
# Broadcast rank 0 meta check to all ranks so error can be raised on all ranks
rank0_on_meta = 0
Expand Down Expand Up @@ -654,7 +656,7 @@ def _validate_parallelism_configs(self):
textwrap.dedent(
'Saving metrics is not allowed with sharded state dict as metric tensors will '
'be sharded and break on load. If you wish to save metric state, set '
'fsdp_config["state_dict_type"] = "full" to disable sharded checkpoints.',
"parallelism_config['fsdp']['state_dict_type'] = 'full' to disable sharded checkpoints.",
),
)

Expand Down Expand Up @@ -881,7 +883,7 @@ def fsdp_state_dict_type(self):
if not self.fsdp_enabled:
return None
if self.fsdp_config is not None:
return self.fsdp_config['state_dict_type']
return self.fsdp_config.state_dict_type
return 'full'

@property
Expand All @@ -906,8 +908,8 @@ def load_fsdp_monolith_rank0_only(self):
@property
def load_monolith_rank0_only(self):
return (
self.fsdp_config is not None and self.fsdp_config['auto_wrap'] and
self.fsdp_config['state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True
self.fsdp_config is not None and self.fsdp_config.auto_wrap and
self.fsdp_config.state_dict_type == 'full' and self.fsdp_config.load_monolith_rank0_only == True
)

def _get_integrations_state_dict(self) -> dict[str, Any]:
Expand Down
2 changes: 0 additions & 2 deletions composer/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
prepare_fsdp_module,
prepare_tp_module,
)
from composer.distributed.mosaic_fsdp import set_fsdp_default

__all__ = [
'fix_batch_precision_for_deepspeed',
Expand All @@ -21,5 +20,4 @@
'prepare_ddp_module',
'prepare_fsdp_module',
'prepare_tp_module',
'set_fsdp_default',
]
74 changes: 38 additions & 36 deletions composer/distributed/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from composer.core import Precision, State
from composer.devices import Device
from composer.distributed.meta_safe_apply import meta_safe_apply
from composer.distributed.mosaic_fsdp import (
from composer.distributed.mosaic_parallelism import (
BACKWARD_PREFETCH_MAP,
SHARDING_MAP,
get_cpu_offload,
get_mixed_precision,
set_custom_fsdp_module_kwargs,
)
from composer.utils import StringEnum, dist, ensure_tuple
from composer.utils import FSDPConfig, StringEnum, TPConfig, dist, ensure_tuple

__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module']

Expand Down Expand Up @@ -181,24 +181,24 @@ def _recreate_fsdp_param_groups_from_unwrapped_opt_info(

def prepare_tp_module(
model: torch.nn.Module,
tp_config: dict[str, Any],
tp_config: TPConfig,
) -> None:
"""Prepare a module (assumed ComposerModel) for use with tensor parallel."""
from torch.distributed.tensor.parallel import parallelize_module

device_mesh = tp_config['device_mesh']
layer_plan = tp_config['layer_plan']
device_mesh = tp_config.device_mesh
assert device_mesh is not None # For type checking, set in State.__init__
parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan=layer_plan,
parallelize_plan=tp_config.layer_plan,
)


def prepare_fsdp_module(
model: torch.nn.Module,
optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]],
fsdp_config: dict[str, Any],
fsdp_config: FSDPConfig,
precision: Precision,
device: Device,
auto_microbatching: bool,
Expand All @@ -216,7 +216,7 @@ def prepare_fsdp_module(
te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234.
"""
# Check sync_module_states is True for mixed initialization or HSDP
if fsdp_config['sync_module_states'] == False:
if fsdp_config.sync_module_states == False:
rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0
all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8))
dist.all_reduce(all_ranks_meta, reduce_operation='MIN')
Expand All @@ -226,7 +226,7 @@ def prepare_fsdp_module(
raise ValueError(
'Detected mixed initialization where some ranks have model on cpu or '
'gpu and some ranks are on meta. Either keep all ranks on the same '
"device or set fsdp_config['sync_module_states'] = True. Otherwise, "
"device or set parallelism_config['fsdp']['sync_module_states'] = True. Otherwise, "
'some weights may be randomly initialized when loading a checkpoint.',
)

Expand Down Expand Up @@ -263,7 +263,7 @@ def sync_hook(*args):

num_param_groups = len(optim.param_groups)
if num_param_groups > 1:
if not fsdp_config['use_orig_params']:
if not fsdp_config.use_orig_params:
raise RuntimeError(
'Multiple optimizer groups with FSDP are only supported with '
'use_orig_params=True.',
Expand Down Expand Up @@ -297,17 +297,19 @@ def sync_hook(*args):
optim.param_groups.clear()
optim.state.clear()

sharding_map_key = fsdp_config['sharding_strategy'].upper()
sharding_map_key = fsdp_config.sharding_strategy.upper()
sharding_strategy = SHARDING_MAP[sharding_map_key]

kwargs = {}
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0') and 'device_mesh' in fsdp_config:
if fsdp_config['process_group'] is not None:
if version.parse(
torch.__version__.split('.dev')[0],
) >= version.parse('2.2.0') and fsdp_config.device_mesh is not None:
if fsdp_config.process_group is not None:
warnings.warn(
'process_group and device_mesh are set for FSDP, so ignoring device_mesh. Please set process_group to None.',
)
else:
ndim = fsdp_config['device_mesh'].ndim
ndim = fsdp_config.device_mesh.ndim
if ndim == 1 and sharding_strategy == ShardingStrategy.HYBRID_SHARD:
sharding_strategy = ShardingStrategy.FULL_SHARD
warnings.warn('HYBRID_SHARD is not supported with 1D device mesh. Using FULL_SHARD instead.')
Expand All @@ -320,12 +322,12 @@ def sync_hook(*args):
elif ndim == 2 and sharding_strategy == ShardingStrategy.FULL_SHARD:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
warnings.warn('FULL_SHARD is not supported with 2D device mesh. Using HYBRID_SHARD instead.')
kwargs['device_mesh'] = fsdp_config['device_mesh']
kwargs['device_mesh'] = fsdp_config.device_mesh

cpu_offload = get_cpu_offload(cpu_offload=fsdp_config['cpu_offload'])
cpu_offload = get_cpu_offload(cpu_offload=fsdp_config.cpu_offload)

mixed_precision = fsdp_config['mixed_precision']
keep_low_precision_grads = fsdp_config['keep_low_precision_grads']
mixed_precision = fsdp_config.mixed_precision
keep_low_precision_grads = fsdp_config.keep_low_precision_grads
mixed_precision, param_dtype, _, _ = get_mixed_precision(
precision,
mixed_precision=mixed_precision,
Expand Down Expand Up @@ -357,22 +359,22 @@ def sync_hook(*args):
)

process_group = None
if fsdp_config['process_group'] is not None:
process_group_dict = {'process_group': fsdp_config['process_group']}
if fsdp_config.process_group is not None:
process_group_dict = {'process_group': fsdp_config.process_group}
process_group = set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group']
backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config['backward_prefetch'].upper()]
activation_checkpointing = fsdp_config['activation_checkpointing']
activation_cpu_offload = fsdp_config['activation_cpu_offload']
sync_module_states = fsdp_config['sync_module_states']
forward_prefetch = fsdp_config['forward_prefetch']
limit_all_gathers = fsdp_config['limit_all_gathers']
ignored_modules = fsdp_config['ignored_modules']
state_dict_type = fsdp_config['state_dict_type']
activation_checkpointing_reentrant = fsdp_config['activation_checkpointing_reentrant']
te_checkpoint_wrapper = fsdp_config['te_checkpoint_wrapper'] if precision == Precision.AMP_FP8 else False
te_shard_fp8_weight = fsdp_config['te_shard_fp8_weight'] if precision == Precision.AMP_FP8 else False
sharded_ckpt_prefix_dir = fsdp_config['sharded_ckpt_prefix_dir']
use_orig_params = fsdp_config['use_orig_params']
backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config.backward_prefetch.upper()]
activation_checkpointing = fsdp_config.activation_checkpointing
activation_cpu_offload = fsdp_config.activation_cpu_offload
sync_module_states = fsdp_config.sync_module_states
forward_prefetch = fsdp_config.forward_prefetch
limit_all_gathers = fsdp_config.limit_all_gathers
ignored_modules = fsdp_config.ignored_modules
state_dict_type = fsdp_config.state_dict_type
activation_checkpointing_reentrant = fsdp_config.activation_checkpointing_reentrant
te_checkpoint_wrapper = fsdp_config.te_checkpoint_wrapper if precision == Precision.AMP_FP8 else False
te_shard_fp8_weight = fsdp_config.te_shard_fp8_weight if precision == Precision.AMP_FP8 else False
sharded_ckpt_prefix_dir = fsdp_config.sharded_ckpt_prefix_dir
use_orig_params = fsdp_config.use_orig_params

# We choose to not wrap the ComposerModel directly, but instead wrap any submodules like `ComposerModel.model`
# This makes it safer to call ComposerModel-specific functions like 'eval_forward' that
Expand Down Expand Up @@ -591,15 +593,15 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num

if hasattr(fsdp_obj, '_exec_order_data'):
if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'):
fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config['forward_prefetch_limit']
fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config.forward_prefetch_limit
else:
warnings.warn(
'FSDP._exec_order_data does not have attribute _forward_prefetch_limit '
'which is unexpected and will result in `forward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.',
)
if hasattr(fsdp_obj._exec_order_data, '_backward_prefetch_limit'):
fsdp_obj._exec_order_data._backward_prefetch_limit = fsdp_config['backward_prefetch_limit']
fsdp_obj._exec_order_data._backward_prefetch_limit = fsdp_config.backward_prefetch_limit
else:
warnings.warn(
'FSDP._exec_order_data does not have attribute _backward_prefetch_limit '
Expand Down Expand Up @@ -712,7 +714,7 @@ def _check_fn(module: torch.nn.Module) -> bool:
setattr(model, obj_name, fsdp_obj)

# Print FSDP wrapped model and FSDP config if `verbose=True`
if fsdp_config['verbose']:
if fsdp_config.verbose:
log.info(f'FSDP: Wrapped model: {model}')
log.info(f'FSDP: Using sharding_strategy={sharding_strategy}')
log.info(f'FSDP: Using cpu_offload={cpu_offload}')
Expand Down
Loading

0 comments on commit a60bf3a

Please sign in to comment.