Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoftQ: Add LoftQ method integrated into LoRA. Add example code for LoftQ usage. #1150

Merged
merged 12 commits into from
Nov 29, 2023
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Supported methods:
7. MultiTask Prompt Tuning: [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://arxiv.org/abs/2303.02861)
8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098)
9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation
10. LoftQ: [LoftQ: LoRA-Fine-Tuning-aware Quantization for Large Language Models](https://arxiv.org/abs/2310.08659)

## Getting started

Expand Down
68 changes: 68 additions & 0 deletions examples/loftq_finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# LoftQ: LoRA-fine-tuning-aware Quantization

## Introduction

LoftQ provides better initialization for LoRA adaptors A and B,
yxli2123 marked this conversation as resolved.
Show resolved Hide resolved
and the Quantization of pre-trained weights W.

## Quantization
We recommend to save the quantized backbone model as fp16/fp32
and load it as [NormalFloat4](https://arxiv.org/abs/2305.14314).

We provide a simple example to show how to quantize llama-2-7b model and save/load it.

```sh
python quantize_save_load.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--token HF_TOKEN \
--bits 4 --iter 5 --rank 16 \
--save_dir model_zoo/loftq/
```

- `HF_TOKEN` is the token used to access to [LLAMA models](https://huggingface.co/meta-llama).
- `quantize_and_save()` function will quantize the backbone and initialize LoRA adapters.
It creates 2 folders under `$save_dir`. The quantized backbone is at `Llama-2-7b-hf-4bit-16rank-backbone`,
and the LoRA adapters are at `Llama-2-7b-hf-4bit-16rank-adapters`.

## Fine-tuning

Here is an example to load the quantized backbone and LoRA adapters:

```python
import os

from transformers import AutoModelForCausalLM
from peft import PeftModel


base_model = AutoModelForCausalLM.from_pretrained(os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-backbone"),
load_in_4bit=True,
)
peft_model = PeftModel.from_pretrained(base_model,
os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-adapters",
is_trainable=True),
)
yxli2123 marked this conversation as resolved.
Show resolved Hide resolved
```

We also provide an example to fine-tune LoftQ on GSM8K.
We load the quantized backbone and LoRA adapters from the [LoftQ Huggingface hub](https://huggingface.co/LoftQ).

```sh
python train_gsm8k_llama.py \
--model_name_or_path LoftQ/Llama-2-7b-hf-6bit-64rank-backbone \
--adapter_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank-adapters \
--output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \
--learning_rate 3e-4 \
--seed 202 \
--dataset_name gsm8k \
--dataset_config main \
--pad_to_max_length \
--max_source_length 128 \
--max_target_length 256 \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--with_tracking \
--report_to tensorboard
```
198 changes: 198 additions & 0 deletions examples/loftq_finetuning/quantize_save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import argparse
import os

import torch
import torch.nn as nn
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)

from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model


class Shell(nn.Module):
def __init__(self, weight, bias=None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=False)
if bias is not None:
self.bias = nn.Parameter(bias, requires_grad=False)


def unwarap_model(model, sub_module_name=".base_layer"):
sub_module_name_list = [k.split(sub_module_name)[0] for k in model.state_dict().keys() if sub_module_name in k]
sub_module_name_set = set(sub_module_name_list)
for name in sub_module_name_set:
# get the parent of the submodule
name_parent = ".".join(name.split(".")[:-1])
name_child = name.split(".")[-1]
sub_module = model.get_submodule(name_parent)
print(sub_module)

# replace with shell
child = getattr(sub_module, name_child)
weight = getattr(child.base_layer, "weight", None)
bias = getattr(child.base_layer, "bias", None)
shell = Shell(weight, bias)

setattr(sub_module, name_child, shell)

print("You have unwrapped the model. Use it on your own risk.")


def print_model(model, name):
print("=" * 10 + name + "=" * 10)
print(model)
for name, param in model.named_parameters():
if torch.is_tensor(param):
if param.dtype in [torch.float32, torch.float16]:
print(
name,
param.shape,
param.device,
param.dtype,
param.requires_grad,
param.mean().item(),
param.max().item(),
)
else:
print(name, param.shape, param.device, param.dtype, param.requires_grad)


def arg_parse():
parser = argparse.ArgumentParser(description="Quantize a model with LoftQ.")
parser.add_argument(
"--model_name_or_path",
type=str,
default=None,
required=True,
help="The name or path of the fp32/16 model.",
)
parser.add_argument(
"--token",
type=str,
default=None,
help="The access token to download model from HuggingFace Hub.",
)
parser.add_argument(
"--bits",
type=int,
default=4,
help="The quantized bits",
)
parser.add_argument(
"--iter",
type=int,
default=1,
help="The alternating steps in LoftQ",
)
parser.add_argument(
"--rank",
type=int,
default=16,
help="The rank of the LoRA adapter",
)
parser.add_argument(
"--save_dir",
type=str,
default="./model_zoo/loftq/",
help="The rank of the LoRA adapter",
)
args = parser.parse_args()
return args


def quantize_and_save():
args = arg_parse()

# Download weights and configure LoRA
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token)
if "llama" in args.model_name_or_path.lower():
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto")
task_type = TaskType.CAUSAL_LM
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"]

elif "bart" in args.model_name_or_path.lower():
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto")
task_type = TaskType.SEQ_2_SEQ_LM
target_modules = ["q_proj", "k_proj", "v_proj", "fc1", "fc2", "out_proj"]

elif "deberta" in args.model_name_or_path.lower():
model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, token=args.token)
model = model.cuda()
task_type = TaskType.SEQ_CLS
target_modules = ["query_proj", "key_proj", "value_proj", "dense"] # embeddings not supported by peft
else:
raise NotImplementedError("Other models not supported yet.")

# Config of LoftQ
loftq_config = LoftQConfig(loftq_bits=args.bits, loftq_iter=args.iter, loftq_fake=True)

lora_config = LoraConfig(
task_type=task_type,
inference_mode=True,
r=args.rank,
lora_alpha=args.rank,
lora_dropout=0.1,
target_modules=target_modules,
init_lora_weights="loftq",
loftq_config=loftq_config,
)

# Obtain LoftQ model
lora_model = get_peft_model(model, lora_config)
base_model = lora_model.get_base_model()

# Save LoftQ model
model_name = args.model_name_or_path.split("/")[-1] + f"-{args.bits}bit" + f"-{args.rank}rank"
base_model_dir = os.path.join(args.save_dir, model_name + "-backbone")
lora_model_dir = os.path.join(args.save_dir, model_name + "-adapters")

# save lora adapters first
lora_model.base_model.peft_config[
"default"
].base_model_name_or_path = base_model_dir # This can be a local path or Hub model id
lora_model.base_model.peft_config["default"].init_lora_weights = True # Don't apply LoftQ when loading again

lora_model.save_pretrained(lora_model_dir)
print_model(lora_model, "lora_model")

# remove lora adapters and save the backbone
unwarap_model(base_model)
base_model.save_pretrained(base_model_dir)
tokenizer.save_pretrained(base_model_dir)

print_model(base_model, "base_model")

return base_model_dir, lora_model_dir


def load_loftq(base_model_path, lora_adapter_path):
if "llama" in base_model_path.lower():
model = AutoModelForCausalLM.from_pretrained(base_model_path, device_map="auto", load_in_4bit=True)
elif "bart" in base_model_path.lower():
model = AutoModelForSeq2SeqLM.from_pretrained(base_model_path, device_map="auto", load_in_4bit=True)
elif "deberta" in base_model_path.lower():
model = AutoModelForSequenceClassification.from_pretrained(base_model_path, load_in_4bit=True)
else:
raise NotImplementedError("Other models not supported yet.")

lora_model = PeftModel.from_pretrained(model, lora_adapter_path, is_trainable=True)

# Do training or inference below
print_model(lora_model, "lora_model")
print_model(model, "base_model")


if __name__ == "__main__":
base_dir, lora_dir = quantize_and_save()
load_loftq(base_dir, lora_dir)

# example command:
# python quantize_save_load.py \
# --model_name_or_path meta-llama/Llama-2-7b-hf \
# --token XXX \
# --bits 4 --iter 5 --rank 16 \
# --save_dir ./model_zoo/loftq/
Loading
Loading