Skip to content

Commit

Permalink
[NeMo-UX] Fix when optimizers are setup for PEFT (#9619)
Browse files Browse the repository at this point in the history
* Fix when optimizers are setup for PEFT

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Init DDP inside PEFT

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Some fixes, loss seems to become nan with peft for some reason

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Loss goes down on fp32

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Simplifying FNMixin

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Fix bug with new checkpoint-io

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Fix failing test: test_peft_on_train_epoch_start_with_adapter

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

---------

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>
Co-authored-by: marcromeyn <marcromeyn@users.noreply.github.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
  • Loading branch information
3 people authored and dimapihtar committed Aug 27, 2024
1 parent 49f13fb commit 3d32926
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 25 deletions.
1 change: 0 additions & 1 deletion nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True
dump_io (bool): If True, the IO configuration will be saved to the output path.
"""
trainer.strategy._setup_optimizers = False
trainer.strategy._init_model_parallel = False
trainer.strategy.setup(trainer)
trainer.save_checkpoint(output_path)

Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,8 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat
raise ValueError("Cannot infer `num_microbatches` from data, please specify it manually")

def init_model_parallel(self):
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from megatron.core import parallel_state
from megatron.core.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes

for model_module in self:
if not self._cpu:
Expand Down
5 changes: 0 additions & 5 deletions nemo/lightning/pytorch/callbacks/model_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ def _maybe_apply_transform(self, trainer):

def apply_transform(self, trainer):
self.model_transform(trainer.model)
from pytorch_lightning.utilities import model_summary

logging.info(
f"After applying model_transform:\n" f"{model_summary.summarize(trainer.lightning_module, max_depth=1)}"
)

@property
def _needs_to_call(self) -> bool:
Expand Down
7 changes: 3 additions & 4 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ def apply_transform(self, trainer):
logging.info("Initializing model parallel")
trainer.strategy.init_model_parallel()

if trainer.state.fn == TrainerFn.FITTING:
logging.info("Setting up optimizers")
trainer.strategy.setup_optimizers(trainer)
logging.info("Setting up optimizers")
trainer.strategy.setup_optimizers(trainer)

def on_save_checkpoint(
def on_load_checkpoint(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
) -> None:
# Filter out non-trainable parameters
Expand Down
9 changes: 6 additions & 3 deletions nemo/lightning/pytorch/plugins/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ def convert_module(self, module: Module) -> Module:
if self.dtype_config.fp16 or self.dtype_config.bf16:
# Patch config options
config = get_model_config(module.module)
config.fp16 = self.dtype_config.fp16
config.bf16 = self.dtype_config.bf16
if hasattr(module, 'module'):
config.fp16 = self.precision == "16-mixed"
config.bf16 = self.precision == "bf16-mixed"
if isinstance(module.module, Float16Module):
new_float16_module = Float16Module(config, module.module.module)
module.module = new_float16_module
else:
module.module = Float16Module(config, module.module)
else:
module = Float16Module(config, module)
Expand Down
17 changes: 6 additions & 11 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,14 @@ def __init__(
ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron",
lazy_init: bool = False,
pipeline_dtype: Optional[torch.dtype] = None,
save_ckpt_format: str = 'torch_dist',
ckpt_async_save: bool = False,
ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure: bool = False,
ckpt_parallel_save: bool = True,
ckpt_parallel_save_within_dp: bool = False,
ckpt_parallel_load: bool = False,
ckpt_parallel_save_optim: bool = True,
ckpt_load_directly_on_device: bool = True,
save_ckpt_format='torch_dist',
ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure=False,
ckpt_parallel_save=True,
ckpt_parallel_load=False,
ckpt_parallel_save_optim=True,
setup_optimizers: bool = True,
init_model_parallel: bool = True,
replace_progress_bar: bool = True,
progress_interval: int = 1,
**kwargs,
) -> None:
super().__init__(
Expand Down
10 changes: 10 additions & 0 deletions tests/lightning/pytorch/callbacks/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,13 @@ def test_peft_on_train_epoch_start_with_adapter(self, mock_logging):
trainer.strategy.load_model_state_dict.assert_called_once_with({"dummy_state": "dummy_value"}, strict=False)
trainer.strategy.init_model_parallel.assert_called_once()
trainer.strategy.setup_optimizers.assert_called_once_with(trainer)

def test_peft_on_load_checkpoint(self):
peft = self.DummyPEFT()
trainer = MagicMock()
pl_module = MagicMock()
checkpoint = {}

peft.on_load_checkpoint(trainer, pl_module, checkpoint)

assert pl_module.strict_loading == False

0 comments on commit 3d32926

Please sign in to comment.