Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update adapter saving logic to be compatible with save_weights_only #10466

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str)
super().setup(trainer, pl_module, stage=stage)

trainer.strategy.trainer = trainer
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io)
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io, self)
trainer.strategy._checkpoint_io = self.wrapped_io
trainer.strategy._init_model_parallel = False
trainer.strategy._setup_optimizers = False
Expand Down Expand Up @@ -137,22 +137,12 @@ def apply_transform(self, trainer):
if trainer.state.fn == TrainerFn.FITTING:
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True)

def on_save_checkpoint(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
) -> None:
# Filter out non-trainable parameters
trainable_params = set(name for name, param in pl_module.named_parameters() if param.requires_grad)
filtered_state_dict = {}
for name, value in trainer.strategy.megatron_parallel.sharded_state_dict().items():
if name in trainable_params:
filtered_state_dict[name] = value
elif self.adapter_key_filter(name): # Include all adapter-related parameters
filtered_state_dict[name] = value

checkpoint['sharded_state_dict'] = filtered_state_dict
self.trainable_params = set(
name for name, param in trainer.lightning_module.named_parameters() if param.requires_grad
)

def adapter_key_filter(self, key: str) -> bool:
return ".adapter." in key or key.endswith(".adapters")
return key in self.trainable_params or ".adapter." in key or key.endswith(".adapters")


class AdapterWrapper(nn.Module):
Expand Down Expand Up @@ -269,13 +259,21 @@ def load_state_dict(self, state_dict, strict=True):


class WrappedAdapterIO(_WrappingCheckpointIO):
peft: Optional[PEFT] = None
model_ckpt_path: Optional[Path] = None
adapter_ckpt_path: Optional[Path] = None

def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None, peft: Optional[PEFT] = None) -> None:
self.peft = peft
super().__init__(checkpoint_io)

@override
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
assert self.checkpoint_io is not None

checkpoint['sharded_state_dict'] = dict(
filter(lambda item: self.peft.adapter_key_filter(item[0]), checkpoint['sharded_state_dict'].items())
)
self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options)

from nemo.utils.get_rank import is_global_rank_zero
Expand Down
Loading