Skip to content

Commit

Permalink
[NeMo-UX] Fix when optimizers are setup for PEFT (#9619)
Browse files Browse the repository at this point in the history
* Fix when optimizers are setup for PEFT

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Init DDP inside PEFT

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Some fixes, loss seems to become nan with peft for some reason

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Loss goes down on fp32

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Simplifying FNMixin

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Fix bug with new checkpoint-io

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Fix failing test: test_peft_on_train_epoch_start_with_adapter

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

---------

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>
Co-authored-by: marcromeyn <marcromeyn@users.noreply.github.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
  • Loading branch information
3 people authored and maanug-nv committed Jul 14, 2024
1 parent 693c55f commit 13345c5
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 100 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def _setup(
model_transform: Optional[Union[PEFT, ModelTransform, Callable]],
) -> Any: # Return type is Any because app_state's type is not specified
_log = log or NeMoLogger()
if resume and resume.adapter_path and _log.ckpt:
if resume and isinstance(model_transform, PEFT) and _log.ckpt:
logging.info("Disabling try_restore_best_ckpt restoration for adapters")
_log.ckpt.try_restore_best_ckpt = False

Expand Down
17 changes: 13 additions & 4 deletions nemo/collections/llm/fn/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing_extensions import Self

from nemo.collections.llm.fn import base as fn
from nemo.utils import logging


class FNMixin:
Expand Down Expand Up @@ -114,8 +115,12 @@ def freeze(self) -> None:
"""
assert isinstance(self, nn.Module), "self is not a nn.Module"

for param in self.parameters():
param.requires_grad = False
params = list(self.parameters())
if not params:
logging.info(f"No parameters found in module {self.__class__.__name__}")
else:
for param in params:
param.requires_grad = False

def unfreeze(self) -> None:
"""
Expand All @@ -124,5 +129,9 @@ def unfreeze(self) -> None:
"""
assert isinstance(self, nn.Module), "self is not a nn.Module"

for param in self.parameters():
param.requires_grad = True
params = list(self.parameters())
if not params:
logging.info(f"No parameters found in module {self.__class__.__name__}")
else:
for param in params:
param.requires_grad = True
3 changes: 3 additions & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,4 +516,7 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri
elif count > n_nesting:
to_remove = "module." * (count - n_nesting)
_state_dict[key[len(to_remove) :]] = value
else:
_state_dict[key] = value

module.load_state_dict(_state_dict, strict=strict)
8 changes: 2 additions & 6 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,8 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None:
output_path (Path): The path where the model checkpoint will be saved.
trainer (pl.Trainer): The trainer with the strategy to save the model.
"""
_setup_kwargs = {}
setup_signature = inspect.signature(trainer.strategy.setup)
if 'setup_optimizers' in setup_signature.parameters:
_setup_kwargs["setup_optimizers"] = False

trainer.strategy.setup(trainer, **_setup_kwargs)
trainer.strategy._setup_optimizers = False
trainer.strategy.setup(trainer)
trainer.save_checkpoint(output_path)

def nemo_load(
Expand Down
159 changes: 95 additions & 64 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Iterable,
Iterator,
List,
Mapping,
Optional,
Protocol,
Sequence,
Expand Down Expand Up @@ -129,7 +128,6 @@ def __init__(
cpu: bool = False,
convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None,
) -> None:
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from megatron.core import parallel_state

_pipeline: List[nn.Module]
Expand All @@ -152,67 +150,15 @@ def __init__(
_model.configure_model()
_pipeline.append(_model)

if convert_module_fn:
for i in range(len(_pipeline)):
_pipeline[i] = convert_module_fn(_pipeline[i])

if isinstance(ddp_config, DistributedDataParallelConfig):
for model_chunk_idx, model_chunk in enumerate(_pipeline):
module = model_chunk.module

ddp = DDP(
module.config,
ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
)
model_chunk.module = ddp
model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses
model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore

# param_sync_func is set in nemo.lightning.pytorch.optim.megatron
no_sync_func, grad_sync_func = extract_ddp_funcs(ddp_config, _pipeline)
for module in _pipeline:
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func

for i, model_module in enumerate(_pipeline):
if not cpu:
model_module.cuda(torch.cuda.current_device())

for param in model_module.parameters():
set_defaults_if_not_set_tensor_model_parallel_attributes(param)

if hasattr(model_module, "configure_model"):
if not hasattr(model_module, "set_input_tensor"):
if hasattr(model_module.module, "set_input_tensor"):
model_module.set_input_tensor = model_module.module.set_input_tensor
else:
# TODO: What to do here?
pass

# Print number of parameters.
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
from nemo.utils import logging

msg = (
f" > number of parameters on (tensor, pipeline) model parallel rank "
f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): "
f"{_calc_number_of_params(_pipeline)}"
)
logging.info(msg)

super().__init__(_pipeline)
self.precision_plugin = precision_plugin
self._cpu = cpu
self.callbacks = callbacks or CallbackConnector()
self.data_step = data_step or default_data_step
self.forward_step = forward_step or default_forward_step
self.loss_reduction: MegatronLossReduction = loss_reduction
self.ddp_config = ddp_config
self.convert_module_fn = convert_module_fn

def forward(
self,
Expand Down Expand Up @@ -475,6 +421,82 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat

raise ValueError("Cannot infer `num_microbatches` from data, please specify it manually")

def init_model_parallel(self):
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from megatron.core import parallel_state

for model_module in self:
if not self._cpu:
model_module.cuda(torch.cuda.current_device())

for param in model_module.parameters():
set_defaults_if_not_set_tensor_model_parallel_attributes(param)

if hasattr(model_module, "configure_model"):
if not hasattr(model_module, "set_input_tensor"):
if hasattr(model_module.module, "set_input_tensor"):
model_module.set_input_tensor = model_module.module.set_input_tensor
else:
# TODO: What to do here?
pass

# Print number of parameters.
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
from nemo.utils import logging

num_params = _calc_number_of_params(list(self))
num_trainable_params = _calc_number_of_trainable_params(list(self))

msg = (
f" > number of parameters on (tensor, pipeline) model parallel rank "
f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): "
f"{num_params}"
)
logging.info(msg)

if num_params != num_trainable_params:
logging.info(
f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)"
)

if self.convert_module_fn:
self.apply_convert_module_fn()

self.init_ddp()

def apply_convert_module_fn(self):
for i in range(len(self)):
self[i] = self.convert_module_fn(self[i])

def init_ddp(self):
if not isinstance(self.ddp_config, DistributedDataParallelConfig):
return

from megatron.core import parallel_state

for model_chunk_idx, model_chunk in enumerate(self):
module = model_chunk.module

ddp = DDP(
module.config,
self.ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
)
model_chunk.module = ddp
model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses
model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore

# param_sync_func is set in nemo.lightning.pytorch.optim.megatron
no_sync_func, grad_sync_func = extract_ddp_funcs(self.ddp_config, self)
for module in self:
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func

def _build_context(self, context: Dict[str, Any]) -> Dict[str, Any]:
if "self" in context:
del context["self"]
Expand Down Expand Up @@ -565,18 +587,21 @@ def forward_backward_func(self) -> "MegatronStepProtocol":

@override
def __getattr__(self, item: Any) -> Any:
if len(self) == 0:
return super().__getattr__(item)

try:
# __getattr__ gets called as a last resort if the attribute does not exist
# call nn.Module's implementation first
# First, try to get the attribute from the superclass (nn.ModuleList)
return super().__getattr__(item)
except AttributeError:
# If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
attr = getattr(self._modules[self._get_abs_string_index(0)], item)
# If not found in superclass, check if we have any modules
if len(self) == 0:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{item}' and contains no modules"
)

return attr
# Try to get it from the first module
try:
return getattr(self._modules[self._get_abs_string_index(0)], item)
except AttributeError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")


class _ModuleStepFunction:
Expand Down Expand Up @@ -915,6 +940,12 @@ def _calc_number_of_params(model: List[nn.Module]) -> int:
return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])


def _calc_number_of_trainable_params(model: List[nn.Module]) -> int:
assert isinstance(model, list)

return sum([sum([p.numel() for p in model_module.parameters() if p.requires_grad]) for model_module in model])


def is_list_of_iterators(var) -> bool:
if not isinstance(var, list):
return False
Expand Down
5 changes: 4 additions & 1 deletion nemo/lightning/pytorch/callbacks/model_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

def _maybe_apply_transform(self, trainer):
if self._needs_to_call:
self.model_transform(trainer.model)
self.apply_transform(trainer)

def apply_transform(self, trainer):
self.model_transform(trainer.model)

@property
def _needs_to_call(self) -> bool:
Expand Down
18 changes: 13 additions & 5 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,27 @@ def __call__(self, model: nn.Module) -> nn.Module:
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
super().setup(trainer, pl_module, stage=stage)

trainer.strategy.trainer = trainer
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io)
trainer.strategy._checkpoint_io = self.wrapped_io
trainer.strategy._init_model_parallel = False
trainer.strategy._setup_optimizers = False

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
needs_to_call = self._needs_to_call
self._maybe_apply_transform(trainer)
def apply_transform(self, trainer):
super().apply_transform(trainer)

# Check if we need to load the adapters
if needs_to_call and self.wrapped_io.adapter_ckpt_path is not None:
if self.wrapped_io.adapter_ckpt_path is not None:
logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}")
adapter_state = self.wrapped_io.load_checkpoint(self.wrapped_io.adapter_ckpt_path)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)

if hasattr(trainer.strategy, "init_model_parallel"):
logging.info("Initializing model parallel")
trainer.strategy.init_model_parallel()

logging.info("Setting up optimizers")
trainer.strategy.setup_optimizers(trainer)

def on_load_checkpoint(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
) -> None:
Expand Down
1 change: 0 additions & 1 deletion nemo/lightning/pytorch/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,6 @@ def scheduler(self, model, optimizer):

return {
"optimizer": optimizer,
"scheduler": lr_scheduler,
"lr_scheduler": {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
Expand Down
6 changes: 4 additions & 2 deletions nemo/lightning/pytorch/plugins/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,17 @@ def convert_module(self, module: Module) -> Module:
This is optional and depends on the precision limitations during optimization.
"""
from megatron.core.distributed import DistributedDataParallel
from megatron.core.transformer.module import Float16Module
from megatron.core.utils import get_model_config

if self.precision in ["16-mixed", "bf16-mixed"]:
config = get_model_config(module.module)
config.fp16 = self.precision == "16-mixed"
config.bf16 = self.precision == "bf16-mixed"
if not isinstance(module.module, Float16Module):
if isinstance(module.module, Float16Module):
new_float16_module = Float16Module(config, module.module.module)
module.module = new_float16_module
else:
module.module = Float16Module(config, module.module)

return module
Expand Down
Loading

0 comments on commit 13345c5

Please sign in to comment.