From 5374cd848035a9647963bee3a2e9853f82fc11be Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Wed, 1 May 2024 02:58:08 +0000 Subject: [PATCH 01/13] update lora.py with dora parameters --- torchtune/modules/peft/lora.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 100329ff49..a7ea740f30 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -32,6 +32,7 @@ class LoRALinear(nn.Module, AdapterModule): rank (int): rank of the low-rank approximation alpha (float): scaling factor for the low-rank approximation dropout (float): dropout probability. Default: 0.0 + use_dora (bool): whether to use DORA (weight-Decomposed Low-Rank Adaptation). Default: False use_bias (bool): whether to include bias in the original linear layer. Default: False quantize_base (bool): Whether to quantize base linear weight or not. @@ -45,6 +46,7 @@ def __init__( rank: int, alpha: float, dropout: float = 0.0, + use_dora: bool = False, use_bias: bool = False, quantize_base: bool = False, ): @@ -54,6 +56,7 @@ def __init__( self.alpha = alpha self.out_dim = out_dim self.use_bias = use_bias + self.use_dora = use_dora self._quantize_base = quantize_base weight, bias = self._create_weight_and_bias() # 'self.disabled' is a flag showing whether to turn off LoRA adapters, @@ -67,6 +70,7 @@ def __init__( self.dropout = nn.Dropout(p=dropout) self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + self.m = nn.Parameter(F.ones(1, out_dim)) self.merged = False # Note: FSDP's meta device initialization contract assumes that a module's # reset_parameters method only initializes its own parameters (i.e. no child @@ -128,6 +132,11 @@ def forward(self, x: Tensor) -> Tensor: return out lora_out = self.lora_a(self.dropout(x)) lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) + # Adding 1e-6 to avoid division by zero + if self.use_dora: + return out + self.m * lora_out / ( + lora_out.norm(p=2, dim=-1, keepdim=True) + 1e-6 + ) return out + lora_out From aefb8cbb02712177d690ca65cbac480fcb8ac429 Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Wed, 1 May 2024 03:02:13 +0000 Subject: [PATCH 02/13] updated dora parameter initialization --- torchtune/modules/peft/lora.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index a7ea740f30..dccef196e9 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -32,7 +32,8 @@ class LoRALinear(nn.Module, AdapterModule): rank (int): rank of the low-rank approximation alpha (float): scaling factor for the low-rank approximation dropout (float): dropout probability. Default: 0.0 - use_dora (bool): whether to use DORA (weight-Decomposed Low-Rank Adaptation). Default: False + use_dora (bool): whether to use DORA (weight-Decomposed Low-Rank Adaptation). + Default: False use_bias (bool): whether to include bias in the original linear layer. Default: False quantize_base (bool): Whether to quantize base linear weight or not. @@ -70,7 +71,7 @@ def __init__( self.dropout = nn.Dropout(p=dropout) self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) - self.m = nn.Parameter(F.ones(1, out_dim)) + self.m = nn.Parameter(F.ones(1, out_dim)) if self.use_dora else None self.merged = False # Note: FSDP's meta device initialization contract assumes that a module's # reset_parameters method only initializes its own parameters (i.e. no child From dffb2a321262d409eaf2a9e745b40b0d762bdbda Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Sun, 5 May 2024 05:54:59 +0000 Subject: [PATCH 03/13] updated ones initialization and test --- tests/torchtune/modules/peft/test_lora.py | 19 +++++++++++++++++++ torchtune/modules/peft/lora.py | 3 ++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py index d48f2bea82..73ab510724 100644 --- a/tests/torchtune/modules/peft/test_lora.py +++ b/tests/torchtune/modules/peft/test_lora.py @@ -60,6 +60,19 @@ def lora_linear(self, in_dim, out_dim) -> LoRALinear: fixed_init_model(lora_linear) return lora_linear + @pytest.fixture + def dora_linear(self, in_dim, out_dim) -> LoRALinear: + lora_linear = LoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=RANK, + alpha=ALPHA, + use_bias=False, + use_dora=True, + ) + fixed_init_model(lora_linear) + return lora_linear + @pytest.fixture def qlora_linear(self, in_dim, out_dim) -> LoRALinear: with utils.set_default_dtype(torch.bfloat16): @@ -97,6 +110,12 @@ def test_forward(self, inputs, lora_linear, out_dim) -> None: assert actual.shape == (BSZ, SEQ_LEN, out_dim) torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6) + def test_dora_forward(self, inputs, dora_linear, out_dim) -> None: + expected = torch.tensor(EXPECTED_VAL) + actual = dora_linear(inputs) + assert actual.shape == (BSZ, SEQ_LEN, out_dim) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6) + def test_lora_weight_nf4_when_quantized(self, qlora_linear): assert isinstance(qlora_linear.weight, NF4Tensor) diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index dccef196e9..5505c15680 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -6,6 +6,7 @@ import math from typing import List +import torch import torch.nn.functional as F from torch import nn, Tensor @@ -71,7 +72,7 @@ def __init__( self.dropout = nn.Dropout(p=dropout) self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) - self.m = nn.Parameter(F.ones(1, out_dim)) if self.use_dora else None + self.m = nn.Parameter(torch.ones(1, out_dim)) if self.use_dora else None self.merged = False # Note: FSDP's meta device initialization contract assumes that a module's # reset_parameters method only initializes its own parameters (i.e. no child From cbafe85b56925fc51f75e338810e060af5c83270 Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Sun, 5 May 2024 06:38:13 +0000 Subject: [PATCH 04/13] updated dora update based on author recommendataion --- torchtune/modules/peft/lora.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 5505c15680..75d493fb9a 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -134,11 +134,17 @@ def forward(self, x: Tensor) -> Tensor: return out lora_out = self.lora_a(self.dropout(x)) lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) - # Adding 1e-6 to avoid division by zero + # Author mentions this method is faster for the computation purpose: + # https://github.com/huggingface/peft/pull/1474#issuecomment-1963402710 if self.use_dora: - return out + self.m * lora_out / ( - lora_out.norm(p=2, dim=-1, keepdim=True) + 1e-6 - ) + weight_norm = torch.linalg.norm( + self.weight + + (self.alpha / self.rank) + * (self.lora_a.weight.T @ self.lora_b.weight.T).T, + dim=1, + ).to(self.weight.dtype) + mag_norm_scale = (self.m / weight_norm - 1).view(1, -1) + return mag_norm_scale * out + mag_norm_scale * lora_out return out + lora_out From f76e076d79823459387184623b3db8b0aa46c121 Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Tue, 7 May 2024 04:17:22 +0000 Subject: [PATCH 05/13] fixed bugs and magnitude intialization --- torchtune/modules/peft/lora.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 75d493fb9a..6b0f03561d 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -34,6 +34,7 @@ class LoRALinear(nn.Module, AdapterModule): alpha (float): scaling factor for the low-rank approximation dropout (float): dropout probability. Default: 0.0 use_dora (bool): whether to use DORA (weight-Decomposed Low-Rank Adaptation). + link to the paper: https://arxiv.org/pdf/2402.09353 Default: False use_bias (bool): whether to include bias in the original linear layer. Default: False @@ -73,7 +74,7 @@ def __init__( self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) self.m = nn.Parameter(torch.ones(1, out_dim)) if self.use_dora else None - self.merged = False + self.dora_initialized = False # Note: FSDP's meta device initialization contract assumes that a module's # reset_parameters method only initializes its own parameters (i.e. no child # params are initialized, as is done in initialize_parameters below). @@ -117,6 +118,19 @@ def adapter_params(self) -> List[str]: adapter_params = ["lora_a.weight", "lora_b.weight"] return adapter_params + def dora_init(self) -> None: + weight_norm = self._dora_weight_norm + self.m = nn.Parameter(weight_norm, requires_grad=True) + self.dora_initialized = True + + @property + def _dora_weight_norm(self) -> Tensor: + return torch.linalg.norm( + self.weight + + (self.alpha / self.rank) * (self.lora_b.weight @ self.lora_a.weight), + dim=1, + ).to(self.weight.dtype) + def forward(self, x: Tensor) -> Tensor: """ Args: @@ -137,14 +151,13 @@ def forward(self, x: Tensor) -> Tensor: # Author mentions this method is faster for the computation purpose: # https://github.com/huggingface/peft/pull/1474#issuecomment-1963402710 if self.use_dora: - weight_norm = torch.linalg.norm( - self.weight - + (self.alpha / self.rank) - * (self.lora_a.weight.T @ self.lora_b.weight.T).T, - dim=1, - ).to(self.weight.dtype) - mag_norm_scale = (self.m / weight_norm - 1).view(1, -1) - return mag_norm_scale * out + mag_norm_scale * lora_out + # intialize the magnitude vector. + if not self.dora_initialized: + self.dora_init() + weight_norm = self._dora_weight_norm.detach() + mag_norm_scale = (self.m / weight_norm).view(1, -1) + # PEFT uses: out + (mag_norm_scale - 1) * out + mag_norm_scale * lora_b(lora_a(x)) * scaling. + return (out + lora_out) * mag_norm_scale return out + lora_out From fe08a068803ea505c83ca8f6b36a443b580bad6b Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Mon, 27 May 2024 20:41:20 +0000 Subject: [PATCH 06/13] working dora training with - default true --- .../configs/llama3/8B_Dora_single_device.yml | 85 +++ recipes/dora_finetune_single_device.py | 532 ++++++++++++++++++ torchtune/_recipe_registry.py | 11 + torchtune/modules/peft/lora.py | 47 +- torchtune/modules/peft/peft_utils.py | 7 + 5 files changed, 667 insertions(+), 15 deletions(-) create mode 100644 recipes/configs/llama3/8B_Dora_single_device.yml create mode 100644 recipes/dora_finetune_single_device.py diff --git a/recipes/configs/llama3/8B_Dora_single_device.yml b/recipes/configs/llama3/8B_Dora_single_device.yml new file mode 100644 index 0000000000..0f5ecd0ee5 --- /dev/null +++ b/recipes/configs/llama3/8B_Dora_single_device.yml @@ -0,0 +1,85 @@ +# Config for single device QLoRA with lora_finetune_single_device.py +# using a Llama3 8b Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8b-Instruct --output-dir /tmp/Meta-Llama-3-8b-Instruct --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama3/8b_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config llama3/8b_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.llama3.lora_llama3_8b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/original/tokenizer.model + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + train_on_input: True +seed: null +shuffle: True +batch_size: 1 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 16 +compile: False + +# Logging +output_dir: /tmp/dora_finetune_output/ +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Profiler (disabled) +profiler: + _component_: torchtune.utils.profiler + enabled: False diff --git a/recipes/dora_finetune_single_device.py b/recipes/dora_finetune_single_device.py new file mode 100644 index 0000000000..fd2faee89e --- /dev/null +++ b/recipes/dora_finetune_single_device.py @@ -0,0 +1,532 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple +from warnings import warn + +import torch +from omegaconf import DictConfig + +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, utils +from torchtune.modules.peft.peft_utils import ( + activate_dora_parms, + get_adapter_params, + get_merged_lora_ckpt, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): + """ + LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + Features: + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported.g + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * gradient accumulation steps. + + For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes + library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with + 8-bit AdamW and Paged AdamW. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + + """ + + def __init__(self, cfg: DictConfig) -> None: + + self._device = utils.get_device(device=cfg.device) + # Reduced precision logic + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + # fp16 precision is explicitly disabled as it is not supported in this + # recipe (for example, no gradient scaling). + if self._dtype == torch.float16: + raise ValueError( + "fp16 precision is not supported in this recipe. Please use fp32 or bf16." + ) + # For CUDA devices, check if the HW supports bf16 if bf16 is specified. + if ( + self._dtype == torch.bfloat16 + and self._device != torch.device("cpu") + and not torch.cuda.is_bf16_supported() + ): + raise RuntimeError("Full bf16 training is not supported on this hardware.") + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.total_training_steps = 0 + + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + if utils.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + # If seed, total_epoch or max_steps_per_epoch don't match, + # warn the user and overwrite + if ( + self.seed != ckpt_dict[utils.SEED_KEY] + or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] + or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY] + ): + warn( + message="""Configured value for seed, epochs or max_steps_per_epoch + does not match the value stored in checkpoint.""" + ) + self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + self._model_compile = cfg.compile + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=cfg.compile, + base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[utils.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None + ), + ) + + self._loss_fn = config.instantiate(cfg.loss) + log.info("Loss is initialized.") + + # Dataloader depends on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.total_training_steps = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.total_training_steps - 1, + ) + + self._profiler_enabled = cfg.profiler.enabled + self._profiler = config.instantiate(cfg.profiler) + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + compile_model: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + with utils.set_default_dtype(self._dtype), self._device: + model = config.instantiate(cfg_model) + + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self.adapter_params = get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + base_missing, base_unexpected = model.load_state_dict( + base_model_state_dict, strict=False + ) + if lora_weights_state_dict: + lora_missing, lora_unexpected = model.load_state_dict( + lora_weights_state_dict, strict=False + ) + else: + lora_missing, lora_unexpected = None, None + + # initialize dora (this will run init_dora in the lora_linear layer. if not self.lora_m will be 0. i.e just lora.) + activate_dora_parms(model) + + validate_missing_and_unexpected_for_lora( + lora_attn_modules=cfg_model.lora_attn_modules, + apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, + apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + # Validate model adapter params were loaded in with the expected dtype + # TODO (rohan-varma): Further validation to ensure the appropriate base params + # are NF4 vs bf16 based on the quantization config. + utils.validate_expected_param_dtype( + self.adapter_params.items(), dtype=self._dtype + ) + + log.info(f"Model is initialized with precision {self._dtype}.") + # Compile model, if enabled. + if compile_model: + log.info("Compiling model with torch.compile...") + model = utils.wrap_compile(model) + if self._device.type == "cuda": + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + + log.info("Optimizer and loss are initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports + Map-style Datasets which fit into memory and an option for random shuffling. + Samplers, iterable datasets, and streaming datasets are not supported. + """ + ds = config.instantiate( + cfg_dataset, + tokenizer=self._tokenizer, + ) + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + sampler=sampler, + batch_size=batch_size, + collate_fn=partial( + utils.padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ), + ) + + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the merged weights, adapter weights and recipe state in + different checkpoint files. To correctly resume from training, the adapter weights + and recipe state must be provided along with the base model weights. + """ + ckpt_dict = {} + # if training is in-progress, checkpoint the optimizer state as well + if epoch + 1 < self.total_epochs: + ckpt_dict.update( + { + utils.OPT_KEY: self._optimizer.state_dict(), + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + # Move to CPU to avoid a copy on GPU + state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} + + # Construct the full state dict with LoRA weights merged into base LLM weights + merged_state_dict = get_merged_lora_ckpt( + state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + ckpt_dict.update({utils.MODEL_KEY: merged_state_dict}) + + # Construct the adapter weights + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) + } + ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), + ) + + def train(self) -> None: + """ + The core training loop. + """ + + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." + ) + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + # Optionally profile the training loop + with self._profiler: + pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + if self._profiler_enabled: + self._profiler.step() + + input_ids, labels = batch + input_ids = input_ids.to(self._device) + num_tokens += input_ids.numel() + labels = labels.to(self._device) + + logits = self._model(input_ids) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + # Update the number of steps when the weights are updated + self.total_training_steps += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch+1}|{self.total_training_steps}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if self.total_training_steps % self._log_every_n_steps == 0: + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second": num_tokens / time_per_step, + } + if ( + self._device.type == "cuda" + and self._log_peak_memory_stats + ): + log_dict.update( + utils.get_memory_stats(device=self._device) + ) + self._metric_logger.log_dict( + log_dict, + step=self.total_training_steps, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + config.log_config(recipe_name="DoRAFinetuneRecipeSingleDevice", cfg=cfg) + recipe = LoRAFinetuneRecipeSingleDevice(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 281cfbaec8..8b99e80e30 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -109,6 +109,17 @@ class Recipe: ], supports_distributed=False, ), + Recipe( + name="dora_finetune_single_device", + file_path="dora_finetune_single_device.py", + configs=[ + Config( + name="llama3/*B_dora_single_device", + file_path="llama3/8B_dora_single_device.yaml", + ), + ], + supports_distributed=False, + ), Recipe( name="lora_dpo_single_device", file_path="lora_dpo_single_device.py", diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 6b0f03561d..baf7fa26bf 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -49,7 +49,7 @@ def __init__( rank: int, alpha: float, dropout: float = 0.0, - use_dora: bool = False, + use_dora: bool = True, # TODO(prakyath): add this at each models inference, Do Not make this aas default True. use_bias: bool = False, quantize_base: bool = False, ): @@ -73,8 +73,7 @@ def __init__( self.dropout = nn.Dropout(p=dropout) self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) - self.m = nn.Parameter(torch.ones(1, out_dim)) if self.use_dora else None - self.dora_initialized = False + self.lora_m = nn.Parameter(torch.zeros(1, out_dim)) # Note: FSDP's meta device initialization contract assumes that a module's # reset_parameters method only initializes its own parameters (i.e. no child # params are initialized, as is done in initialize_parameters below). @@ -90,6 +89,7 @@ def initialize_parameters(self): # https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119 _lora_a_init_params(self.lora_a) _lora_b_init_params(self.lora_b) + _dora_m_init_params(self.lora_m) def _create_weight_and_bias(self): """ @@ -116,20 +116,33 @@ def adapter_params(self) -> List[str]: # NOTE: this function has to be updated if the names of "lora_a" and "lora_b" # in this module change. adapter_params = ["lora_a.weight", "lora_b.weight"] + if self.use_dora: + adapter_params.append("lora_m") return adapter_params - def dora_init(self) -> None: + def init_dora(self) -> None: + # this is a seperate function because, + # this should be called after model state dict is called. + # But We verify and initialize the model arch first before the loading weights. weight_norm = self._dora_weight_norm - self.m = nn.Parameter(weight_norm, requires_grad=True) - self.dora_initialized = True + self.lora_m.data = weight_norm.data # Update the data of 'm' directly @property def _dora_weight_norm(self) -> Tensor: - return torch.linalg.norm( - self.weight - + (self.alpha / self.rank) * (self.lora_b.weight @ self.lora_a.weight), - dim=1, - ).to(self.weight.dtype) + if self._quantize_base: + # Convert NF4Tensor to regular Tensor for computation TODO(prakyath): Fix this. + weight = to_regular_tensor(self.weight) + else: + weight = self.weight + + # Perform the operation with regular tensors + result = weight + (self.alpha / self.rank) * ( + self.lora_b.weight @ self.lora_a.weight + ) + norm = torch.linalg.norm(result, dim=1) + + # Convert back if necessary (depending on your requirements) + return norm def forward(self, x: Tensor) -> Tensor: """ @@ -151,11 +164,8 @@ def forward(self, x: Tensor) -> Tensor: # Author mentions this method is faster for the computation purpose: # https://github.com/huggingface/peft/pull/1474#issuecomment-1963402710 if self.use_dora: - # intialize the magnitude vector. - if not self.dora_initialized: - self.dora_init() weight_norm = self._dora_weight_norm.detach() - mag_norm_scale = (self.m / weight_norm).view(1, -1) + mag_norm_scale = (self.lora_m / weight_norm).view(1, -1) # PEFT uses: out + (mag_norm_scale - 1) * out + mag_norm_scale * lora_b(lora_a(x)) * scaling. return (out + lora_out) * mag_norm_scale return out + lora_out @@ -173,3 +183,10 @@ def _lora_b_init_params(x: nn.Linear) -> None: Initialize LoRA B weight to zeros. """ nn.init.zeros_(x.weight) + + +def _dora_m_init_params(x: nn.Parameter) -> None: + """ + Initialize DORA m to ones. + """ + nn.init.zeros_(x) diff --git a/torchtune/modules/peft/peft_utils.py b/torchtune/modules/peft/peft_utils.py index 0eea5bd8fb..06a7ea9063 100644 --- a/torchtune/modules/peft/peft_utils.py +++ b/torchtune/modules/peft/peft_utils.py @@ -63,6 +63,13 @@ def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]: return adapter_params +def activate_dora_parms(model: nn.Module) -> nn.Module: + for k, v in model.named_modules(): + if hasattr(v, "adapter_params") and callable(v.adapter_params): + current_adapter_params = v.adapter_params() + v.init_dora() # TODO(prakyath) check if module is LoraLinear and use_dora is true and then apply. + + @functools.lru_cache() def _get_base_model_params(model: nn.Module) -> Dict[str, Any]: """ From 4a58dba01010df53ac9be7b5419d4699916564b5 Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Tue, 28 May 2024 02:47:08 +0000 Subject: [PATCH 07/13] updated config with seed and the print for testing and comparision --- .../configs/llama3/8B_Dora_single_device.yml | 4 +-- recipes/dora_finetune_single_device.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/recipes/configs/llama3/8B_Dora_single_device.yml b/recipes/configs/llama3/8B_Dora_single_device.yml index 0f5ecd0ee5..4962491578 100644 --- a/recipes/configs/llama3/8B_Dora_single_device.yml +++ b/recipes/configs/llama3/8B_Dora_single_device.yml @@ -44,7 +44,7 @@ resume_from_checkpoint: False dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset train_on_input: True -seed: null +seed: 12345678 shuffle: True batch_size: 1 @@ -69,7 +69,7 @@ compile: False # Logging output_dir: /tmp/dora_finetune_output/ metric_logger: - _component_: torchtune.utils.metric_logging.DiskLogger + _component_: torchtune.utils.metric_logging.WandBLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/recipes/dora_finetune_single_device.py b/recipes/dora_finetune_single_device.py index fd2faee89e..ff3df9be52 100644 --- a/recipes/dora_finetune_single_device.py +++ b/recipes/dora_finetune_single_device.py @@ -293,6 +293,31 @@ def _setup_model( utils.validate_expected_param_dtype( self.adapter_params.items(), dtype=self._dtype ) + # This is for comparing with peft, remove after comparision. + # 1. Get number of trainable parameters. + # 2. Get number of total parameters. + # 3. If possible extract the name of the trainable parameters + # and name of the non-traiable parameter and store in the wandb summary. + + # Calculate total parameters + total_params = sum(p.numel() for p in model.parameters()) + log.info(f"Total parameters: {total_params}") + + # Calculate trainable parameters + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + log.info(f"Trainable parameters: {trainable_params}") + + # Extract names of trainable and non-trainable parameters + trainable_names = [ + name for name, param in model.named_parameters() if param.requires_grad + ] + non_trainable_names = [ + name for name, param in model.named_parameters() if not param.requires_grad + ] + + # Print names + log.info(f"Trainable parameter names: {trainable_names}") + log.info(f"Non-trainable parameter names: {non_trainable_names}") log.info(f"Model is initialized with precision {self._dtype}.") # Compile model, if enabled. From f096e4b788c8fb192ebf1a075d6899f557038b91 Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Fri, 31 May 2024 22:53:03 +0000 Subject: [PATCH 08/13] update dora initialization and llama3 loading --- torchtune/models/llama3/__init__.py | 2 ++ .../models/llama3/_component_builders.py | 15 +++++++++- torchtune/models/llama3/_model_builders.py | 12 ++++++++ torchtune/modules/peft/lora.py | 28 +++++++++++-------- 4 files changed, 45 insertions(+), 12 deletions(-) diff --git a/torchtune/models/llama3/__init__.py b/torchtune/models/llama3/__init__.py index 975147d7be..11fb633b96 100644 --- a/torchtune/models/llama3/__init__.py +++ b/torchtune/models/llama3/__init__.py @@ -7,6 +7,7 @@ from ._component_builders import llama3, lora_llama3 from ._model_builders import ( # noqa + dora_llama3_8b, llama3_70b, llama3_8b, llama3_tokenizer, @@ -26,4 +27,5 @@ "lora_llama3_70b", "qlora_llama3_8b", "scale_hidden_dim_for_mlp", + "dora_llama3_8b", ] diff --git a/torchtune/models/llama3/_component_builders.py b/torchtune/models/llama3/_component_builders.py index 0828285645..678a0e92b8 100644 --- a/torchtune/models/llama3/_component_builders.py +++ b/torchtune/models/llama3/_component_builders.py @@ -150,6 +150,8 @@ def lora_llama3( lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, + # dora args + use_dora: bool = False, # Quantization args quantize_base: bool = False, ) -> TransformerDecoder: @@ -204,6 +206,7 @@ def lora_llama3( lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + use_dora=use_dora, quantize_base=quantize_base, ) @@ -214,6 +217,7 @@ def lora_llama3( hidden_dim=hidden_dim, lora_rank=lora_rank, lora_alpha=lora_alpha, + use_dora=use_dora, quantize_base=quantize_base, ) else: @@ -230,7 +234,7 @@ def lora_llama3( # TODO: quantize_base is not applied to final output_proj currently. output_proj = ( - LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha) + LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, use_dora=use_dora) if apply_lora_to_output else nn.Linear(embed_dim, vocab_size, bias=False) ) @@ -270,6 +274,7 @@ def lora_llama3_self_attention( lora_alpha: float, lora_dropout: float = 0.0, quantize_base: bool = False, + use_dora: bool = False, ) -> CausalSelfAttention: """ Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA @@ -316,6 +321,7 @@ def lora_llama3_self_attention( rank=lora_rank, alpha=lora_alpha, quantize_base=quantize_base, + use_dora=use_dora, ) if "q_proj" in lora_modules else nn.Linear(embed_dim, num_heads * head_dim, bias=False) @@ -327,6 +333,7 @@ def lora_llama3_self_attention( rank=lora_rank, alpha=lora_alpha, quantize_base=quantize_base, + use_dora=use_dora, ) if "k_proj" in lora_modules else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) @@ -338,6 +345,7 @@ def lora_llama3_self_attention( rank=lora_rank, alpha=lora_alpha, quantize_base=quantize_base, + use_dora=use_dora, ) if "v_proj" in lora_modules else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) @@ -349,6 +357,7 @@ def lora_llama3_self_attention( rank=lora_rank, alpha=lora_alpha, quantize_base=quantize_base, + use_dora=use_dora, ) if "output_proj" in lora_modules else nn.Linear(embed_dim, embed_dim, bias=False) @@ -377,6 +386,7 @@ def lora_llama3_mlp( lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, + use_dora: bool = False, quantize_base: bool = False, ) -> FeedForward: gate_proj = LoRALinear( @@ -386,6 +396,7 @@ def lora_llama3_mlp( alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, + use_dora=use_dora, ) down_proj = LoRALinear( in_dim=hidden_dim, @@ -394,6 +405,7 @@ def lora_llama3_mlp( alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, + use_dora=use_dora, ) up_proj = LoRALinear( in_dim=dim, @@ -402,6 +414,7 @@ def lora_llama3_mlp( alpha=lora_alpha, dropout=lora_dropout, quantize_base=quantize_base, + use_dora=use_dora, ) return FeedForward( gate_proj=gate_proj, diff --git a/torchtune/models/llama3/_model_builders.py b/torchtune/models/llama3/_model_builders.py index 9369a03cb7..f5cfd2e93f 100644 --- a/torchtune/models/llama3/_model_builders.py +++ b/torchtune/models/llama3/_model_builders.py @@ -77,6 +77,7 @@ def lora_llama3_8b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + use_dora: bool = False, quantize_base: bool = False, ) -> TransformerDecoder: """ @@ -118,7 +119,9 @@ def lora_llama3_8b( lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=0.05, + use_dora=use_dora, quantize_base=quantize_base, + use_dora=use_dora, ) @@ -180,3 +183,12 @@ def lora_llama3_70b( that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. Please see `lora_llama3_8b` for full API arguments. """ + +dora_llama3_8b = partial(lora_llama3_8b, use_dora=True) + +dora_llama3_8b.__doc__ = """ +Builder for creating a Llama3 model with DORA enabled. Base model weights in linear layers +that DORA is applied to are quantized per the Dora paper: https://arxiv.org/abs/2402.09353. +In addition to the lora adaptor weights, DORA also adds a trainable magnitude parameters. +Please see `lora_llama3_8b` for full API arguments. +""" diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index baf7fa26bf..52307ce7de 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -49,7 +49,7 @@ def __init__( rank: int, alpha: float, dropout: float = 0.0, - use_dora: bool = True, # TODO(prakyath): add this at each models inference, Do Not make this aas default True. + use_dora: bool = False, # TODO(prakyath): add this at each models inference, Do Not make this aas default True. use_bias: bool = False, quantize_base: bool = False, ): @@ -73,7 +73,8 @@ def __init__( self.dropout = nn.Dropout(p=dropout) self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) - self.lora_m = nn.Parameter(torch.zeros(1, out_dim)) + if self.use_dora: + self.lora_m = nn.Parameter(torch.zeros(1, out_dim)) # Note: FSDP's meta device initialization contract assumes that a module's # reset_parameters method only initializes its own parameters (i.e. no child # params are initialized, as is done in initialize_parameters below). @@ -89,7 +90,8 @@ def initialize_parameters(self): # https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119 _lora_a_init_params(self.lora_a) _lora_b_init_params(self.lora_b) - _dora_m_init_params(self.lora_m) + if self.use_dora: + _dora_m_init_params(self.lora_m) def _create_weight_and_bias(self): """ @@ -129,11 +131,12 @@ def init_dora(self) -> None: @property def _dora_weight_norm(self) -> Tensor: - if self._quantize_base: - # Convert NF4Tensor to regular Tensor for computation TODO(prakyath): Fix this. - weight = to_regular_tensor(self.weight) - else: - weight = self.weight + """ + Compute the norm of the linear weight and lora adaptor weights. + If the base model is quantized, dequantize the weights before computing the norm. + Return the norm in NF4 format if the base model is quantized. + """ + weight = self.weight.dequantize() if self._quantize_base else self.weight # Perform the operation with regular tensors result = weight + (self.alpha / self.rank) * ( @@ -141,8 +144,11 @@ def _dora_weight_norm(self) -> Tensor: ) norm = torch.linalg.norm(result, dim=1) - # Convert back if necessary (depending on your requirements) - return norm + # Clamp the norm to avoid division by zero + # TODO(Prakyath): Check with torchtune team whether this should be a parameter ? + norm = torch.clamp(norm, min=1e-6) + # Return the norm in NF4 format. + return to_nf4(norm) if self._quantize_base else norm def forward(self, x: Tensor) -> Tensor: """ @@ -189,4 +195,4 @@ def _dora_m_init_params(x: nn.Parameter) -> None: """ Initialize DORA m to ones. """ - nn.init.zeros_(x) + nn.init.ones_(x) From a55a962e510c259e5b5d2572c2a9b2c654453d84 Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Fri, 31 May 2024 22:53:55 +0000 Subject: [PATCH 09/13] removed dora specific recipe and merge recipes --- ...e_device.yml => 8B_dora_single_device.yml} | 3 +- recipes/dora_finetune_single_device.py | 557 ------------------ recipes/lora_finetune_single_device.py | 8 + 3 files changed, 10 insertions(+), 558 deletions(-) rename recipes/configs/llama3/{8B_Dora_single_device.yml => 8B_dora_single_device.yml} (97%) delete mode 100644 recipes/dora_finetune_single_device.py diff --git a/recipes/configs/llama3/8B_Dora_single_device.yml b/recipes/configs/llama3/8B_dora_single_device.yml similarity index 97% rename from recipes/configs/llama3/8B_Dora_single_device.yml rename to recipes/configs/llama3/8B_dora_single_device.yml index 4962491578..eff726a665 100644 --- a/recipes/configs/llama3/8B_Dora_single_device.yml +++ b/recipes/configs/llama3/8B_dora_single_device.yml @@ -17,12 +17,13 @@ # Model Arguments model: - _component_: torchtune.models.llama3.lora_llama3_8b + _component_: torchtune.models.llama3.dora_llama3_8b lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] apply_lora_to_mlp: True apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + use_dora: True # Tokenizer tokenizer: diff --git a/recipes/dora_finetune_single_device.py b/recipes/dora_finetune_single_device.py deleted file mode 100644 index ff3df9be52..0000000000 --- a/recipes/dora_finetune_single_device.py +++ /dev/null @@ -1,557 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import sys -import time - -from functools import partial -from typing import Any, Dict, Optional, Tuple -from warnings import warn - -import torch -from omegaconf import DictConfig - -from torch import nn -from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, utils -from torchtune.modules.peft.peft_utils import ( - activate_dora_parms, - get_adapter_params, - get_merged_lora_ckpt, - set_trainable_params, - validate_missing_and_unexpected_for_lora, -) -from torchtune.recipe_interfaces import FTRecipeInterface -from tqdm import tqdm - -log = utils.get_logger("DEBUG") - - -class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): - """ - LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe is optimized - for single GPU training. Training on CPU is not supported. - - Features: - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` - flag. Activation checkpointing helps reduce the memory footprint since we no longer keep - activations in memory and instead recompute them during the backward pass. This is especially - helpful for larger batch sizes when you're memory constrained. But these savings in memory - come at the cost of training performance. In most cases training can slow-down quite a bit as - a result of this activation recomputation. - - - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` - flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In - most cases this should halve the memory footprint of full precision (fp32) training, without - loss in model quality (will depend on the model, training data and other settings). For - GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 - precision are currently not supported.g - - - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is - controlled using the ``gradient_accumulation_steps`` flag. - - Total Batch Size = batch_size * gradient accumulation steps. - - For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. - - Gradient accumulation is especially useful when you are memory constrained. In this case, - accumulating gradients might give you better training speed than enabling activation - checkpointing. - - - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes - library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with - 8-bit AdamW and Paged AdamW. - - - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of - training. Currently we checkpoint both the adapter weights (trainable params only) and the - complete merged weights (adapter weights added back to the base model). For more details - please take a look at our LoRA tutorial - (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). - - Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are - only saved at the end of a given epoch and used in case of resuming training. Resuming - training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is - currently not supported. - - For more details on the checkpointer, please take a look at - our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). - - - Logging. Terminal, Disk, WandB and TensorBoard are all supported. - - For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config - has example commands for how to kick-off training. - - Args: - cfg (DictConfig): OmegaConf object parsed from yaml file - - Raises: - ValueError: If ``dtype`` is set to fp16. - RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. - - """ - - def __init__(self, cfg: DictConfig) -> None: - - self._device = utils.get_device(device=cfg.device) - # Reduced precision logic - self._dtype = utils.get_dtype(cfg.dtype, device=self._device) - # fp16 precision is explicitly disabled as it is not supported in this - # recipe (for example, no gradient scaling). - if self._dtype == torch.float16: - raise ValueError( - "fp16 precision is not supported in this recipe. Please use fp32 or bf16." - ) - # For CUDA devices, check if the HW supports bf16 if bf16 is specified. - if ( - self._dtype == torch.bfloat16 - and self._device != torch.device("cpu") - and not torch.cuda.is_bf16_supported() - ): - raise RuntimeError("Full bf16 training is not supported on this hardware.") - # logging attributes - self._output_dir = cfg.output_dir - self._log_every_n_steps = cfg.get("log_every_n_steps", 1) - self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) - - # These are public properties which are updated by the checkpoint loader - # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=cfg.seed) - self.epochs_run = 0 - self.total_epochs = cfg.epochs - self.max_steps_per_epoch = cfg.max_steps_per_epoch - self.total_training_steps = 0 - - self._resume_from_checkpoint = cfg.resume_from_checkpoint - self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - - def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: - """ - Extract the checkpoint state from file and validate. This includes the - base model weights. If resume_from_checkpoint is True, this also includes - the adapter weights and recipe state - """ - self._checkpointer = config.instantiate( - cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, - ) - checkpoint_dict = self._checkpointer.load_checkpoint() - - if self._resume_from_checkpoint: - if utils.ADAPTER_KEY not in checkpoint_dict: - raise ValueError( - "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." - ) - # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded - # no need to check here - self._update_recipe_state(checkpoint_dict) - return checkpoint_dict - - def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: - """ - Updates the recipe state from checkpoint. - """ - # If seed, total_epoch or max_steps_per_epoch don't match, - # warn the user and overwrite - if ( - self.seed != ckpt_dict[utils.SEED_KEY] - or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] - or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY] - ): - warn( - message="""Configured value for seed, epochs or max_steps_per_epoch - does not match the value stored in checkpoint.""" - ) - self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) - self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] - self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] - self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] - - def setup(self, cfg: DictConfig) -> None: - """ - Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), - model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. - """ - self._metric_logger = config.instantiate(cfg.metric_logger) - - # log config with parameter override - self._metric_logger.log_config(cfg) - - self._model_compile = cfg.compile - checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - - self._model = self._setup_model( - cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, - compile_model=cfg.compile, - base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], - lora_weights_state_dict=( - checkpoint_dict[utils.ADAPTER_KEY] - if self._resume_from_checkpoint - else None - ), - ) - - self._tokenizer = config.instantiate(cfg.tokenizer) - log.info("Tokenizer is initialized from file.") - - self._optimizer = self._setup_optimizer( - cfg_optimizer=cfg.optimizer, - opt_state_dict=( - checkpoint_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None - ), - ) - - self._loss_fn = config.instantiate(cfg.loss) - log.info("Loss is initialized.") - - # Dataloader depends on the tokenizer and loss_fn and should be - # setup after all of these are setup - self._sampler, self._dataloader = self._setup_data( - cfg_dataset=cfg.dataset, - shuffle=cfg.shuffle, - batch_size=cfg.batch_size, - ) - - # Finally update the recipe state which can only be correctly set after all of the - # other components have been initialized and updated. - - # Number of training steps in each epoch depends on the number of batches produced - # by the dataloader and the max_steps_per_epoch param set by the user and is used - # for logging and tracking training state. This should be computed after the dataloader - # has been setup - self._steps_per_epoch = ( - len(self._dataloader) // self._gradient_accumulation_steps - ) - if ( - self.max_steps_per_epoch is not None - and self.max_steps_per_epoch < self._steps_per_epoch - ): - self._steps_per_epoch = self.max_steps_per_epoch - self.total_training_steps = self.epochs_run * self._steps_per_epoch - - # Learning rate scheduler can only be set up after number of steps - # has been computed - self._lr_scheduler = self._setup_lr_scheduler( - cfg_lr_scheduler=cfg.lr_scheduler, - num_training_steps=self.total_epochs * self._steps_per_epoch, - last_epoch=self.total_training_steps - 1, - ) - - self._profiler_enabled = cfg.profiler.enabled - self._profiler = config.instantiate(cfg.profiler) - - def _setup_model( - self, - cfg_model: DictConfig, - enable_activation_checkpointing: bool, - compile_model: bool, - base_model_state_dict: Dict[str, Any], - lora_weights_state_dict: Optional[Dict[str, Any]] = None, - ) -> nn.Module: - with utils.set_default_dtype(self._dtype), self._device: - model = config.instantiate(cfg_model) - - self._lora_rank = cfg_model.lora_rank - self._lora_alpha = cfg_model.lora_alpha - self.adapter_params = get_adapter_params(model) - set_trainable_params(model, self.adapter_params) - - if enable_activation_checkpointing: - utils.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerDecoderLayer} - ) - - base_missing, base_unexpected = model.load_state_dict( - base_model_state_dict, strict=False - ) - if lora_weights_state_dict: - lora_missing, lora_unexpected = model.load_state_dict( - lora_weights_state_dict, strict=False - ) - else: - lora_missing, lora_unexpected = None, None - - # initialize dora (this will run init_dora in the lora_linear layer. if not self.lora_m will be 0. i.e just lora.) - activate_dora_parms(model) - - validate_missing_and_unexpected_for_lora( - lora_attn_modules=cfg_model.lora_attn_modules, - apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, - apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), - base_missing=base_missing, - base_unexpected=base_unexpected, - lora_missing=lora_missing, - lora_unexpected=lora_unexpected, - ) - # Validate model adapter params were loaded in with the expected dtype - # TODO (rohan-varma): Further validation to ensure the appropriate base params - # are NF4 vs bf16 based on the quantization config. - utils.validate_expected_param_dtype( - self.adapter_params.items(), dtype=self._dtype - ) - # This is for comparing with peft, remove after comparision. - # 1. Get number of trainable parameters. - # 2. Get number of total parameters. - # 3. If possible extract the name of the trainable parameters - # and name of the non-traiable parameter and store in the wandb summary. - - # Calculate total parameters - total_params = sum(p.numel() for p in model.parameters()) - log.info(f"Total parameters: {total_params}") - - # Calculate trainable parameters - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - log.info(f"Trainable parameters: {trainable_params}") - - # Extract names of trainable and non-trainable parameters - trainable_names = [ - name for name, param in model.named_parameters() if param.requires_grad - ] - non_trainable_names = [ - name for name, param in model.named_parameters() if not param.requires_grad - ] - - # Print names - log.info(f"Trainable parameter names: {trainable_names}") - log.info(f"Non-trainable parameter names: {non_trainable_names}") - - log.info(f"Model is initialized with precision {self._dtype}.") - # Compile model, if enabled. - if compile_model: - log.info("Compiling model with torch.compile...") - model = utils.wrap_compile(model) - if self._device.type == "cuda": - memory_stats = utils.get_memory_stats(device=self._device) - utils.log_memory_stats(memory_stats) - return model - - def _setup_optimizer( - self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None - ) -> Optimizer: - optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - optimizer.load_state_dict(opt_state_dict) - - log.info("Optimizer and loss are initialized.") - return optimizer - - def _setup_lr_scheduler( - self, - cfg_lr_scheduler: DictConfig, - num_training_steps: int, - last_epoch: int, - ) -> Optimizer: - lr_scheduler = config.instantiate( - cfg_lr_scheduler, - self._optimizer, - num_training_steps=num_training_steps, - last_epoch=last_epoch, - ) - - log.info("Learning rate scheduler is initialized.") - return lr_scheduler - - def _setup_data( - self, - cfg_dataset: DictConfig, - shuffle: bool, - batch_size: int, - ) -> Tuple[DistributedSampler, DataLoader]: - """ - All data related setup happens here. Currently this recipe only supports - Map-style Datasets which fit into memory and an option for random shuffling. - Samplers, iterable datasets, and streaming datasets are not supported. - """ - ds = config.instantiate( - cfg_dataset, - tokenizer=self._tokenizer, - ) - sampler = DistributedSampler( - ds, - num_replicas=1, - rank=0, - shuffle=shuffle, - seed=0, - ) - dataloader = DataLoader( - dataset=ds, - sampler=sampler, - batch_size=batch_size, - collate_fn=partial( - utils.padded_collate, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ), - ) - - log.info("Dataset and Sampler are initialized.") - - return sampler, dataloader - - def save_checkpoint(self, epoch: int) -> None: - """ - Checkpoint the state of the recipe. The constructed checkpoint state dict - contains the following information: - - Merged weights with key MODEL_KEY - - Adapter weights with key ADAPTER_KEY - - Relevant recipe state if training is not complete - - Checkpointer will save the merged weights, adapter weights and recipe state in - different checkpoint files. To correctly resume from training, the adapter weights - and recipe state must be provided along with the base model weights. - """ - ckpt_dict = {} - # if training is in-progress, checkpoint the optimizer state as well - if epoch + 1 < self.total_epochs: - ckpt_dict.update( - { - utils.OPT_KEY: self._optimizer.state_dict(), - utils.SEED_KEY: self.seed, - utils.EPOCHS_KEY: self.epochs_run, - utils.TOTAL_EPOCHS_KEY: self.total_epochs, - utils.MAX_STEPS_KEY: self.max_steps_per_epoch, - } - ) - - # Move to CPU to avoid a copy on GPU - state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} - - # Construct the full state dict with LoRA weights merged into base LLM weights - merged_state_dict = get_merged_lora_ckpt( - state_dict, - rank=self._lora_rank, - alpha=self._lora_alpha, - ) - ckpt_dict.update({utils.MODEL_KEY: merged_state_dict}) - - # Construct the adapter weights - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) - } - ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) - self._checkpointer.save_checkpoint( - ckpt_dict, - epoch=epoch, - intermediate_checkpoint=(epoch + 1 < self.total_epochs), - ) - - def train(self) -> None: - """ - The core training loop. - """ - - if self._model_compile: - log.info( - "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." - ) - - # Initialize tokens count and running loss (for grad accumulation) - t0 = time.perf_counter() - running_loss = 0 - num_tokens = 0 - - # self.epochs_run should be non-zero when we're resuming from a checkpoint - for curr_epoch in range(self.epochs_run, self.total_epochs): - # Update the sampler to ensure data is correctly shuffled across epochs - # in case shuffle is True - self._sampler.set_epoch(curr_epoch) - - # Optionally profile the training loop - with self._profiler: - pbar = tqdm(total=self._steps_per_epoch) - for idx, batch in enumerate(self._dataloader): - if ( - self.max_steps_per_epoch is not None - and (idx // self._gradient_accumulation_steps) - == self.max_steps_per_epoch - ): - break - - if self._profiler_enabled: - self._profiler.step() - - input_ids, labels = batch - input_ids = input_ids.to(self._device) - num_tokens += input_ids.numel() - labels = labels.to(self._device) - - logits = self._model(input_ids) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) - # Compute loss - loss = self._loss_fn(logits, labels) - loss = loss / self._gradient_accumulation_steps - running_loss += loss - loss.backward() - - # Step with optimizer - if (idx + 1) % self._gradient_accumulation_steps == 0: - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) - self._lr_scheduler.step() - # Update the number of steps when the weights are updated - self.total_training_steps += 1 - - loss_to_log = running_loss.item() - pbar.update(1) - pbar.set_description( - f"{curr_epoch+1}|{self.total_training_steps}|Loss: {loss_to_log}" - ) - - # Log per-step metrics - if self.total_training_steps % self._log_every_n_steps == 0: - time_per_step = time.perf_counter() - t0 - log_dict = { - "loss": loss_to_log, - "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second": num_tokens / time_per_step, - } - if ( - self._device.type == "cuda" - and self._log_peak_memory_stats - ): - log_dict.update( - utils.get_memory_stats(device=self._device) - ) - self._metric_logger.log_dict( - log_dict, - step=self.total_training_steps, - ) - - # Reset running stats for the next step - running_loss = 0 - num_tokens = 0 - t0 = time.perf_counter() - - self.epochs_run += 1 - self.save_checkpoint(epoch=curr_epoch) - - def cleanup(self) -> None: - self._metric_logger.close() - - -@config.parse -def recipe_main(cfg: DictConfig) -> None: - """ - Entry point for the recipe. - - Configurable parameters are read in the following order: - - Parameters specified in config (see available configs through ``tune ls``) - - Overwritten by arguments from the command-line - """ - config.log_config(recipe_name="DoRAFinetuneRecipeSingleDevice", cfg=cfg) - recipe = LoRAFinetuneRecipeSingleDevice(cfg=cfg) - recipe.setup(cfg=cfg) - recipe.train() - recipe.cleanup() - - -if __name__ == "__main__": - sys.exit(recipe_main()) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index bef234e55f..83ba7b1aa4 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -19,6 +19,7 @@ from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, utils from torchtune.modules.peft.peft_utils import ( + activate_dora_params, get_adapter_params, get_merged_lora_ckpt, set_trainable_params, @@ -256,6 +257,7 @@ def _setup_model( self._lora_rank = cfg_model.lora_rank self._lora_alpha = cfg_model.lora_alpha + self.adapter_params = get_adapter_params(model) set_trainable_params(model, self.adapter_params) @@ -274,6 +276,12 @@ def _setup_model( else: lora_missing, lora_unexpected = None, None + if cfg_model.get("use_dora"): + # magnitude vectors for dora are initialized as ones. + # Once the weights are loaded, they are replaced by obtaining the norm of the + # linear weights. Refer https://arxiv.org/pdf/2402.09353 for more details. + activate_dora_params(model) + validate_missing_and_unexpected_for_lora( lora_attn_modules=cfg_model.lora_attn_modules, apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, From 9110f0e1e3f69c61a71f66fe85f617de35dd8300 Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Sat, 1 Jun 2024 02:50:53 +0000 Subject: [PATCH 10/13] fixed model loading bugs and tested training --- ...B_dora_single_device.yml => 8B_dora_single_device.yaml} | 0 torchtune/_recipe_registry.py | 7 ------- torchtune/models/llama3/_model_builders.py | 1 - torchtune/modules/peft/peft_utils.py | 4 ++-- 4 files changed, 2 insertions(+), 10 deletions(-) rename recipes/configs/llama3/{8B_dora_single_device.yml => 8B_dora_single_device.yaml} (100%) diff --git a/recipes/configs/llama3/8B_dora_single_device.yml b/recipes/configs/llama3/8B_dora_single_device.yaml similarity index 100% rename from recipes/configs/llama3/8B_dora_single_device.yml rename to recipes/configs/llama3/8B_dora_single_device.yaml diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 8b99e80e30..a80169b932 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -106,13 +106,6 @@ class Recipe: name="gemma/2B_qlora_single_device", file_path="gemma/2B_qlora_single_device.yaml", ), - ], - supports_distributed=False, - ), - Recipe( - name="dora_finetune_single_device", - file_path="dora_finetune_single_device.py", - configs=[ Config( name="llama3/*B_dora_single_device", file_path="llama3/8B_dora_single_device.yaml", diff --git a/torchtune/models/llama3/_model_builders.py b/torchtune/models/llama3/_model_builders.py index f5cfd2e93f..0b3fdaf7e0 100644 --- a/torchtune/models/llama3/_model_builders.py +++ b/torchtune/models/llama3/_model_builders.py @@ -121,7 +121,6 @@ def lora_llama3_8b( lora_dropout=0.05, use_dora=use_dora, quantize_base=quantize_base, - use_dora=use_dora, ) diff --git a/torchtune/modules/peft/peft_utils.py b/torchtune/modules/peft/peft_utils.py index 06a7ea9063..b7724a652f 100644 --- a/torchtune/modules/peft/peft_utils.py +++ b/torchtune/modules/peft/peft_utils.py @@ -63,11 +63,11 @@ def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]: return adapter_params -def activate_dora_parms(model: nn.Module) -> nn.Module: +def activate_dora_params(model: nn.Module) -> nn.Module: for k, v in model.named_modules(): if hasattr(v, "adapter_params") and callable(v.adapter_params): current_adapter_params = v.adapter_params() - v.init_dora() # TODO(prakyath) check if module is LoraLinear and use_dora is true and then apply. + v.init_dora() @functools.lru_cache() From 733aff4f8c49dc59282c7d6f57cbe6c41e94ff3b Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Sun, 2 Jun 2024 18:45:51 +0000 Subject: [PATCH 11/13] update the lora finetune recipe for lora init --- recipes/lora_finetune_single_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 83ba7b1aa4..84ecb66b6f 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -276,7 +276,7 @@ def _setup_model( else: lora_missing, lora_unexpected = None, None - if cfg_model.get("use_dora"): + if cfg_model.get("use_dora", False): # magnitude vectors for dora are initialized as ones. # Once the weights are loaded, they are replaced by obtaining the norm of the # linear weights. Refer https://arxiv.org/pdf/2402.09353 for more details. From 0ad3b8492bceab5f58bc8446b1840b7fca4b5da5 Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Sun, 2 Jun 2024 18:54:17 +0000 Subject: [PATCH 12/13] updated llama3 docstring with use_dora --- torchtune/models/llama3/_component_builders.py | 2 ++ torchtune/models/llama3/_model_builders.py | 1 + 2 files changed, 3 insertions(+) diff --git a/torchtune/models/llama3/_component_builders.py b/torchtune/models/llama3/_component_builders.py index 678a0e92b8..8f773e8cca 100644 --- a/torchtune/models/llama3/_component_builders.py +++ b/torchtune/models/llama3/_component_builders.py @@ -185,6 +185,7 @@ def lora_llama3( lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Whether to use DORA. Default is ``False``. quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. @@ -299,6 +300,7 @@ def lora_llama3_self_attention( lora_dropout (float): LoRA dropout probability. Default: 0.0 quantize_base (bool): Whether to quantize base model parameters for linear layers LoRA is being applied to. Default is ``False``. + use_dora (bool): Whether to use DORA. Default is ``False``. Returns: CausalSelfAttention: instantiation of self-attention module with LoRA diff --git a/torchtune/models/llama3/_model_builders.py b/torchtune/models/llama3/_model_builders.py index 0b3fdaf7e0..8772dcf9b8 100644 --- a/torchtune/models/llama3/_model_builders.py +++ b/torchtune/models/llama3/_model_builders.py @@ -97,6 +97,7 @@ def lora_llama3_8b( Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation + use_dora (bool): Whether to use DORA. Default is ``False``. quantize_base (bool): Whether to quantize base model weights Returns: From 7b4b8a4dbae45f5b1b436fe902c6ec5006e3e28c Mon Sep 17 00:00:00 2001 From: Prakyath Kantharaju Date: Tue, 11 Jun 2024 02:50:03 +0000 Subject: [PATCH 13/13] changed property to function and fixed typos in the docs --- recipes/configs/llama3/8B_dora_single_device.yaml | 8 ++++---- torchtune/_recipe_registry.py | 2 +- torchtune/models/llama3/_model_builders.py | 6 +++--- torchtune/modules/peft/lora.py | 6 ++---- torchtune/modules/peft/peft_utils.py | 1 - 5 files changed, 10 insertions(+), 13 deletions(-) diff --git a/recipes/configs/llama3/8B_dora_single_device.yaml b/recipes/configs/llama3/8B_dora_single_device.yaml index eff726a665..3c4b2c45b1 100644 --- a/recipes/configs/llama3/8B_dora_single_device.yaml +++ b/recipes/configs/llama3/8B_dora_single_device.yaml @@ -1,4 +1,4 @@ -# Config for single device QLoRA with lora_finetune_single_device.py +# Config for single device DoRA with lora_finetune_single_device.py # using a Llama3 8b Instruct model # # This config assumes that you've run the following command before launching @@ -6,12 +6,12 @@ # tune download meta-llama/Meta-Llama-3-8b-Instruct --output-dir /tmp/Meta-Llama-3-8b-Instruct --hf-token # # To launch on a single device, run the following command from root: -# tune run lora_finetune_single_device --config llama3/8b_qlora_single_device +# tune run lora_finetune_single_device --config llama3/8b_dora_single_device # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune run lora_finetune_single_device --config llama3/8b_qlora_single_device checkpointer.checkpoint_dir= +# tune run lora_finetune_single_device --config llama3/8b_dora_single_device checkpointer.checkpoint_dir= # # This config works only for training on single device. @@ -70,7 +70,7 @@ compile: False # Logging output_dir: /tmp/dora_finetune_output/ metric_logger: - _component_: torchtune.utils.metric_logging.WandBLogger + _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index a80169b932..e817be4108 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -107,7 +107,7 @@ class Recipe: file_path="gemma/2B_qlora_single_device.yaml", ), Config( - name="llama3/*B_dora_single_device", + name="llama3/8B_dora_single_device", file_path="llama3/8B_dora_single_device.yaml", ), ], diff --git a/torchtune/models/llama3/_model_builders.py b/torchtune/models/llama3/_model_builders.py index 8772dcf9b8..f2e4d3fac0 100644 --- a/torchtune/models/llama3/_model_builders.py +++ b/torchtune/models/llama3/_model_builders.py @@ -187,8 +187,8 @@ def lora_llama3_70b( dora_llama3_8b = partial(lora_llama3_8b, use_dora=True) dora_llama3_8b.__doc__ = """ -Builder for creating a Llama3 model with DORA enabled. Base model weights in linear layers -that DORA is applied to are quantized per the Dora paper: https://arxiv.org/abs/2402.09353. -In addition to the lora adaptor weights, DORA also adds a trainable magnitude parameters. +Builder for creating a Llama3 model with DoRA enabled. Base model weights in linear layers +that DoRA is applied to are quantized per the DoRA paper: https://arxiv.org/abs/2402.09353. +In addition to the lora adaptor weights, DoRA also adds a trainable magnitude parameters. Please see `lora_llama3_8b` for full API arguments. """ diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 52307ce7de..3bb6b7f1e7 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -126,10 +126,9 @@ def init_dora(self) -> None: # this is a seperate function because, # this should be called after model state dict is called. # But We verify and initialize the model arch first before the loading weights. - weight_norm = self._dora_weight_norm + weight_norm = self._dora_weight_norm() self.lora_m.data = weight_norm.data # Update the data of 'm' directly - @property def _dora_weight_norm(self) -> Tensor: """ Compute the norm of the linear weight and lora adaptor weights. @@ -145,7 +144,6 @@ def _dora_weight_norm(self) -> Tensor: norm = torch.linalg.norm(result, dim=1) # Clamp the norm to avoid division by zero - # TODO(Prakyath): Check with torchtune team whether this should be a parameter ? norm = torch.clamp(norm, min=1e-6) # Return the norm in NF4 format. return to_nf4(norm) if self._quantize_base else norm @@ -170,7 +168,7 @@ def forward(self, x: Tensor) -> Tensor: # Author mentions this method is faster for the computation purpose: # https://github.com/huggingface/peft/pull/1474#issuecomment-1963402710 if self.use_dora: - weight_norm = self._dora_weight_norm.detach() + weight_norm = self._dora_weight_norm().detach() mag_norm_scale = (self.lora_m / weight_norm).view(1, -1) # PEFT uses: out + (mag_norm_scale - 1) * out + mag_norm_scale * lora_b(lora_a(x)) * scaling. return (out + lora_out) * mag_norm_scale diff --git a/torchtune/modules/peft/peft_utils.py b/torchtune/modules/peft/peft_utils.py index b7724a652f..083671f7f3 100644 --- a/torchtune/modules/peft/peft_utils.py +++ b/torchtune/modules/peft/peft_utils.py @@ -66,7 +66,6 @@ def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]: def activate_dora_params(model: nn.Module) -> nn.Module: for k, v in model.named_modules(): if hasattr(v, "adapter_params") and callable(v.adapter_params): - current_adapter_params = v.adapter_params() v.init_dora()