Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Fix when optimizers are setup for PEFT #9619

Merged
merged 17 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def _setup(
model_transform: Optional[Union[PEFT, ModelTransform, Callable]],
) -> Any: # Return type is Any because app_state's type is not specified
_log = log or NeMoLogger()
if resume and resume.adapter_path and _log.ckpt:
if resume and isinstance(model_transform, PEFT) and _log.ckpt:
logging.info("Disabling try_restore_best_ckpt restoration for adapters")
_log.ckpt.try_restore_best_ckpt = False

Expand Down
17 changes: 13 additions & 4 deletions nemo/collections/llm/fn/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing_extensions import Self

from nemo.collections.llm.fn import base as fn
from nemo.utils import logging


class FNMixin:
Expand Down Expand Up @@ -114,8 +115,12 @@ def freeze(self) -> None:
"""
assert isinstance(self, nn.Module), "self is not a nn.Module"

for param in self.parameters():
param.requires_grad = False
params = list(self.parameters())
if not params:
logging.info(f"No parameters found in module {self.__class__.__name__}")
else:
for param in params:
param.requires_grad = False

def unfreeze(self) -> None:
"""
Expand All @@ -124,5 +129,9 @@ def unfreeze(self) -> None:
"""
assert isinstance(self, nn.Module), "self is not a nn.Module"

for param in self.parameters():
param.requires_grad = True
params = list(self.parameters())
if not params:
logging.info(f"No parameters found in module {self.__class__.__name__}")
else:
for param in params:
param.requires_grad = True
3 changes: 3 additions & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,4 +515,7 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri
elif count > n_nesting:
to_remove = "module." * (count - n_nesting)
_state_dict[key[len(to_remove) :]] = value
else:
_state_dict[key] = value

module.load_state_dict(_state_dict, strict=strict)
8 changes: 2 additions & 6 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,8 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None:
output_path (Path): The path where the model checkpoint will be saved.
trainer (pl.Trainer): The trainer with the strategy to save the model.
"""
_setup_kwargs = {}
setup_signature = inspect.signature(trainer.strategy.setup)
if 'setup_optimizers' in setup_signature.parameters:
_setup_kwargs["setup_optimizers"] = False

trainer.strategy.setup(trainer, **_setup_kwargs)
trainer.strategy._setup_optimizers = False
trainer.strategy.setup(trainer)
trainer.save_checkpoint(output_path)

def nemo_load(
Expand Down
159 changes: 95 additions & 64 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Iterable,
Iterator,
List,
Mapping,
Optional,
Protocol,
Sequence,
Expand Down Expand Up @@ -129,7 +128,6 @@ def __init__(
cpu: bool = False,
convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None,
) -> None:
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from megatron.core import parallel_state

_pipeline: List[nn.Module]
Expand All @@ -152,67 +150,15 @@ def __init__(
_model.configure_model()
_pipeline.append(_model)

if convert_module_fn:
for i in range(len(_pipeline)):
_pipeline[i] = convert_module_fn(_pipeline[i])

if isinstance(ddp_config, DistributedDataParallelConfig):
for model_chunk_idx, model_chunk in enumerate(_pipeline):
module = model_chunk.module

ddp = DDP(
module.config,
ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
)
model_chunk.module = ddp
model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses
model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore

# param_sync_func is set in nemo.lightning.pytorch.optim.megatron
no_sync_func, grad_sync_func = extract_ddp_funcs(ddp_config, _pipeline)
for module in _pipeline:
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func

for i, model_module in enumerate(_pipeline):
if not cpu:
model_module.cuda(torch.cuda.current_device())

for param in model_module.parameters():
set_defaults_if_not_set_tensor_model_parallel_attributes(param)

if hasattr(model_module, "configure_model"):
if not hasattr(model_module, "set_input_tensor"):
if hasattr(model_module.module, "set_input_tensor"):
model_module.set_input_tensor = model_module.module.set_input_tensor
else:
# TODO: What to do here?
pass

# Print number of parameters.
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
from nemo.utils import logging

msg = (
f" > number of parameters on (tensor, pipeline) model parallel rank "
f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): "
f"{_calc_number_of_params(_pipeline)}"
)
logging.info(msg)

super().__init__(_pipeline)
self.precision_plugin = precision_plugin
self._cpu = cpu
self.callbacks = callbacks or CallbackConnector()
self.data_step = data_step or default_data_step
self.forward_step = forward_step or default_forward_step
self.loss_reduction: MegatronLossReduction = loss_reduction
self.ddp_config = ddp_config
self.convert_module_fn = convert_module_fn

def forward(
self,
Expand Down Expand Up @@ -475,6 +421,82 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat

raise ValueError("Cannot infer `num_microbatches` from data, please specify it manually")

def init_model_parallel(self):
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from megatron.core import parallel_state

for model_module in self:
if not self._cpu:
model_module.cuda(torch.cuda.current_device())

for param in model_module.parameters():
set_defaults_if_not_set_tensor_model_parallel_attributes(param)

if hasattr(model_module, "configure_model"):
if not hasattr(model_module, "set_input_tensor"):
if hasattr(model_module.module, "set_input_tensor"):
model_module.set_input_tensor = model_module.module.set_input_tensor
else:
# TODO: What to do here?
pass

# Print number of parameters.
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
from nemo.utils import logging

num_params = _calc_number_of_params(list(self))
num_trainable_params = _calc_number_of_trainable_params(list(self))

msg = (
f" > number of parameters on (tensor, pipeline) model parallel rank "
f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): "
f"{num_params}"
)
logging.info(msg)

if num_params != num_trainable_params:
logging.info(
f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)"
)

if self.convert_module_fn:
self.apply_convert_module_fn()

self.init_ddp()

def apply_convert_module_fn(self):
for i in range(len(self)):
self[i] = self.convert_module_fn(self[i])

def init_ddp(self):
if not isinstance(self.ddp_config, DistributedDataParallelConfig):
return

from megatron.core import parallel_state

for model_chunk_idx, model_chunk in enumerate(self):
module = model_chunk.module

ddp = DDP(
module.config,
self.ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
)
model_chunk.module = ddp
model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses
model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore

# param_sync_func is set in nemo.lightning.pytorch.optim.megatron
no_sync_func, grad_sync_func = extract_ddp_funcs(self.ddp_config, self)
for module in self:
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func

def _build_context(self, context: Dict[str, Any]) -> Dict[str, Any]:
if "self" in context:
del context["self"]
Expand Down Expand Up @@ -565,18 +587,21 @@ def forward_backward_func(self) -> "MegatronStepProtocol":

@override
def __getattr__(self, item: Any) -> Any:
if len(self) == 0:
return super().__getattr__(item)

try:
# __getattr__ gets called as a last resort if the attribute does not exist
# call nn.Module's implementation first
# First, try to get the attribute from the superclass (nn.ModuleList)
return super().__getattr__(item)
except AttributeError:
# If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
attr = getattr(self._modules[self._get_abs_string_index(0)], item)
# If not found in superclass, check if we have any modules
if len(self) == 0:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{item}' and contains no modules"
)

return attr
# Try to get it from the first module
try:
return getattr(self._modules[self._get_abs_string_index(0)], item)
except AttributeError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")


class _ModuleStepFunction:
Expand Down Expand Up @@ -915,6 +940,12 @@ def _calc_number_of_params(model: List[nn.Module]) -> int:
return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])


def _calc_number_of_trainable_params(model: List[nn.Module]) -> int:
assert isinstance(model, list)

return sum([sum([p.numel() for p in model_module.parameters() if p.requires_grad]) for model_module in model])


def is_list_of_iterators(var) -> bool:
if not isinstance(var, list):
return False
Expand Down
5 changes: 4 additions & 1 deletion nemo/lightning/pytorch/callbacks/model_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

def _maybe_apply_transform(self, trainer):
if self._needs_to_call:
self.model_transform(trainer.model)
self.apply_transform(trainer)

def apply_transform(self, trainer):
self.model_transform(trainer.model)

@property
def _needs_to_call(self) -> bool:
Expand Down
18 changes: 13 additions & 5 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,27 @@ def __call__(self, model: nn.Module) -> nn.Module:
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
super().setup(trainer, pl_module, stage=stage)

trainer.strategy.trainer = trainer
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io)
trainer.strategy._checkpoint_io = self.wrapped_io
trainer.strategy._init_model_parallel = False
trainer.strategy._setup_optimizers = False

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
needs_to_call = self._needs_to_call
self._maybe_apply_transform(trainer)
def apply_transform(self, trainer):
super().apply_transform(trainer)

# Check if we need to load the adapters
if needs_to_call and self.wrapped_io.adapter_ckpt_path is not None:
if self.wrapped_io.adapter_ckpt_path is not None:
logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}")
adapter_state = self.wrapped_io.load_checkpoint(self.wrapped_io.adapter_ckpt_path)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)

if hasattr(trainer.strategy, "init_model_parallel"):
logging.info("Initializing model parallel")
trainer.strategy.init_model_parallel()

logging.info("Setting up optimizers")
trainer.strategy.setup_optimizers(trainer)

def on_load_checkpoint(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
) -> None:
Expand Down
1 change: 0 additions & 1 deletion nemo/lightning/pytorch/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,6 @@ def scheduler(self, model, optimizer):

return {
"optimizer": optimizer,
"scheduler": lr_scheduler,
"lr_scheduler": {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
Expand Down
6 changes: 4 additions & 2 deletions nemo/lightning/pytorch/plugins/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,17 @@ def convert_module(self, module: Module) -> Module:
This is optional and depends on the precision limitations during optimization.

"""
from megatron.core.distributed import DistributedDataParallel
from megatron.core.transformer.module import Float16Module
from megatron.core.utils import get_model_config

if self.precision in ["16-mixed", "bf16-mixed"]:
config = get_model_config(module.module)
config.fp16 = self.precision == "16-mixed"
config.bf16 = self.precision == "bf16-mixed"
if not isinstance(module.module, Float16Module):
if isinstance(module.module, Float16Module):
new_float16_module = Float16Module(config, module.module.module)
module.module = new_float16_module
else:
module.module = Float16Module(config, module.module)

return module
Expand Down
Loading
Loading