Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoftQ: edit README.md and example files #1276

Merged
merged 5 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 96 additions & 25 deletions examples/loftq_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,115 @@

## Introduction

LoftQ provides better initialization for LoRA adapters A and B,
and the Quantization of pre-trained weights W.
LoftQ finds quantized LoRA initialization: quantized backbone Q and LoRA adapters A and B, given a pre-trained weight W.

## Quantization
We recommend to save the quantized backbone model as fp16/fp32
and load it as [NormalFloat4](https://arxiv.org/abs/2305.14314).
## Quick Start
Steps:

We provide a simple example to show how to quantize llama-2-7b model and save/load it.
1. Apply LoftQ to a full-precision pre-trained weight and save.
2. Load LoftQ initialization and train.

For step 1, we have provided off-the-shelf LoftQ initializations (see [supported model list](#appendix-off-the-shelf-model-table))
in [Huggingface Hub LoftQ](https://huggingface.co/LoftQ).
If you want to do it yourself, jump to [LoftQ DIY](#loftq-diy).

For step 2, below is an example of loading 4bit Mistral-7B with 64rank LoRA adapters from Huggingface Hub.
```python
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

MODEL_ID = "LoftQ/Mistral-7B-v0.1-4bit-64rank"

base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16, # you may change it with different models
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16, # bfloat16 is recommended
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4',
),
)
peft_model = PeftModel.from_pretrained(
base_model,
MODEL_ID,
subfolder="loftq_init",
is_trainable=True,
)

# Do training with peft_model ...
```

## LoftQ DIY

### Apply LoftQ and save
We provide [quantize_save_load.py](quantize_save_load.py) as an example to apply LoftQ with
different bits(`--bits`), ranks(`--rank`), and alternating steps (`--iter`, a hyper-parameter in LoftQ, see Algorithm 1 in [LoftQ paper](https://arxiv.org/abs/2310.08659)). Currently, this example supports
`llama-2`, `falcon`, `mistral`, `bart`, `t5`, `deberta`, `bert`, `roberta`.

Below is an example of obtaining 4bit LLAMA-2-7b with 16-rank LoRA adapters by 5 alternating steps.
```sh
SAVE_DIR="model_zoo/loftq/"
python quantize_save_load.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--token HF_TOKEN \
--bits 4 --iter 5 --rank 16 \
--save_dir model_zoo/loftq/
--model_name_or_path meta-llama/Llama-2-7b-hf \ # high-precision model id in HF
--token HF_TOKEN \ # your HF token if the model is private, e.g., llama-2
--bits 4 \
--iter 5 \
--rank 16 \
--save_dir $SAVE_DIR
```

- `HF_TOKEN` is the token used to access to [LLAMA models](https://huggingface.co/meta-llama).
- `quantize_and_save()` function will quantize the backbone and initialize LoRA adapters.
It creates 2 folders under `$save_dir`. The quantized backbone is at `Llama-2-7b-hf-4bit-16rank`,
and the LoRA adapters are at the sub-folder `Llama-2-7b-hf-4bit-16rank/loftq_init`.
The above commands end up with creating the model directory under `$SAVE_DIR`.
Specifically, the model directory is named as

## Fine-tuning
`MODEL_DIR = SAVE_DIR + f"{args.model_name_or_path.split('/')[-1]}-{args.bits}bits-{args.rank}rank"`

Here is an example to load the quantized backbone and LoRA adapters:
In this example, `MODEL_DIR="model_zoo/loftq/Llama-2-7b-hf-4bit-16rank"`, where the backbone is stored in `$MODEL_DIR`
and the LoRA adapters are at the sub-folder `$MODEL_DIR/loftq_init`.

```python
import os
### Load and train
Similar to loading from Huggingface Hub, we only need to change the `MODEL_ID` to the `MODEL_DIR`.

from transformers import AutoModelForCausalLM
```python
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

MODEL_DIR = "model_zoo/loftq/Llama-2-7b-hf-4bit-16rank"

base_model = AutoModelForCausalLM.from_pretrained(
os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank"),
load_in_4bit=True,
MODEL_DIR,
torch_dtype=torch.bfloat16,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4',
),
)
peft_model = PeftModel.from_pretrained(
base_model,
os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank", "loftq_init"),
MODEL_DIR,
subfolder="loftq_init",
is_trainable=True,
)
# Do training with peft_model ...
```

## LoftQ Fine-tuning

We also provide an example to fine-tune LoftQ on GSM8K.
We load the quantized backbone and LoRA adapters from the [LoftQ Huggingface hub](https://huggingface.co/LoftQ).

```sh
python train_gsm8k_llama.py \
--model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \
--output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \
--learning_rate 3e-4 \
--model_name_or_path LoftQ/Llama-2-13b-hf-4bit-64rank \
--output_dir exp_results/gsm8k/llama-2-13b/bit4-rank64/lr1e-4 \
--learning_rate 1e-4 \
--weight_decay 0.1 \
--lr_scheduler_type cosine \
--num_warmup_steps 100 \
--seed 202 \
--dataset_name gsm8k \
--dataset_config main \
Expand All @@ -67,3 +124,17 @@ python train_gsm8k_llama.py \
--with_tracking \
--report_to tensorboard
```


## Appendix: Off-the-shelf Model List
| Model Name | Bits | Ranks |
| ----------- | ---- | ----- |
| LLAMA-2-7b | 4 | 64 |
| LLAMA-2-13b | 4 | 64 |
| LLAMA-2-70b | 4 | 64 |
| Mistral | 4 | 64 |
| Mistral | 4 | 32 |
| BART-large | 4 | 8 |
| BART-large | 4 | 16 |
| BART-large | 4 | 32 |
| BART-large | 2 | 8 |
49 changes: 1 addition & 48 deletions examples/loftq_finetuning/quantize_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
BitsAndBytesConfig,
)

from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model


class Shell(nn.Module):
Expand Down Expand Up @@ -184,54 +183,8 @@ def quantize_and_save():
return base_model_dir, lora_model_dir


def load_loftq(base_model_path, lora_adapter_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this has been removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function was supposed to confirm that everything works fine after the loftq weight initialization step. Hence, not a required step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

if any(name in base_model_path.lower() for name in ["llama", "mistral", "falcon"]):
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
device_map="auto",
low_cpu_mem_usage=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
)
elif any(name in base_model_path.lower() for name in ["bart", "t5"]):
model = AutoModelForSeq2SeqLM.from_pretrained(
base_model_path,
device_map="auto",
low_cpu_mem_usage=True,
load_in_4bit=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
)
elif any(name in base_model_path.lower() for name in ["deberta", "roberta", "bert"]):
model = AutoModelForSequenceClassification.from_pretrained(
base_model_path,
low_cpu_mem_usage=True,
load_in_4bit=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
)
else:
raise NotImplementedError("Other models not supported yet.")

lora_model = PeftModel.from_pretrained(model, lora_adapter_path, is_trainable=True)

# Do training or inference below
print_model(lora_model, "lora_model")
print_model(model, "base_model")


if __name__ == "__main__":
base_dir, lora_dir = quantize_and_save()
load_loftq(base_dir, lora_dir)

# example command:
# python quantize_save_load.py \
Expand Down
23 changes: 2 additions & 21 deletions examples/loftq_finetuning/train_gsm8k_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,8 @@ def preprocess_function_test(examples):
if args.with_tracking:
total_loss += loss.detach().float()
accelerator.backward(loss)
accelerator.print(f"Epoch: {epoch} | Step: {step} | Loss: {loss}")
if completed_steps % 50:
accelerator.print(f"Epoch: {epoch} | Step: {completed_steps} | Loss: {loss}")
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Expand Down Expand Up @@ -843,23 +844,3 @@ def compute_accuracy(pred: list, gold: list):

if __name__ == "__main__":
main()

# example command

# python train_gsm8k_llama.py \
# --model_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-backbone \
# --adapter_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-adapters \
# --output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \
# --learning_rate 1e-4 \
# --seed 202 \
# --dataset_name gsm8k \
# --dataset_config main \
# --pad_to_max_length \
# --max_source_length 128 \
# --max_target_length 256 \
# --num_train_epochs 5 \
# --per_device_train_batch_size 4 \
# --per_device_eval_batch_size 4 \
# --gradient_accumulation_steps 4 \
# --with_tracking \
# --report_to tensorboard
2 changes: 1 addition & 1 deletion src/peft/utils/loftq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r
f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} "
f"| Num Iter: {num_iter} | Num Bits: {num_bits}"
)
if not is_bnb_4bit_available():
if not is_bnb_4bit_available() or num_bits in [2, 8]:
quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
compute_device = device
else:
Expand Down
Loading