Skip to content

Commit

Permalink
[DCP] Add test for planner option for load_sharded_optimizer_state_di…
Browse files Browse the repository at this point in the history
…ct (#112930)

Add test for a user submitted PR: #112259
Cherry-pick of #112891 into `release/2.1` branch
  • Loading branch information
wz337 authored Nov 7, 2023
1 parent 4b4c012 commit 33106b7
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions test/distributed/checkpoint/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
Expand Down Expand Up @@ -53,8 +57,10 @@ def backend(self):
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_load_sharded_optimizer_state_dict(self) -> None:
@parametrize("pass_planner", [True, False])
def test_load_sharded_optimizer_state_dict(self, pass_planner) -> None:
CHECKPOINT_DIR = self.temp_dir
planner = DCP.DefaultLoadPlanner() if pass_planner else None

model = self._create_model()
model = FSDP(model)
Expand Down Expand Up @@ -105,6 +111,7 @@ def test_load_sharded_optimizer_state_dict(self) -> None:
model_state_dict=state_dict["model"],
optimizer_key="optim",
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
planner=planner,
)
flattened_osd = FSDP.optim_state_dict_to_load(
model_2, optim_2, optim_state["optim"]
Expand All @@ -126,5 +133,6 @@ def test_load_sharded_optimizer_state_dict(self) -> None:
self.assertEqual(state, state2)


instantiate_parametrized_tests(FsdpOptimStateCheckpoint)
if __name__ == "__main__":
run_tests()

0 comments on commit 33106b7

Please sign in to comment.