From 4f67ff4b4fd3db5ce8937cb6a26f6b7d2d1c132e Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 24 Sep 2024 20:48:17 -0400 Subject: [PATCH 1/5] contiguous Signed-off-by: Chen Cui --- nemo/collections/llm/peft/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/peft/lora.py b/nemo/collections/llm/peft/lora.py index 0d2a98fa3dfb..bdd23be4b029 100644 --- a/nemo/collections/llm/peft/lora.py +++ b/nemo/collections/llm/peft/lora.py @@ -50,7 +50,7 @@ def forward(self, x): linear_output, bias, layernorm_output = linear_output x = layernorm_output - adapter_output = self.adapter(x) + adapter_output = self.adapter(x.contiguous()) return linear_output + adapter_output, bias From 41a65288daca0f731816434dcca85da7135b1e02 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 24 Sep 2024 20:54:30 -0400 Subject: [PATCH 2/5] fix load Signed-off-by: Chen Cui --- nemo/lightning/pytorch/callbacks/peft.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index a3542d9a2135..1e3cde0bbcde 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -107,6 +107,9 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) def apply_transform(self, trainer): super().apply_transform(trainer) + self.trainable_params = set( + name for name, param in trainer.lightning_module.named_parameters() if param.requires_grad + ) adapter_sharded_state_dict = {} if self.wrapped_io.adapter_ckpt_path is not None: @@ -137,10 +140,6 @@ def apply_transform(self, trainer): if trainer.state.fn == TrainerFn.FITTING: trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True) - self.trainable_params = set( - name for name, param in trainer.lightning_module.named_parameters() if param.requires_grad - ) - def adapter_key_filter(self, key: str) -> bool: return key in self.trainable_params or ".adapter." in key or key.endswith(".adapters") From a04e517f48f0615ad6f19b317739f234ff125f82 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 25 Sep 2024 14:25:36 -0400 Subject: [PATCH 3/5] add test script Signed-off-by: Chen Cui --- tests/collections/llm/gpt_finetuning.py | 117 ++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 tests/collections/llm/gpt_finetuning.py diff --git a/tests/collections/llm/gpt_finetuning.py b/tests/collections/llm/gpt_finetuning.py new file mode 100644 index 000000000000..28ada6a0e92b --- /dev/null +++ b/tests/collections/llm/gpt_finetuning.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from dataclasses import dataclass + +import argparse + +## NOTE: This script is present for github-actions testing only. + +from nemo import lightning as nl +from nemo.collections import llm +from megatron.core.optimizer import OptimizerConfig + +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + + +@dataclass +class Llama3Config96M(llm.Llama3Config8B): + seq_length: int = 2048 + num_layers: int = 2 + hidden_size: int = 768 + ffn_hidden_size: int = 3072 + num_attention_heads: int = 8 + + +def get_args(): + parser = argparse.ArgumentParser(description='Finetune a small GPT model using NeMo 2.0') + parser.add_argument('--restore_path', type=str, help="Path to model to be finetuned") + parser.add_argument('--experiment_dir', type=str, help="directory to write results and checkpoints to") + parser.add_argument('--devices', type=int, default=1, help="number of devices") + parser.add_argument('--mbs', type=int, default=1, help="micro batch size") + parser.add_argument('--tp_size', type=int, default=1, help="tensor parallel size") + parser.add_argument('--pp_size', type=int, default=1, help="pipeline parallel size") + + return parser.parse_args() + + +if __name__ == '__main__': + args = get_args() + + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_parallel_size=args.pp_size, + ) + + trainer = nl.Trainer( + devices=args.devices, + max_steps=2, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + log_every_n_steps=1, + limit_val_batches=2, + val_check_interval=2, + num_sanity_val_steps=0, + ) + + ckpt = nl.ModelCheckpoint( + save_last=True, + monitor="reduced_train_loss", + save_top_k=1, + save_on_train_epoch_end=True, + save_optim_on_train_end=True, + ) + + logger = nl.NeMoLogger( + log_dir=args.experiment_dir, + use_datetime_version=False, # must be false if using auto resume + ckpt=ckpt, + ) + + adam = nl.MegatronOptimizerModule( + config=OptimizerConfig( + optimizer="adam", + lr=0.0001, + adam_beta2=0.98, + use_distributed_optimizer=True, + clip_grad=1.0, + bf16=True, + ), + ) + + lora = llm.peft.LoRA() + + squad = llm.SquadDataModule(seq_length=2048, micro_batch_size=args.mbs, global_batch_size=8, num_workers=0) + + tokenizer = get_nmt_tokenizer( + tokenizer_model="/lustre/fsw/coreai_dlalgo_llm/nemo_home/models/llama_96M/dummy_tokenizer.model" + ) + llama3_8b = llm.LlamaModel(Llama3Config96M(), tokenizer=tokenizer) + + resume = nl.AutoResume( + restore_config=nl.RestoreConfig(path=args.restore_path), + resume_if_exists=True, + ) + + llm.finetune( + model=llama3_8b, + data=squad, + trainer=trainer, + peft=lora, + log=logger, + optim=adam, + resume=resume, + ) \ No newline at end of file From 8c743c1f50fbb566e11499c40e7fe4591563ed44 Mon Sep 17 00:00:00 2001 From: cuichenx Date: Wed, 25 Sep 2024 18:26:36 +0000 Subject: [PATCH 4/5] Apply isort and black reformatting Signed-off-by: cuichenx --- tests/collections/llm/gpt_finetuning.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/collections/llm/gpt_finetuning.py b/tests/collections/llm/gpt_finetuning.py index 28ada6a0e92b..f1489d45e62b 100644 --- a/tests/collections/llm/gpt_finetuning.py +++ b/tests/collections/llm/gpt_finetuning.py @@ -11,20 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os - from dataclasses import dataclass -import argparse - -## NOTE: This script is present for github-actions testing only. +from megatron.core.optimizer import OptimizerConfig from nemo import lightning as nl from nemo.collections import llm -from megatron.core.optimizer import OptimizerConfig - from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +## NOTE: This script is present for github-actions testing only. + + + @dataclass class Llama3Config96M(llm.Llama3Config8B): @@ -114,4 +114,4 @@ def get_args(): log=logger, optim=adam, resume=resume, - ) \ No newline at end of file + ) From 80bf07d85e6b7d965d677f7818b431527ed726d3 Mon Sep 17 00:00:00 2001 From: artbataev Date: Wed, 25 Sep 2024 18:27:18 +0000 Subject: [PATCH 5/5] Apply isort and black reformatting Signed-off-by: artbataev --- tests/collections/llm/gpt_finetuning.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/collections/llm/gpt_finetuning.py b/tests/collections/llm/gpt_finetuning.py index f1489d45e62b..09050595aebe 100644 --- a/tests/collections/llm/gpt_finetuning.py +++ b/tests/collections/llm/gpt_finetuning.py @@ -24,8 +24,6 @@ ## NOTE: This script is present for github-actions testing only. - - @dataclass class Llama3Config96M(llm.Llama3Config8B): seq_length: int = 2048