diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 5c9703497597..0bb8f5fa46af 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -279,7 +279,7 @@ def _setup( model_transform: Optional[Union[PEFT, ModelTransform, Callable]], ) -> Any: # Return type is Any because app_state's type is not specified _log = log or NeMoLogger() - if resume and resume.adapter_path and _log.ckpt: + if resume and isinstance(model_transform, PEFT) and _log.ckpt: logging.info("Disabling try_restore_best_ckpt restoration for adapters") _log.ckpt.try_restore_best_ckpt = False diff --git a/nemo/collections/llm/fn/mixin.py b/nemo/collections/llm/fn/mixin.py index b32f66366bfb..c566c6e9d392 100644 --- a/nemo/collections/llm/fn/mixin.py +++ b/nemo/collections/llm/fn/mixin.py @@ -2,6 +2,7 @@ from typing_extensions import Self from nemo.collections.llm.fn import base as fn +from nemo.utils import logging class FNMixin: @@ -114,8 +115,12 @@ def freeze(self) -> None: """ assert isinstance(self, nn.Module), "self is not a nn.Module" - for param in self.parameters(): - param.requires_grad = False + params = list(self.parameters()) + if not params: + logging.info(f"No parameters found in module {self.__class__.__name__}") + else: + for param in params: + param.requires_grad = False def unfreeze(self) -> None: """ @@ -124,5 +129,9 @@ def unfreeze(self) -> None: """ assert isinstance(self, nn.Module), "self is not a nn.Module" - for param in self.parameters(): - param.requires_grad = True + params = list(self.parameters()) + if not params: + logging.info(f"No parameters found in module {self.__class__.__name__}") + else: + for param in params: + param.requires_grad = True diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index e6452de16512..3bd62ddce24a 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -515,4 +515,7 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri elif count > n_nesting: to_remove = "module." * (count - n_nesting) _state_dict[key[len(to_remove) :]] = value + else: + _state_dict[key] = value + module.load_state_dict(_state_dict, strict=strict) diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index 500d0203cfd4..8be630f163e0 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -160,12 +160,8 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None: output_path (Path): The path where the model checkpoint will be saved. trainer (pl.Trainer): The trainer with the strategy to save the model. """ - _setup_kwargs = {} - setup_signature = inspect.signature(trainer.strategy.setup) - if 'setup_optimizers' in setup_signature.parameters: - _setup_kwargs["setup_optimizers"] = False - - trainer.strategy.setup(trainer, **_setup_kwargs) + trainer.strategy._setup_optimizers = False + trainer.strategy.setup(trainer) trainer.save_checkpoint(output_path) def nemo_load( diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 2f2308717004..ee41455544bb 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -12,7 +12,6 @@ Iterable, Iterator, List, - Mapping, Optional, Protocol, Sequence, @@ -129,7 +128,6 @@ def __init__( cpu: bool = False, convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None, ) -> None: - from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes from megatron.core import parallel_state _pipeline: List[nn.Module] @@ -152,67 +150,15 @@ def __init__( _model.configure_model() _pipeline.append(_model) - if convert_module_fn: - for i in range(len(_pipeline)): - _pipeline[i] = convert_module_fn(_pipeline[i]) - - if isinstance(ddp_config, DistributedDataParallelConfig): - for model_chunk_idx, model_chunk in enumerate(_pipeline): - module = model_chunk.module - - ddp = DDP( - module.config, - ddp_config, - module, - data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), - expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), - # Turn off bucketing for model_chunk 2 onwards, since communication for these - # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0), - ) - model_chunk.module = ddp - model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses - model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore - - # param_sync_func is set in nemo.lightning.pytorch.optim.megatron - no_sync_func, grad_sync_func = extract_ddp_funcs(ddp_config, _pipeline) - for module in _pipeline: - module.config.no_sync_func = no_sync_func - module.config.grad_sync_func = grad_sync_func - - for i, model_module in enumerate(_pipeline): - if not cpu: - model_module.cuda(torch.cuda.current_device()) - - for param in model_module.parameters(): - set_defaults_if_not_set_tensor_model_parallel_attributes(param) - - if hasattr(model_module, "configure_model"): - if not hasattr(model_module, "set_input_tensor"): - if hasattr(model_module.module, "set_input_tensor"): - model_module.set_input_tensor = model_module.module.set_input_tensor - else: - # TODO: What to do here? - pass - - # Print number of parameters. - if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0: - from nemo.utils import logging - - msg = ( - f" > number of parameters on (tensor, pipeline) model parallel rank " - f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): " - f"{_calc_number_of_params(_pipeline)}" - ) - logging.info(msg) - super().__init__(_pipeline) self.precision_plugin = precision_plugin + self._cpu = cpu self.callbacks = callbacks or CallbackConnector() self.data_step = data_step or default_data_step self.forward_step = forward_step or default_forward_step self.loss_reduction: MegatronLossReduction = loss_reduction self.ddp_config = ddp_config + self.convert_module_fn = convert_module_fn def forward( self, @@ -475,6 +421,82 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat raise ValueError("Cannot infer `num_microbatches` from data, please specify it manually") + def init_model_parallel(self): + from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes + from megatron.core import parallel_state + + for model_module in self: + if not self._cpu: + model_module.cuda(torch.cuda.current_device()) + + for param in model_module.parameters(): + set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + if hasattr(model_module, "configure_model"): + if not hasattr(model_module, "set_input_tensor"): + if hasattr(model_module.module, "set_input_tensor"): + model_module.set_input_tensor = model_module.module.set_input_tensor + else: + # TODO: What to do here? + pass + + # Print number of parameters. + if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0: + from nemo.utils import logging + + num_params = _calc_number_of_params(list(self)) + num_trainable_params = _calc_number_of_trainable_params(list(self)) + + msg = ( + f" > number of parameters on (tensor, pipeline) model parallel rank " + f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): " + f"{num_params}" + ) + logging.info(msg) + + if num_params != num_trainable_params: + logging.info( + f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)" + ) + + if self.convert_module_fn: + self.apply_convert_module_fn() + + self.init_ddp() + + def apply_convert_module_fn(self): + for i in range(len(self)): + self[i] = self.convert_module_fn(self[i]) + + def init_ddp(self): + if not isinstance(self.ddp_config, DistributedDataParallelConfig): + return + + from megatron.core import parallel_state + + for model_chunk_idx, model_chunk in enumerate(self): + module = model_chunk.module + + ddp = DDP( + module.config, + self.ddp_config, + module, + data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0), + ) + model_chunk.module = ddp + model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses + model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore + + # param_sync_func is set in nemo.lightning.pytorch.optim.megatron + no_sync_func, grad_sync_func = extract_ddp_funcs(self.ddp_config, self) + for module in self: + module.config.no_sync_func = no_sync_func + module.config.grad_sync_func = grad_sync_func + def _build_context(self, context: Dict[str, Any]) -> Dict[str, Any]: if "self" in context: del context["self"] @@ -565,18 +587,21 @@ def forward_backward_func(self) -> "MegatronStepProtocol": @override def __getattr__(self, item: Any) -> Any: - if len(self) == 0: - return super().__getattr__(item) - try: - # __getattr__ gets called as a last resort if the attribute does not exist - # call nn.Module's implementation first + # First, try to get the attribute from the superclass (nn.ModuleList) return super().__getattr__(item) except AttributeError: - # If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module - attr = getattr(self._modules[self._get_abs_string_index(0)], item) + # If not found in superclass, check if we have any modules + if len(self) == 0: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}' and contains no modules" + ) - return attr + # Try to get it from the first module + try: + return getattr(self._modules[self._get_abs_string_index(0)], item) + except AttributeError: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") class _ModuleStepFunction: @@ -915,6 +940,12 @@ def _calc_number_of_params(model: List[nn.Module]) -> int: return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]) +def _calc_number_of_trainable_params(model: List[nn.Module]) -> int: + assert isinstance(model, list) + + return sum([sum([p.numel() for p in model_module.parameters() if p.requires_grad]) for model_module in model]) + + def is_list_of_iterators(var) -> bool: if not isinstance(var, list): return False diff --git a/nemo/lightning/pytorch/callbacks/model_transform.py b/nemo/lightning/pytorch/callbacks/model_transform.py index 68b3db16f473..512324940133 100644 --- a/nemo/lightning/pytorch/callbacks/model_transform.py +++ b/nemo/lightning/pytorch/callbacks/model_transform.py @@ -65,7 +65,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo def _maybe_apply_transform(self, trainer): if self._needs_to_call: - self.model_transform(trainer.model) + self.apply_transform(trainer) + + def apply_transform(self, trainer): + self.model_transform(trainer.model) @property def _needs_to_call(self) -> bool: diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index 26325bf549d0..f8fa76110288 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -84,19 +84,27 @@ def __call__(self, model: nn.Module) -> nn.Module: def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: super().setup(trainer, pl_module, stage=stage) + trainer.strategy.trainer = trainer self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io) trainer.strategy._checkpoint_io = self.wrapped_io + trainer.strategy._init_model_parallel = False + trainer.strategy._setup_optimizers = False - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - needs_to_call = self._needs_to_call - self._maybe_apply_transform(trainer) + def apply_transform(self, trainer): + super().apply_transform(trainer) - # Check if we need to load the adapters - if needs_to_call and self.wrapped_io.adapter_ckpt_path is not None: + if self.wrapped_io.adapter_ckpt_path is not None: logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}") adapter_state = self.wrapped_io.load_checkpoint(self.wrapped_io.adapter_ckpt_path) trainer.strategy.load_model_state_dict(adapter_state, strict=False) + if hasattr(trainer.strategy, "init_model_parallel"): + logging.info("Initializing model parallel") + trainer.strategy.init_model_parallel() + + logging.info("Setting up optimizers") + trainer.strategy.setup_optimizers(trainer) + def on_load_checkpoint( self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any] ) -> None: diff --git a/nemo/lightning/pytorch/optim/lr_scheduler.py b/nemo/lightning/pytorch/optim/lr_scheduler.py index 298a6e7a7f45..9374328190a6 100644 --- a/nemo/lightning/pytorch/optim/lr_scheduler.py +++ b/nemo/lightning/pytorch/optim/lr_scheduler.py @@ -445,7 +445,6 @@ def scheduler(self, model, optimizer): return { "optimizer": optimizer, - "scheduler": lr_scheduler, "lr_scheduler": { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 751141d8111b..5e43e09c0420 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -61,7 +61,6 @@ def convert_module(self, module: Module) -> Module: This is optional and depends on the precision limitations during optimization. """ - from megatron.core.distributed import DistributedDataParallel from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_model_config @@ -69,7 +68,10 @@ def convert_module(self, module: Module) -> Module: config = get_model_config(module.module) config.fp16 = self.precision == "16-mixed" config.bf16 = self.precision == "bf16-mixed" - if not isinstance(module.module, Float16Module): + if isinstance(module.module, Float16Module): + new_float16_module = Float16Module(config, module.module.module) + module.module = new_float16_module + else: module.module = Float16Module(config, module.module) return module diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 6a84319b4fa2..d0e502839f2f 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -110,6 +110,8 @@ def __init__( ckpt_parallel_save=True, ckpt_parallel_load=False, ckpt_parallel_save_optim=True, + setup_optimizers: bool = True, + init_model_parallel: bool = True, **kwargs, ) -> None: super().__init__( @@ -132,6 +134,8 @@ def __init__( self.lazy_init = lazy_init self.ckpt_include_optimizer = ckpt_include_optimizer self.pipeline_dtype = pipeline_dtype + self._setup_optimizers = setup_optimizers + self._init_model_parallel = init_model_parallel self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) @@ -144,7 +148,7 @@ def __init__( self._ddp = ddp if ddp == "megatron": - self.ddp_config = DistributedDataParallelConfig() + self.ddp_config = DistributedDataParallelConfig(check_for_nan_in_grad=True) elif isinstance(ddp, DistributedDataParallelConfig): self.ddp_config = ddp elif ddp == "pytorch": @@ -180,7 +184,7 @@ def connect(self, model: pl.LightningModule) -> None: ddp_config.use_distributed_optimizer = mcore_opt_config.use_distributed_optimizer @override - def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: + def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator is not None self.accelerator.setup(trainer) self.trainer = trainer @@ -204,7 +208,7 @@ def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: self.data_sampler.connect(trainer) self._fix_progress_bar(trainer) - self.setup_megatron_parallel(trainer, setup_optimizers=setup_optimizers) + self.setup_megatron_parallel(trainer) self.setup_precision_plugin() if getattr(self.lightning_module, "model_transform", None): @@ -271,7 +275,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader: return dataloader - def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: + def setup_megatron_parallel(self, trainer: pl.Trainer) -> None: assert self.model is not None, "Model is not set" convert_module_fn = None @@ -286,6 +290,10 @@ def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = ddp_config=self.ddp_config, convert_module_fn=convert_module_fn, ) + + if self._init_model_parallel: + self.init_model_parallel() + self.megatron_parallel.trainer = trainer # check signature-def of self.model.configure_optimizers to check if there's an optional arg: megatron_parallel @@ -295,18 +303,9 @@ def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = self.model.configure_optimizers, megatron_parallel=self.megatron_parallel ) - if setup_optimizers: + if self._setup_optimizers: self.setup_optimizers(trainer) - # TODO: Throw an execption if we have a mcore optimizer and no ddp_config - - if hasattr(self.precision_plugin, "convert_optimizer"): - _optimizers = [*self.optimizers] - _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0]) - self.optimizers = _optimizers - - _optimizers_to_device(self.optimizers, self.root_device) - self.model = self.megatron_parallel self.model.callbacks.add(getattr(trainer, "callbacks")) @@ -317,6 +316,9 @@ def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = if datamodule: self.model.callbacks.add(datamodule) + def init_model_parallel(self): + self.megatron_parallel.init_model_parallel() + @override def configure_ddp(self) -> None: logging.debug(f"{self.__class__.__name__}: configuring MegatronParallel") @@ -349,6 +351,16 @@ def _setup_model(self, model: nn.Module) -> nn.Module: return model + @override + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + super().setup_optimizers(trainer) + if hasattr(self.precision_plugin, "convert_optimizer"): + _optimizers = [*self.optimizers] + _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0]) + self.optimizers = _optimizers + + _optimizers_to_device(self.optimizers, self.root_device) + def _setup_parallel_ranks(self) -> None: self.set_world_ranks() env = cast(ClusterEnvironment, self.cluster_environment) diff --git a/tests/lightning/pytorch/callbacks/test_peft.py b/tests/lightning/pytorch/callbacks/test_peft.py index 81dc7f85bc08..e64ee7bd0ba3 100644 --- a/tests/lightning/pytorch/callbacks/test_peft.py +++ b/tests/lightning/pytorch/callbacks/test_peft.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import torch.nn as nn from nemo.collections.llm import fn @@ -54,8 +54,22 @@ def test_peft_on_train_epoch_start_with_adapter(self, mock_logging): peft.wrapped_io.load_checkpoint.return_value = {"dummy_state": "dummy_value"} peft.on_train_epoch_start(trainer, pl_module) - mock_logging.info.assert_called_once_with("Loading adapters from dummy_path") + # Check for all expected log messages + mock_logging.info.assert_has_calls( + [ + call("Loading adapters from dummy_path"), + call("Initializing model parallel"), + call("Setting up optimizers"), + ], + any_order=True, + ) + + # Verify the number of calls + assert mock_logging.info.call_count == 3 + trainer.strategy.load_model_state_dict.assert_called_once_with({"dummy_state": "dummy_value"}, strict=False) + trainer.strategy.init_model_parallel.assert_called_once() + trainer.strategy.setup_optimizers.assert_called_once_with(trainer) def test_peft_on_load_checkpoint(self): peft = self.DummyPEFT()