From 0f8a5319c6f625025544fff960aa5d9a8ac83460 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Mon, 16 Sep 2024 10:33:22 -0400 Subject: [PATCH] Update adapter saving logic to be compatible with `save_weights_only` (#10466) * update adapter save logic to be compatible with `save_weights_only` Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx --------- Signed-off-by: Chen Cui Signed-off-by: cuichenx Co-authored-by: cuichenx Co-authored-by: Pablo Garay --- nemo/lightning/pytorch/callbacks/peft.py | 28 +++++++++++------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index 100f1df3f9ab..a3542d9a2135 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -100,7 +100,7 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) super().setup(trainer, pl_module, stage=stage) trainer.strategy.trainer = trainer - self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io) + self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io, self) trainer.strategy._checkpoint_io = self.wrapped_io trainer.strategy._init_model_parallel = False trainer.strategy._setup_optimizers = False @@ -137,22 +137,12 @@ def apply_transform(self, trainer): if trainer.state.fn == TrainerFn.FITTING: trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True) - def on_save_checkpoint( - self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any] - ) -> None: - # Filter out non-trainable parameters - trainable_params = set(name for name, param in pl_module.named_parameters() if param.requires_grad) - filtered_state_dict = {} - for name, value in trainer.strategy.megatron_parallel.sharded_state_dict().items(): - if name in trainable_params: - filtered_state_dict[name] = value - elif self.adapter_key_filter(name): # Include all adapter-related parameters - filtered_state_dict[name] = value - - checkpoint['sharded_state_dict'] = filtered_state_dict + self.trainable_params = set( + name for name, param in trainer.lightning_module.named_parameters() if param.requires_grad + ) def adapter_key_filter(self, key: str) -> bool: - return ".adapter." in key or key.endswith(".adapters") + return key in self.trainable_params or ".adapter." in key or key.endswith(".adapters") class AdapterWrapper(nn.Module): @@ -269,13 +259,21 @@ def load_state_dict(self, state_dict, strict=True): class WrappedAdapterIO(_WrappingCheckpointIO): + peft: Optional[PEFT] = None model_ckpt_path: Optional[Path] = None adapter_ckpt_path: Optional[Path] = None + def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None, peft: Optional[PEFT] = None) -> None: + self.peft = peft + super().__init__(checkpoint_io) + @override def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: assert self.checkpoint_io is not None + checkpoint['sharded_state_dict'] = dict( + filter(lambda item: self.peft.adapter_key_filter(item[0]), checkpoint['sharded_state_dict'].items()) + ) self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options) from nemo.utils.get_rank import is_global_rank_zero