-
Notifications
You must be signed in to change notification settings - Fork 465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Addition of Dora #936
Changes from all commits
5374cd8
aefb8cb
dffb2a3
cbafe85
f76e076
fe08a06
4a58dba
f096e4b
a55a962
9110f0e
733aff4
0ad3b84
7b4b8a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Config for single device DoRA with lora_finetune_single_device.py | ||
# using a Llama3 8b Instruct model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Meta-Llama-3-8b-Instruct --output-dir /tmp/Meta-Llama-3-8b-Instruct --hf-token <HF_TOKEN> | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run lora_finetune_single_device --config llama3/8b_dora_single_device | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run lora_finetune_single_device --config llama3/8b_dora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only for training on single device. | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3.dora_llama3_8b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] | ||
apply_lora_to_mlp: True | ||
apply_lora_to_output: False | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
use_dora: True | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/original/tokenizer.model | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelMetaCheckpointer | ||
checkpoint_dir: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/original/ | ||
checkpoint_files: [ | ||
consolidated.00.pth | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/ | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
train_on_input: True | ||
seed: 12345678 | ||
shuffle: True | ||
batch_size: 1 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 16 | ||
compile: False | ||
|
||
# Logging | ||
output_dir: /tmp/dora_finetune_output/ | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: True | ||
|
||
# Profiler (disabled) | ||
profiler: | ||
_component_: torchtune.utils.profiler | ||
enabled: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,19 @@ def lora_linear(self, in_dim, out_dim) -> LoRALinear: | |
fixed_init_model(lora_linear) | ||
return lora_linear | ||
|
||
@pytest.fixture | ||
def dora_linear(self, in_dim, out_dim) -> LoRALinear: | ||
lora_linear = LoRALinear( | ||
in_dim=in_dim, | ||
out_dim=out_dim, | ||
rank=RANK, | ||
alpha=ALPHA, | ||
use_bias=False, | ||
use_dora=True, | ||
) | ||
fixed_init_model(lora_linear) | ||
return lora_linear | ||
|
||
@pytest.fixture | ||
def qlora_linear(self, in_dim, out_dim) -> LoRALinear: | ||
with utils.set_default_dtype(torch.bfloat16): | ||
|
@@ -97,6 +110,12 @@ def test_forward(self, inputs, lora_linear, out_dim) -> None: | |
assert actual.shape == (BSZ, SEQ_LEN, out_dim) | ||
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6) | ||
|
||
def test_dora_forward(self, inputs, dora_linear, out_dim) -> None: | ||
expected = torch.tensor(EXPECTED_VAL) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this mean the expected val of DoRA is the same as LoRA? Why is this, and intuitively, I'm not sure if I understand if the result is the exact same, how DoRA results in different training than LoRA? Pretty sure I'm missing something basic here but would be good to clarify. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is the explanation: The m vector is initialized as self.m == weight norm. So the ratio Better explanation from the author: huggingface/peft#1474 (comment) |
||
actual = dora_linear(inputs) | ||
assert actual.shape == (BSZ, SEQ_LEN, out_dim) | ||
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6) | ||
|
||
def test_lora_weight_nf4_when_quantized(self, qlora_linear): | ||
assert isinstance(qlora_linear.weight, NF4Tensor) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no rush on this, but before merging make sure to change the paths to
/tmp/
, the metric logger toDiskLogger
, etc.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the logging to disk logger; I will update this path to /tmp/ before the merge.