Skip to content

Commit

Permalink
Add checkpoint BC tests for 0.27.0 and 0.28.0 (#3735)
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 authored Dec 5, 2024
1 parent 6c20da5 commit 06fb4a8
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ def test_fsdp_mixed_with_sync(
'0.24.0',
'0.25.0',
'0.26.0',
'0.27.0',
'0.28.0',
],
)
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
Expand All @@ -534,9 +536,12 @@ def test_fsdp_load_old_checkpoint(
if composer_version == '0.18.1' and state_dict_type == 'full' and precision == 'amp_bf16' and sharding_strategy == 'FULL_SHARD':
pytest.skip('TODO: This checkpoint is missing')

if (composer_version in ['0.22.0', '0.23.0'] and version.parse(torch.__version__) < version.parse('2.3.0')) or (
composer_version in ['0.24.0', '0.25.0'] and version.parse(torch.__version__) < version.parse('2.4.0')
) or (composer_version in '0.26.0' and version.parse(torch.__version__) < version.parse('2.5.0')):
if (composer_version in ['0.22.0', '0.23.0'] and version.parse(torch.__version__) < version.parse('2.3.0')
) or (composer_version in ['0.24.0', '0.25.0'] and
version.parse(torch.__version__) < version.parse('2.4.0')) or (
composer_version in ['0.26.0', '0.27.0', '0.28.0'] and
version.parse(torch.__version__) < version.parse('2.5.0')
):
pytest.skip('Current torch version is older than torch version that checkpoint was written with.')

if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']:
Expand Down

0 comments on commit 06fb4a8

Please sign in to comment.