Skip to content

Commit

Permalink
Hsdp + MoE CI tests (#3378)
Browse files Browse the repository at this point in the history
* fold ema fsdp state

* debug

* debug

* more debug

* keep debugging

* debug

* sanity check

* debug

* debug

* use ema

* debug

* debug

* debug

* debug

* debug

* debug

* more fix

* filename test

* revert test

* fully parameterize

* hsdp test

* revert testing

* typo

* typo

* hsdp

* split off test

* precommit

* float to int

* pyright

* oom

* print

* rm tp

* tp cfg

* tp?

* rm tp line

* type annotation

* revert

* readd tp

* type

* world size

* revert

* revert monolithic cpkt + include sharded cpkt

* enumerate

* precommit

* precommit

* sharded

* sync

* only sync on first trainer

* typo

* hsdp

* xfail

* explicit sync

* test

* revert test

* sync, docker issue

* pre-commit

* sync

* pytest

* xfail

* rm world_size param

* im so sorry pls forgive me king

* the kings comments

* Update tests/trainer/test_fsdp_checkpoint.py

fix formatting

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

* precommit

---------

Co-authored-by: v-chen_data <v-chen_data@example.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
  • Loading branch information
4 people authored Jun 24, 2024
1 parent 8b32fbc commit 072758e
Showing 1 changed file with 59 additions and 27 deletions.
86 changes: 59 additions & 27 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,21 +289,21 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2):
@pytest.mark.gpu
@pytest.mark.filterwarnings(r'ignore:.*scatter_full_optim_state_dict``is being deprecated.*:UserWarning')
@pytest.mark.parametrize(
'world_size,optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp',
'optimizer,autoresume,precision,save_weights_only,load_weights_only,load_monolith_rank0_only,use_tp,use_hsdp',
[
pytest.param(2, 'adam', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adamw', False, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', True, 'amp_bf16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_fp16', False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_bf16', True, True, False, False,
pytest.param('adam', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adamw', False, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', True, 'amp_bf16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_fp16', False, False, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', True, True, False, False, False,
marks=pytest.mark.world_size(2)), # save_weights_only requires load_weights_only
pytest.param(2, 'adam', False, 'amp_bf16', False, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, 'adam', False, 'amp_bf16', False, False, True, False, marks=pytest.mark.world_size(2)),
pytest.param(4, 'adam', False, 'amp_bf16', False, False, False, True, marks=pytest.mark.world_size(4)),
pytest.param('adam', False, 'amp_bf16', False, True, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', False, False, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param('adam', False, 'amp_bf16', False, False, False, True, False, marks=pytest.mark.world_size(4)),
pytest.param('adam', False, 'amp_bf16', False, False, False, False, True, marks=pytest.mark.world_size(4)),
],
)
def test_fsdp_full_state_dict_load(
world_size,
tmp_path: pathlib.Path,
autoresume: bool,
precision: str,
Expand All @@ -312,19 +312,31 @@ def test_fsdp_full_state_dict_load(
load_weights_only: bool,
load_monolith_rank0_only: bool,
use_tp: bool,
use_hsdp: bool,
):
if use_hsdp:
pytest.xfail('Known Pytorch issue with HSDP, waiting for pytorch patch')
if autoresume:
run_name = 'my-cool-autoresume-run'
else:
run_name = None
save_folder = tmp_path
save_filename = 'rank{rank}.pt'

fsdp_config = FSDPConfig(
sharded_ckpt_prefix_dir='ba{batch}',
sync_module_states=load_monolith_rank0_only,
load_monolith_rank0_only=load_monolith_rank0_only,
)
if use_hsdp:
fsdp_config = FSDPConfig(
sharding_strategy='HYBRID_SHARD',
sharded_ckpt_prefix_dir='ba{batch}',
data_parallel_shard_degree=2,
data_parallel_replicate_degree=2,
sync_module_states=True,
)
else:
fsdp_config = FSDPConfig(
sharded_ckpt_prefix_dir='ba{batch}',
sync_module_states=load_monolith_rank0_only,
load_monolith_rank0_only=load_monolith_rank0_only,
)
tp_config = None
if use_tp:
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
Expand Down Expand Up @@ -778,23 +790,33 @@ def mock_get_checkpoint_validation_function():
@pytest.mark.gpu
@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False])
@pytest.mark.parametrize(
'world_size,weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp',
'weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp',
[
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, True, 'adamw', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adam', 'amp_bf16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_fp16', False, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', True, None, False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, ['rng'], False, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, True, False, marks=pytest.mark.world_size(2)),
pytest.param(2, False, 'adamw', 'amp_bf16', False, None, False, True, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(
False,
'adamw',
'amp_bf16',
False,
['rng'],
False,
False,
False,
marks=pytest.mark.world_size(2),
),
pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, True, False, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
],
)
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning')
def test_fsdp_partitioned_state_dict_load(
world_size,
tmp_path: pathlib.Path,
autoresume: bool,
precision: str,
Expand All @@ -803,6 +825,7 @@ def test_fsdp_partitioned_state_dict_load(
load_ignore_keys: Union[list[str], None],
use_symlink: bool,
use_tp: bool,
use_hsdp: bool,
use_remote,
s3_bucket,
s3_ephemeral_prefix,
Expand All @@ -829,10 +852,19 @@ def test_fsdp_partitioned_state_dict_load(

save_filename = 'ba{batch}-rank{rank}.pt'

fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}')
if use_hsdp:
fsdp_config = FSDPConfig(
sharding_strategy='HYBRID_SHARD',
sharded_ckpt_prefix_dir='ba{batch}',
state_dict_type='sharded',
data_parallel_shard_degree=2,
data_parallel_replicate_degree=2,
sync_module_states=True,
)
else:
fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}')
tp_config = None
if use_tp:
fsdp_config = FSDPConfig(state_dict_type='sharded', sharded_ckpt_prefix_dir='ba{batch}')
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
tp_config = {
'tensor_parallel_degree': 2,
Expand Down

0 comments on commit 072758e

Please sign in to comment.