From 3d32926f160d2abeb20425d6bb9d08013871b537 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 9 Jul 2024 06:16:04 +0200 Subject: [PATCH] [NeMo-UX] Fix when optimizers are setup for PEFT (#9619) * Fix when optimizers are setup for PEFT * Apply isort and black reformatting Signed-off-by: marcromeyn * Init DDP inside PEFT * Apply isort and black reformatting Signed-off-by: marcromeyn * Some fixes, loss seems to become nan with peft for some reason * Apply isort and black reformatting Signed-off-by: marcromeyn * Loss goes down on fp32 * Apply isort and black reformatting Signed-off-by: marcromeyn * Simplifying FNMixin * Apply isort and black reformatting Signed-off-by: marcromeyn * Fix bug with new checkpoint-io * Apply isort and black reformatting Signed-off-by: marcromeyn * Fix failing test: test_peft_on_train_epoch_start_with_adapter * Apply isort and black reformatting Signed-off-by: marcromeyn --------- Signed-off-by: marcromeyn Co-authored-by: marcromeyn Co-authored-by: Chen Cui --- nemo/lightning/io/connector.py | 1 - nemo/lightning/megatron_parallel.py | 2 +- .../pytorch/callbacks/model_transform.py | 5 ----- nemo/lightning/pytorch/callbacks/peft.py | 7 +++---- .../pytorch/plugins/mixed_precision.py | 9 ++++++--- nemo/lightning/pytorch/strategies.py | 17 ++++++----------- tests/lightning/pytorch/callbacks/test_peft.py | 10 ++++++++++ 7 files changed, 26 insertions(+), 25 deletions(-) diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index 512f3bc4f12e..66452aa570ad 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -171,7 +171,6 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True dump_io (bool): If True, the IO configuration will be saved to the output path. """ trainer.strategy._setup_optimizers = False - trainer.strategy._init_model_parallel = False trainer.strategy.setup(trainer) trainer.save_checkpoint(output_path) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index dd10a726e67a..c7ae644e4826 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -476,8 +476,8 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat raise ValueError("Cannot infer `num_microbatches` from data, please specify it manually") def init_model_parallel(self): + from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes from megatron.core import parallel_state - from megatron.core.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes for model_module in self: if not self._cpu: diff --git a/nemo/lightning/pytorch/callbacks/model_transform.py b/nemo/lightning/pytorch/callbacks/model_transform.py index 8a07566f92c3..5d48851843fc 100644 --- a/nemo/lightning/pytorch/callbacks/model_transform.py +++ b/nemo/lightning/pytorch/callbacks/model_transform.py @@ -71,11 +71,6 @@ def _maybe_apply_transform(self, trainer): def apply_transform(self, trainer): self.model_transform(trainer.model) - from pytorch_lightning.utilities import model_summary - - logging.info( - f"After applying model_transform:\n" f"{model_summary.summarize(trainer.lightning_module, max_depth=1)}" - ) @property def _needs_to_call(self) -> bool: diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index c7983af26752..c12e2bc574f4 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -103,11 +103,10 @@ def apply_transform(self, trainer): logging.info("Initializing model parallel") trainer.strategy.init_model_parallel() - if trainer.state.fn == TrainerFn.FITTING: - logging.info("Setting up optimizers") - trainer.strategy.setup_optimizers(trainer) + logging.info("Setting up optimizers") + trainer.strategy.setup_optimizers(trainer) - def on_save_checkpoint( + def on_load_checkpoint( self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any] ) -> None: # Filter out non-trainable parameters diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 79394cc4bbb1..3e366db393b7 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -134,9 +134,12 @@ def convert_module(self, module: Module) -> Module: if self.dtype_config.fp16 or self.dtype_config.bf16: # Patch config options config = get_model_config(module.module) - config.fp16 = self.dtype_config.fp16 - config.bf16 = self.dtype_config.bf16 - if hasattr(module, 'module'): + config.fp16 = self.precision == "16-mixed" + config.bf16 = self.precision == "bf16-mixed" + if isinstance(module.module, Float16Module): + new_float16_module = Float16Module(config, module.module.module) + module.module = new_float16_module + else: module.module = Float16Module(config, module.module) else: module = Float16Module(config, module) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index d6ef18770fa4..f6504238e75a 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -144,19 +144,14 @@ def __init__( ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron", lazy_init: bool = False, pipeline_dtype: Optional[torch.dtype] = None, - save_ckpt_format: str = 'torch_dist', - ckpt_async_save: bool = False, - ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere? - ckpt_assume_constant_structure: bool = False, - ckpt_parallel_save: bool = True, - ckpt_parallel_save_within_dp: bool = False, - ckpt_parallel_load: bool = False, - ckpt_parallel_save_optim: bool = True, - ckpt_load_directly_on_device: bool = True, + save_ckpt_format='torch_dist', + ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere? + ckpt_assume_constant_structure=False, + ckpt_parallel_save=True, + ckpt_parallel_load=False, + ckpt_parallel_save_optim=True, setup_optimizers: bool = True, init_model_parallel: bool = True, - replace_progress_bar: bool = True, - progress_interval: int = 1, **kwargs, ) -> None: super().__init__( diff --git a/tests/lightning/pytorch/callbacks/test_peft.py b/tests/lightning/pytorch/callbacks/test_peft.py index 99a22f82fa50..3c3e4c4347aa 100644 --- a/tests/lightning/pytorch/callbacks/test_peft.py +++ b/tests/lightning/pytorch/callbacks/test_peft.py @@ -72,3 +72,13 @@ def test_peft_on_train_epoch_start_with_adapter(self, mock_logging): trainer.strategy.load_model_state_dict.assert_called_once_with({"dummy_state": "dummy_value"}, strict=False) trainer.strategy.init_model_parallel.assert_called_once() trainer.strategy.setup_optimizers.assert_called_once_with(trainer) + + def test_peft_on_load_checkpoint(self): + peft = self.DummyPEFT() + trainer = MagicMock() + pl_module = MagicMock() + checkpoint = {} + + peft.on_load_checkpoint(trainer, pl_module, checkpoint) + + assert pl_module.strict_loading == False