diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 093bfeee30b7..57cd33a612ae 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -111,6 +111,7 @@ def __init__( ckpt_parallel_save_within_dp=False, ckpt_parallel_load=False, ckpt_parallel_save_optim=True, + ckpt_load_directly_on_device=True, setup_optimizers: bool = True, init_model_parallel: bool = True, **kwargs, @@ -147,6 +148,7 @@ def __init__( self.parallel_save_within_dp = ckpt_parallel_save_within_dp self.parallel_load = ckpt_parallel_load self.parallel_save_optim = ckpt_parallel_save_optim + self.load_directly_on_device = ckpt_load_directly_on_device self._ddp = ddp if ddp == "megatron": @@ -582,6 +584,7 @@ def checkpoint_io(self) -> CheckpointIO: parallel_save=self.parallel_save, parallel_save_within_dp=self.parallel_save_within_dp, parallel_load=self.parallel_load, + load_directly_on_device=self.load_directly_on_device, ) if async_save: self._checkpoint_io = AsyncFinalizableCheckpointIO(self._checkpoint_io)