Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds DTensor Support #2821

Merged
merged 76 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
bda5bab
fixes to get dtensor to work
eracah Nov 22, 2023
60d02b4
more fixes
eracah Dec 15, 2023
8897c52
Change state dict materialization for new version of torch
eracah Dec 20, 2023
a88b430
get load working for new set_state_dict api
eracah Dec 21, 2023
8deef8d
merge hsdp
abhi-databricks Dec 26, 2023
888bec0
use device_mesh
abhi-databricks Dec 27, 2023
24ba9ec
Add fsdp init monkeypatch for DTensor
b-chu Dec 28, 2023
d07b5af
Add checkpoint profiling logs
b-chu Dec 28, 2023
9326b0e
attempt
abhi-mosaic Dec 29, 2023
221e58e
working single node
abhi-databricks Dec 29, 2023
b60d0ab
fix optimizer
Dec 30, 2023
28aa784
allow 3d device mesh
abhi-databricks Jan 3, 2024
730368b
attempt to use different pg during 3d mesh save
abhi-databricks Jan 3, 2024
ddb336f
undo 3d mesh changes
abhi-databricks Jan 5, 2024
f184e31
load_state_dict -> load
eracah Jan 5, 2024
dd1c3c4
allow parent mesh in FSDP init
abhi-databricks Jan 5, 2024
9d9eeb7
merge dev
abhi-databricks Jan 5, 2024
d5f9951
allow override of force_sync_module_states
abhi-databricks Jan 5, 2024
0aed0bf
remove unnecessary exit
abhi-databricks Jan 5, 2024
dd936ea
ignore _validate_and_get_shard_state()
abhi-databricks Jan 5, 2024
f864a26
save/load hsdp-moe working
abhi-databricks Jan 5, 2024
77a1417
remove prints
abhi-databricks Jan 5, 2024
355d535
v1
mvpatel2000 Jan 7, 2024
1033fdf
v2
mvpatel2000 Jan 7, 2024
06b60a3
lint
mvpatel2000 Jan 7, 2024
e1cd8cf
add more tests
mvpatel2000 Jan 7, 2024
de6a568
switch to PRs
mvpatel2000 Jan 7, 2024
0b721f3
ignore warning
mvpatel2000 Jan 8, 2024
893bcb7
fix lint
mvpatel2000 Jan 8, 2024
d66da7b
version error
mvpatel2000 Jan 8, 2024
02faba7
fix version
mvpatel2000 Jan 8, 2024
eb05a68
Merge branch 'dev' into mvpatel2000/hsdp-dtensor-v2
mvpatel2000 Jan 8, 2024
7f8e08c
fix state dict
mvpatel2000 Jan 8, 2024
56eb06d
Merge branch 'mvpatel2000/hsdp-dtensor-v2' of github.com:mosaicml/com…
mvpatel2000 Jan 8, 2024
3c557bf
update versions
mvpatel2000 Jan 8, 2024
92e38ca
lint
mvpatel2000 Jan 8, 2024
3ee188f
lint
mvpatel2000 Jan 8, 2024
91cd547
disable lint for mosaic fsdp utils
mvpatel2000 Jan 8, 2024
b902b1b
remove bad line
mvpatel2000 Jan 8, 2024
74f459d
move around for legacy
mvpatel2000 Jan 8, 2024
0dc5068
merge with dev
mvpatel2000 Jan 8, 2024
0e0cefc
device mesh
mvpatel2000 Jan 8, 2024
37941f5
ignore warning
mvpatel2000 Jan 8, 2024
db9f54d
fix import
mvpatel2000 Jan 8, 2024
52e7b32
always init
mvpatel2000 Jan 8, 2024
81dfad0
fix error
mvpatel2000 Jan 8, 2024
956ce55
fix load planner
mvpatel2000 Jan 8, 2024
b40b8a5
remove
mvpatel2000 Jan 8, 2024
1ec3774
fix lint
mvpatel2000 Jan 8, 2024
f59bdaa
lint
mvpatel2000 Jan 8, 2024
ffaa7f4
delay state dict
mvpatel2000 Jan 8, 2024
5309476
Merge branch 'dev' into mvpatel2000/hsdp-dtensor-v2
mvpatel2000 Jan 9, 2024
51f6e3b
test checkpoint
Jan 9, 2024
13b86ce
checkpoint
Jan 9, 2024
1bb1e4a
fix cpu tests
Jan 9, 2024
9899826
fix rotate tests
Jan 9, 2024
7efbd32
fix precision
Jan 9, 2024
5d185d1
lint
mvpatel2000 Jan 9, 2024
575fe35
fix alibi
Jan 9, 2024
f3ce8f8
cleanup
Jan 9, 2024
6e9fa0c
cleanup
mvpatel2000 Jan 9, 2024
743ef10
remove force sync
mvpatel2000 Jan 9, 2024
4042d76
fix type
mvpatel2000 Jan 9, 2024
02a0b20
merge
Jan 9, 2024
e6e4e73
Merge branch 'mvpatel2000/hsdp-dtensor-v2' of github.com:mosaicml/com…
Jan 9, 2024
b5622b3
lint
mvpatel2000 Jan 9, 2024
adc611a
fix gpt
Jan 9, 2024
b5691dd
comment
mvpatel2000 Jan 9, 2024
04279ca
fix test
mvpatel2000 Jan 9, 2024
50fd5e2
lint
mvpatel2000 Jan 9, 2024
fa2f112
minor optimizations
Jan 9, 2024
c110223
Merge branch 'mvpatel2000/hsdp-dtensor-v2' of github.com:mosaicml/com…
Jan 9, 2024
8170144
Update composer/core/state.py
mvpatel2000 Jan 9, 2024
ad0c8f6
revert tests
mvpatel2000 Jan 9, 2024
3155297
Merge branch 'mvpatel2000/hsdp-dtensor-v2' of github.com:mosaicml/com…
mvpatel2000 Jan 9, 2024
cad212c
Merge branch 'dev' into mvpatel2000/hsdp-dtensor-v2
mvpatel2000 Jan 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ jobs:
markers: 'not daily and not remote and not gpu and not vision and not doctest'
pytest_command: 'coverage run -m pytest'
composer_package_name: 'mosaicml'
- name: 'cpu-3.10-2.2'
container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04
markers: 'not daily and not remote and not gpu and not vision and not doctest'
pytest_command: 'coverage run -m pytest'
composer_package_name: 'mosaicml'
- name: 'cpu-vision'
container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04
markers: 'not daily and not remote and not gpu and vision and not doctest'
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: PR GPU tests
on:
pull_request_target:
pull_request:
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
workflow_dispatch:
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
concurrency:
Expand All @@ -17,6 +17,11 @@ jobs:
markers: 'not daily and not remote and gpu and (doctest or not doctest)'
pytest_command: 'coverage run -m pytest'
composer_package_name: 'mosaicml'
- name: 'gpu-3.10-2.2'
container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04
markers: 'not daily and not remote and gpu and (doctest or not doctest)'
pytest_command: 'coverage run -m pytest'
composer_package_name: 'mosaicml'
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,8 @@ def zero_and_freeze_expand_position_embeddings(
if not isinstance(old_weight, torch.nn.Parameter):
raise TypeError(f'Module {module._get_name()}, position embedding {position_embedding_attribute}, '
f"'weight' attribute must be of type torch.nn.Module")
new_weight = torch.nn.Parameter(
torch.zeros((max_sequence_length, old_weight.shape[1]),
dtype=old_weight.dtype,
layout=old_weight.layout,
device=old_weight.device))
new_weight.requires_grad = False
setattr(pos_embedding_module, 'weight', new_weight)
old_weight.requires_grad = False
old_weight.zero_()

log.info(f' Position embedding expanded to sequence length {max_sequence_length}, zeroed, and frozen')

Expand Down
179 changes: 121 additions & 58 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -792,6 +792,15 @@ def fsdp_state_dict_type(self):
def fsdp_sharded_state_dict_enabled(self):
return self.fsdp_config is not None and self.fsdp_enabled and self.fsdp_state_dict_type in ['sharded', 'local']

@property
def fsdp_device_mesh(self):
if self.fsdp_enabled:
if not hasattr(self.model, 'model'):
return None
return self.model.model._device_mesh
else:
return None

@property
def load_fsdp_monolith_rank0_only(self):
return self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config[
Expand Down Expand Up @@ -864,6 +873,9 @@ def get_model_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the model.
"""
return self.get_model_and_optimizer_state_dict(model_only=True)[0]

def _legacy_get_model_state_dict(self) -> Dict[str, Any]:
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
model_state_dict = self.model.state_dict()
Expand All @@ -876,30 +888,60 @@ def get_model_state_dict(self) -> Dict[str, Any]:
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.')
return model_state_dict

def _legacy_get_optim_state_dict(self) -> Dict[str, Any]:
optimizer = ensure_tuple(self.optimizers)[0] # Let's stop pretending. We don't support more than one optimizer.
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
optim_state_dict = {
type(optimizer).__qualname__:
fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type)
}
else:
optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()}
return optim_state_dict

def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if version.parse(torch.__version__) > version.parse('2.1.3'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set '
'fsdp_state_dict_type to None, "full", or "sharded".'))

optimizer = ensure_tuple(self.optimizers)[0]
model_state_dict, optim_state_dict = get_state_dict(
model=self.model,
optimizers=([] if model_only else optimizer),
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type != 'sharded',
cpu_offload=True,
),
)
optim_state_dict = {type(optimizer).__qualname__: optim_state_dict}
else:
model_state_dict = self._legacy_get_model_state_dict()
optim_state_dict = self._legacy_get_optim_state_dict()

return model_state_dict, optim_state_dict

def state_dict(self) -> Dict[str, Any]:
"""Collect the state dicts of our serializable attributes.

Returns:
Dict[str, Any]: The state dict.
"""
state_dict = {}

model_state_dict, optim_state_dict = None, None
if 'model' in self.serialized_attributes or 'optimizers' in self.serialized_attributes:
model_state_dict, optim_state_dict = self.get_model_and_optimizer_state_dict()
for attribute_name in self.serialized_attributes:
attribute_value = getattr(self, attribute_name)
if attribute_name == 'dataset_state':
serialized_value = self._dataset_state_dict()
elif attribute_name == 'model':
serialized_value = self.get_model_state_dict()
serialized_value = model_state_dict
elif attribute_name == 'optimizers':
optimizer = ensure_tuple(attribute_value)[
0] # Let's stop pretending. We don't support more than one optimizer.
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
optim_state_dict = {
type(optimizer).__qualname__:
fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type)
}
else:
optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()}
serialized_value = optim_state_dict
elif attribute_name == 'algorithms':
# Store as list to preserve order in which algorithms were applied
Expand Down Expand Up @@ -1058,49 +1100,34 @@ def _apply_required_algorithms(
'have undergone surgery, the following algorithms may be excluded using '
f'`load_exclude_algorithms`, e.g. `load_exclude_algorithms=[{missing_algo_names}]`.')) from e

def load_model_state(
def _legacy_load_model_state(
self,
state_dict: Dict[str, Any],
logger: Logger,
strict: bool,
exclude_algorithms: Optional[List[str]] = None,
algorithm_passes: Optional[List[AlgorithmPass]] = None,
):
"""Loads the model's state from a ``state_dict``.

Args:
state_dict (Dict[str, Any]): The state dict, generated from a previous call to :meth:`state_dict`.
logger (Logger): The logger.
strict (bool): Whether the keys (i.e., model parameter names) in the model state dict should
perfectly match the keys in the model instance.
exclude_algorithms (List[str], optional): List of algorithm names to exclude from autoloading. (default: ``None``)
algorithm_passes (List[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms
to sort them into the correct order. (default: ``None``)
"""
if 'algorithms' in state_dict:
self._apply_required_algorithms(state_dict, logger, exclude_algorithms, algorithm_passes)

if state_dict.get('is_model_ddp', False) and not self.is_model_ddp:
# This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state
# with the `module.` prefix
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.')

# For FSDP monolith checkpoints, the model does not exist on ranks > 0
model_on_rank = state_dict['model'] is not None
if state_dict['model'] is None:
return

missing_keys, unexpected_keys = [], []
try:
# Load model if it exists. For FSDP monolith checkpoints, the model does not exist on ranks > 0
if model_on_rank:
if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only:
log.debug(
f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}'
)
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
else:
log.debug(f'Loading model state dict with strict={strict}')
# Load model if it exists
if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only:
log.debug(
f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}'
)
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
else:
log.debug(f'Loading model state dict with strict={strict}')
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
except RuntimeError as e:
if 'Missing key(s) in state_dict' in str(e) or 'Unexpected key(s) in state_dict' in str(e):
raise RuntimeError(
Expand All @@ -1110,9 +1137,9 @@ def load_model_state(
else:
raise e

if model_on_rank and len(missing_keys) > 0:
if len(missing_keys) > 0:
log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
if model_on_rank and len(unexpected_keys) > 0:
if len(unexpected_keys) > 0:
if self.fsdp_config is not None and self.fsdp_config[
'use_orig_params'] and self.fsdp_state_dict_type == 'local':
log.warning(
Expand All @@ -1122,16 +1149,7 @@ def load_model_state(
'was still loaded correctly.')
log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")

# If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading
if self.load_fsdp_monolith_rank0_only:
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device,
self.auto_microbatching)
log.debug('Finished wrapping model with FSDP.')

def load_optim_state(self, state_dict: Dict[str, Any]):
def _legacy_load_optim_state(self, state_dict: Dict[str, Any]):
"""Load the optimizer state.

Args:
Expand Down Expand Up @@ -1205,6 +1223,55 @@ def _load_dataset_state(self, obj: Dict[str, Any]) -> None:
# starts. This avoids "CUDA error: initialization error" -- its not clear why.
# self.dataset_resumption['eval'][evaluator.label] = True

def load_model_and_optimizer_state(
self,
state_dict: Dict[str, Any],
logger: Logger,
strict: bool,
exclude_algorithms: Optional[List[str]] = None,
algorithm_passes: Optional[List[AlgorithmPass]] = None,
load_model_only: bool = False,
):
if 'algorithms' in state_dict:
self._apply_required_algorithms(state_dict, logger, exclude_algorithms, algorithm_passes)

if state_dict.get('is_model_ddp', False) and not self.is_model_ddp:
# This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state
# with the `module.` prefix
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.')

# Load model and optimizer state
use_state_dict_fns = version.parse(torch.__version__) > version.parse('2.1.3')
if use_state_dict_fns:
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_state_dict
optimizer = ensure_tuple(self.optimizers)[0]
model_state_dict = state_dict.get('model', {})
optim_state_dict = state_dict['optimizers'].get(type(optimizer).__qualname__, {})
if load_model_only:
optimizer, optim_state_dict = [], {}
set_state_dict(
self.model,
optimizers=optimizer,
model_state_dict=model_state_dict,
optim_state_dict=optim_state_dict,
options=StateDictOptions(strict=strict, cpu_offload=True),
)
else:
self._legacy_load_model_state(state_dict, strict)

# If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading
if self.load_fsdp_monolith_rank0_only:
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device,
self.auto_microbatching)
log.debug('Finished wrapping model with FSDP.')

# Legacy optimizer state load must happen after FSDP monolith
if not use_state_dict_fns and not load_model_only:
self._legacy_load_optim_state(state_dict)

def load_state_dict(
self,
state: Dict[str, Any],
Expand All @@ -1228,28 +1295,26 @@ def load_state_dict(

# Call load_model_state first since it applies required algorithms
if 'model' in state:
self.load_model_state(
self.load_model_and_optimizer_state(
state,
logger,
strict=strict,
exclude_algorithms=exclude_algorithms,
algorithm_passes=algorithm_passes,
load_model_only=(not 'optimizers' in state),
)

for attribute_name in sorted(state.keys()): # Sort so all ranks load in the same order
serialized_value = state[attribute_name]
# Skip removed attributes as well as algorithms and model, which was already loaded
if attribute_name not in self.serialized_attributes or attribute_name == 'model':
if attribute_name not in self.serialized_attributes or attribute_name in ['model', 'optimizers']:
continue

# Integrations are extra information about other libraries (e.g. huggingface) and not attributes to be loaded here
if attribute_name == 'integrations':
continue

# Skip metadata, which is not an attribute on State
if attribute_name == 'metadata':
continue

log.debug(f'Loading {attribute_name} into state.')

# Restructure algorithms serialized_value from list to dict
Expand All @@ -1258,8 +1323,6 @@ def load_state_dict(

if attribute_name == 'dataset_state':
self._load_dataset_state(serialized_value)
elif attribute_name == 'optimizers':
self.load_optim_state(state)
elif attribute_name == 'train_metrics':
# Get current metrics object and populate each metric present
# in serialization with serialized data via load_state_dict()
Expand Down
9 changes: 8 additions & 1 deletion composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def prepare_fsdp_module(
set_fsdp_default(fsdp_config)

# Check sync_module_states is True for mixed initialization or HSDP
if fsdp_config['sync_module_states'] == False:
if fsdp_config['sync_module_states'] == False and not fsdp_config.get('force_sync_module_states', False):
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0
all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8))
dist.all_reduce(all_ranks_meta, reduce_operation='MIN')
Expand Down Expand Up @@ -273,6 +273,13 @@ def sync_hook(*args):
# `nn.Module.named_parameters`.
# Setting it to `True` is mandatory when using `torch.compile()`.
kwargs['use_orig_params'] = fsdp_config['use_orig_params']
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'):
if 'device_mesh' in fsdp_config:
from torch.distributed._tensor import init_device_mesh
kwargs['device_mesh'] = init_device_mesh(
'cuda',
tuple(fsdp_config['device_mesh']),
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
)

# necessary variables for optimizers with multiple param groups in FSDP
num_param_groups = None
Expand Down
4 changes: 4 additions & 0 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ def patch_pytorch():
_runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
_runtime_utils._root_pre_forward = _root_pre_forward
FullyShardedDataParallel.forward = forward

# Monkeypatch dtensor support
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0
FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore
Loading
Loading