diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index c9d4967c5f6b0..473c2dfb185a0 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -71,8 +71,7 @@ def configure_sharded_model(self) -> None: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # when loading full state dict, we first need to create a new unwrapped model - if self.layer is None or isinstance(self.layer, FullyShardedDataParallel): - self._init_model() + self._init_model() def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1)