Skip to content

Commit

Permalink
Update adapter saving logic to be compatible with save_weights_only (
Browse files Browse the repository at this point in the history
…#10466)

* update adapter save logic to be compatible with `save_weights_only`

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

---------

Signed-off-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>
Co-authored-by: cuichenx <cuichenx@users.noreply.github.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>
  • Loading branch information
3 people committed Sep 16, 2024
1 parent 9621be2 commit 0f8a531
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str)
super().setup(trainer, pl_module, stage=stage)

trainer.strategy.trainer = trainer
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io)
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io, self)
trainer.strategy._checkpoint_io = self.wrapped_io
trainer.strategy._init_model_parallel = False
trainer.strategy._setup_optimizers = False
Expand Down Expand Up @@ -137,22 +137,12 @@ def apply_transform(self, trainer):
if trainer.state.fn == TrainerFn.FITTING:
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True)

def on_save_checkpoint(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
) -> None:
# Filter out non-trainable parameters
trainable_params = set(name for name, param in pl_module.named_parameters() if param.requires_grad)
filtered_state_dict = {}
for name, value in trainer.strategy.megatron_parallel.sharded_state_dict().items():
if name in trainable_params:
filtered_state_dict[name] = value
elif self.adapter_key_filter(name): # Include all adapter-related parameters
filtered_state_dict[name] = value

checkpoint['sharded_state_dict'] = filtered_state_dict
self.trainable_params = set(
name for name, param in trainer.lightning_module.named_parameters() if param.requires_grad
)

def adapter_key_filter(self, key: str) -> bool:
return ".adapter." in key or key.endswith(".adapters")
return key in self.trainable_params or ".adapter." in key or key.endswith(".adapters")


class AdapterWrapper(nn.Module):
Expand Down Expand Up @@ -269,13 +259,21 @@ def load_state_dict(self, state_dict, strict=True):


class WrappedAdapterIO(_WrappingCheckpointIO):
peft: Optional[PEFT] = None
model_ckpt_path: Optional[Path] = None
adapter_ckpt_path: Optional[Path] = None

def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None, peft: Optional[PEFT] = None) -> None:
self.peft = peft
super().__init__(checkpoint_io)

@override
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
assert self.checkpoint_io is not None

checkpoint['sharded_state_dict'] = dict(
filter(lambda item: self.peft.adapter_key_filter(item[0]), checkpoint['sharded_state_dict'].items())
)
self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options)

from nemo.utils.get_rank import is_global_rank_zero
Expand Down

0 comments on commit 0f8a531

Please sign in to comment.