Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Addition of Dora #936

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions recipes/configs/llama3/8B_dora_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Config for single device DoRA with lora_finetune_single_device.py
# using a Llama3 8b Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3-8b-Instruct --output-dir /tmp/Meta-Llama-3-8b-Instruct --hf-token <HF_TOKEN>
#
# To launch on a single device, run the following command from root:
# tune run lora_finetune_single_device --config llama3/8b_dora_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run lora_finetune_single_device --config llama3/8b_dora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Model Arguments
model:
_component_: torchtune.models.llama3.dora_llama3_8b
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16
use_dora: True

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/original/tokenizer.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no rush on this, but before merging make sure to change the paths to /tmp/, the metric logger to DiskLogger, etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the logging to disk logger; I will update this path to /tmp/ before the merge.


checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/original/
checkpoint_files: [
consolidated.00.pth
]
recipe_checkpoint: null
output_dir: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
train_on_input: True
seed: 12345678
shuffle: True
batch_size: 1

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 16
compile: False

# Logging
output_dir: /tmp/dora_finetune_output/
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True

# Profiler (disabled)
profiler:
_component_: torchtune.utils.profiler
enabled: False
8 changes: 8 additions & 0 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.modules.peft.peft_utils import (
activate_dora_params,
get_adapter_params,
get_merged_lora_ckpt,
set_trainable_params,
Expand Down Expand Up @@ -256,6 +257,7 @@ def _setup_model(

self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha

self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

Expand All @@ -274,6 +276,12 @@ def _setup_model(
else:
lora_missing, lora_unexpected = None, None

if cfg_model.get("use_dora", False):
# magnitude vectors for dora are initialized as ones.
# Once the weights are loaded, they are replaced by obtaining the norm of the
# linear weights. Refer https://arxiv.org/pdf/2402.09353 for more details.
activate_dora_params(model)

validate_missing_and_unexpected_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
Expand Down
19 changes: 19 additions & 0 deletions tests/torchtune/modules/peft/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ def lora_linear(self, in_dim, out_dim) -> LoRALinear:
fixed_init_model(lora_linear)
return lora_linear

@pytest.fixture
def dora_linear(self, in_dim, out_dim) -> LoRALinear:
lora_linear = LoRALinear(
in_dim=in_dim,
out_dim=out_dim,
rank=RANK,
alpha=ALPHA,
use_bias=False,
use_dora=True,
)
fixed_init_model(lora_linear)
return lora_linear

@pytest.fixture
def qlora_linear(self, in_dim, out_dim) -> LoRALinear:
with utils.set_default_dtype(torch.bfloat16):
Expand Down Expand Up @@ -97,6 +110,12 @@ def test_forward(self, inputs, lora_linear, out_dim) -> None:
assert actual.shape == (BSZ, SEQ_LEN, out_dim)
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)

def test_dora_forward(self, inputs, dora_linear, out_dim) -> None:
expected = torch.tensor(EXPECTED_VAL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the expected val of DoRA is the same as LoRA? Why is this, and intuitively, I'm not sure if I understand if the result is the exact same, how DoRA results in different training than LoRA? Pretty sure I'm missing something basic here but would be good to clarify.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the explanation:
$$DORA = W*x + (\frac{m}{weight norm} -1)*W *x + \frac{m}{weight norm} * {lora}_b({lora}_a(x)) + {scaling}$$

The m vector is initialized as self.m == weight norm. So the ratio $\frac{m}{weight norm}$ is 1 for the first iteration.
so LORA == DORA for the first pass.

Better explanation from the author: huggingface/peft#1474 (comment)

actual = dora_linear(inputs)
assert actual.shape == (BSZ, SEQ_LEN, out_dim)
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)

def test_lora_weight_nf4_when_quantized(self, qlora_linear):
assert isinstance(qlora_linear.weight, NF4Tensor)

Expand Down
4 changes: 4 additions & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class Recipe:
name="gemma/2B_qlora_single_device",
file_path="gemma/2B_qlora_single_device.yaml",
),
Config(
name="llama3/8B_dora_single_device",
file_path="llama3/8B_dora_single_device.yaml",
),
],
supports_distributed=False,
),
Expand Down
2 changes: 2 additions & 0 deletions torchtune/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ._component_builders import llama3, lora_llama3

from ._model_builders import ( # noqa
dora_llama3_8b,
llama3_70b,
llama3_8b,
llama3_tokenizer,
Expand All @@ -26,4 +27,5 @@
"lora_llama3_70b",
"qlora_llama3_8b",
"scale_hidden_dim_for_mlp",
"dora_llama3_8b",
]
17 changes: 16 additions & 1 deletion torchtune/models/llama3/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def lora_llama3(
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
# dora args
use_dora: bool = False,
# Quantization args
quantize_base: bool = False,
) -> TransformerDecoder:
Expand Down Expand Up @@ -183,6 +185,7 @@ def lora_llama3(
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
use_dora (bool): Whether to use DORA. Default is ``False``.
quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
weights within linear layers LoRA is applied to. The final output linear projection is not
supported for quantization currently.
Expand All @@ -204,6 +207,7 @@ def lora_llama3(
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)

Expand All @@ -214,6 +218,7 @@ def lora_llama3(
hidden_dim=hidden_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
use_dora=use_dora,
quantize_base=quantize_base,
)
else:
Expand All @@ -230,7 +235,7 @@ def lora_llama3(

# TODO: quantize_base is not applied to final output_proj currently.
output_proj = (
LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha)
LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, use_dora=use_dora)
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
Expand Down Expand Up @@ -270,6 +275,7 @@ def lora_llama3_self_attention(
lora_alpha: float,
lora_dropout: float = 0.0,
quantize_base: bool = False,
use_dora: bool = False,
) -> CausalSelfAttention:
"""
Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA
Expand All @@ -294,6 +300,7 @@ def lora_llama3_self_attention(
lora_dropout (float): LoRA dropout probability. Default: 0.0
quantize_base (bool): Whether to quantize base model parameters for linear layers
LoRA is being applied to. Default is ``False``.
use_dora (bool): Whether to use DORA. Default is ``False``.

Returns:
CausalSelfAttention: instantiation of self-attention module with LoRA
Expand All @@ -316,6 +323,7 @@ def lora_llama3_self_attention(
rank=lora_rank,
alpha=lora_alpha,
quantize_base=quantize_base,
use_dora=use_dora,
)
if "q_proj" in lora_modules
else nn.Linear(embed_dim, num_heads * head_dim, bias=False)
Expand All @@ -327,6 +335,7 @@ def lora_llama3_self_attention(
rank=lora_rank,
alpha=lora_alpha,
quantize_base=quantize_base,
use_dora=use_dora,
)
if "k_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
Expand All @@ -338,6 +347,7 @@ def lora_llama3_self_attention(
rank=lora_rank,
alpha=lora_alpha,
quantize_base=quantize_base,
use_dora=use_dora,
)
if "v_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
Expand All @@ -349,6 +359,7 @@ def lora_llama3_self_attention(
rank=lora_rank,
alpha=lora_alpha,
quantize_base=quantize_base,
use_dora=use_dora,
)
if "output_proj" in lora_modules
else nn.Linear(embed_dim, embed_dim, bias=False)
Expand Down Expand Up @@ -377,6 +388,7 @@ def lora_llama3_mlp(
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> FeedForward:
gate_proj = LoRALinear(
Expand All @@ -386,6 +398,7 @@ def lora_llama3_mlp(
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
use_dora=use_dora,
)
down_proj = LoRALinear(
in_dim=hidden_dim,
Expand All @@ -394,6 +407,7 @@ def lora_llama3_mlp(
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
use_dora=use_dora,
)
up_proj = LoRALinear(
in_dim=dim,
Expand All @@ -402,6 +416,7 @@ def lora_llama3_mlp(
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
use_dora=use_dora,
)
return FeedForward(
gate_proj=gate_proj,
Expand Down
12 changes: 12 additions & 0 deletions torchtune/models/llama3/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def lora_llama3_8b(
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
use_dora: bool = False,
quantize_base: bool = False,
) -> TransformerDecoder:
"""
Expand All @@ -96,6 +97,7 @@ def lora_llama3_8b(
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
use_dora (bool): Whether to use DORA. Default is ``False``.
quantize_base (bool): Whether to quantize base model weights

Returns:
Expand All @@ -118,6 +120,7 @@ def lora_llama3_8b(
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=0.05,
use_dora=use_dora,
quantize_base=quantize_base,
)

Expand Down Expand Up @@ -180,3 +183,12 @@ def lora_llama3_70b(
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
Please see `lora_llama3_8b` for full API arguments.
"""

dora_llama3_8b = partial(lora_llama3_8b, use_dora=True)

dora_llama3_8b.__doc__ = """
Builder for creating a Llama3 model with DoRA enabled. Base model weights in linear layers
that DoRA is applied to are quantized per the DoRA paper: https://arxiv.org/abs/2402.09353.
In addition to the lora adaptor weights, DoRA also adds a trainable magnitude parameters.
Please see `lora_llama3_8b` for full API arguments.
"""
Loading