From 57373a500fdcd8dc908d52dc075a588902f79598 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 8 Jul 2024 15:33:20 -0700 Subject: [PATCH] make 'load_directly_on_device' configurable Signed-off-by: ashors1 --- nemo/lightning/pytorch/strategies.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index d0e502839f2f..1155dafef293 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -110,6 +110,7 @@ def __init__( ckpt_parallel_save=True, ckpt_parallel_load=False, ckpt_parallel_save_optim=True, + ckpt_load_directly_on_device=True, setup_optimizers: bool = True, init_model_parallel: bool = True, **kwargs, @@ -145,6 +146,7 @@ def __init__( self.parallel_save = ckpt_parallel_save self.parallel_load = ckpt_parallel_load self.parallel_save_optim = ckpt_parallel_save_optim + self.load_directly_on_device = ckpt_load_directly_on_device self._ddp = ddp if ddp == "megatron": @@ -579,6 +581,7 @@ def checkpoint_io(self) -> CheckpointIO: assume_constant_structure=self.assume_constant_structure, parallel_save=self.parallel_save, parallel_load=self.parallel_load, + load_directly_on_device=self.load_directly_on_device, ) if async_save: self._checkpoint_io = AsyncFinalizableCheckpointIO(self._checkpoint_io)