Skip to content

Commit

Permalink
blora LlaMa support vllm-project#1
Browse files Browse the repository at this point in the history
  • Loading branch information
l1cacheDell committed Nov 15, 2023
1 parent 9004314 commit 424df61
Show file tree
Hide file tree
Showing 11 changed files with 470 additions and 7 deletions.
48 changes: 48 additions & 0 deletions examples/llama_test_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from vllm import LLM, SamplingParams
import time
if __name__ == "__main__":
prompt = "Hello and welcome, "
prompts = [prompt]
path = "./baichuan2-13b"
path = "/vllm_workspace/weights/llama_7b_hf"
lora_path = "./baichuan2-13b-20231013174626"
lora_path = "/vllm_workspace/weights/alpaca-lora-7b"
lora_path_2 = "./baichuan2-13b-20231013192059"
lora_path_2 = "/vllm_workspace/weights/bactrian-x-llama-7b-lora"
llm = LLM(model=path,
trust_remote_code=True,
lora_paths=[lora_path, lora_path_2],
adapter_names=["adapter_1", "adapter_2"])

print(llm.llm_engine.workers[0].model)

sampling_params = SamplingParams(temperature=0,
top_p=1,
best_of=2,
top_k=-1,
max_tokens=100,
use_beam_search=True,
lora_id="adapter_1")
llm._add_request(prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params)

sampling_params = SamplingParams(temperature=0,
top_p=1,
best_of=2,
top_k=-1,
max_tokens=100,
use_beam_search=True,
lora_id="adapter_2")
llm._add_request(prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params)
start = time.time()
outputs = llm._run_engine(use_tqdm=True)
end = time.time()
print(f"cost: {end - start}")
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
29 changes: 28 additions & 1 deletion vllm/engine/arg_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, List

from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
Expand Down Expand Up @@ -32,6 +32,11 @@ class EngineArgs:
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None

# MODIFY
lora_paths: Optional[List[str]] = None
adapter_names: Optional[List[str]] = None
# END

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -171,6 +176,28 @@ def add_cli_args(
choices=['awq', 'squeezellm', None],
default=None,
help='Method used to quantize the weights')

# MODIFY
parser.add_argument(
'--lora-paths',
metavar='path',
type=str,
default=None,
nargs='+',
help='the paths of lora model you want to load:' +
'[lora_path1 lora_path2 ...]')

parser.add_argument(
'--adapter-names',
metavar='adapter_name',
type=str,
default=None,
nargs='+',
help='the adapter names of lora model you want to load, each name'
+ ' should be unique and needs to correspond to the path ' +
'one-to-one: [name1 name2 ...]')
# END

return parser

@classmethod
Expand Down
17 changes: 16 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,20 @@ def from_engine_args(cls,
# Initialize the cluster.
distributed_init_method, placement_group = initialize_cluster(
parallel_config, engine_args.engine_use_ray)

# =====================
# MODIFY HERE
lora_paths: list = engine_args.lora_paths
adapter_names: list = engine_args.adapter_names
lora_configs = None
if lora_paths is not None and adapter_names is not None:
assert len(lora_paths) == len(adapter_names), (len(lora_paths), len(adapter_names))
lora_configs = []
for lora_path, adapter_name in zip(lora_paths, adapter_names):
lora_configs.append((lora_path, adapter_name))
# =====================


# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
Expand All @@ -499,5 +513,6 @@ def from_engine_args(cls,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
start_engine_loop=start_engine_loop,
lora_configs=lora_configs) # MODIFY HERE
return engine
21 changes: 20 additions & 1 deletion vllm/engine/llm_engine.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
distributed_init_method: str,
placement_group: Optional["PlacementGroup"],
log_stats: bool,
lora_configs: List[Tuple[str, str]] = None # MODIFY
) -> None:
logger.info(
"Initializing an LLM engine with config: "
Expand All @@ -93,6 +94,7 @@ def __init__(
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self.lora_configs = lora_configs # MODIFY
self._verify_args()

self.tokenizer = get_tokenizer(
Expand Down Expand Up @@ -137,6 +139,7 @@ def _init_workers(self, distributed_init_method: str):
self.scheduler_config,
0,
distributed_init_method,
self.lora_configs, # MODIFY
)
self.workers.append(worker)
self._run_workers(
Expand Down Expand Up @@ -169,6 +172,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
lora_configs = copy.deepcopy(self.lora_configs) # MODIFY
self._run_workers("init_worker",
get_all_outputs=True,
worker_init_fn=lambda: Worker(
Expand All @@ -177,6 +181,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
scheduler_config,
None,
None,
lora_configs, # MODIFY
))
self._run_workers(
"init_model",
Expand Down Expand Up @@ -227,11 +232,25 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
# Initialize the cluster.
distributed_init_method, placement_group = initialize_cluster(
parallel_config)

# MODIFY
lora_paths: list = engine_args.lora_paths
adapter_names: list = engine_args.adapter_names
lora_configs = None
if lora_paths is not None and adapter_names is not None:
assert len(lora_paths) == len(adapter_names), (len(lora_paths),
len(adapter_names))
lora_configs = []
for lora_path, adapter_name in zip(lora_paths, adapter_names):
lora_configs.append((lora_path, adapter_name))
# END

# Create the LLM engine.
engine = cls(*engine_configs,
distributed_init_method,
placement_group,
log_stats=not engine_args.disable_log_stats)
log_stats=not engine_args.disable_log_stats,
lora_configs=lora_configs,) # MODIFY
return engine

def add_request(
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/llm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
lora_paths: List[str] = None,
adapter_names: List[str] = None,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
Expand All @@ -88,6 +90,8 @@ def __init__(
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
lora_paths=lora_paths, # MODIFY
adapter_names=adapter_names,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
Expand Down
134 changes: 134 additions & 0 deletions vllm/model_executor/lora_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from vllm.model_executor.parallel_utils.layers import BLoraColumnParallelLinear, BLoraRowParallelLinear, ColumnParallelLinear, RowParallelLinear
from peft.tuners.lora import LoraLayer
from peft import LoraConfig
import re
import torch

WEIGHTS_NAME = "adapter_model.bin"
PREFIX = "base_model.model."
PARAMETER_PREFIX = "lora_"


def _get_submodules(model, key):
parent = model.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
target = model.get_submodule(key)
return parent, target, target_name


def _create_new_module(lora_config, adapter_name, target):
lora_alpha = lora_config.lora_alpha
r = lora_config.r
lora_dropout = lora_config.lora_dropout
if isinstance(target, ColumnParallelLinear):
new_module = BLoraColumnParallelLinear(
input_size=target.input_size,
output_size=target.output_size_per_partition,
adapter_name=adapter_name,
bias=target.bias,
gather_output=target.gather_output,
skip_bias_add=target.skip_bias_add,
quant_config=target.quant_config,
lora_alpha=lora_alpha,
r=r,
lora_dropout=lora_dropout)
return new_module
if isinstance(target, RowParallelLinear):
new_module = BLoraRowParallelLinear(
input_size=target.input_size_per_partition,
output_size=target.output_size,
adapter_name=adapter_name,
bias=target.bias,
input_is_parallel=target.input_is_parallel,
reduce_results=target.reduce_results,
skip_bias_add=target.skip_bias_add,
quant_config=target.quant_config,
lora_alpha=lora_alpha,
r=r,
lora_dropout=lora_dropout)
return new_module


def _replace_module(parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
new_module.weight = child.weight
if getattr(child, "state", None) is not None:
new_module.state = child.state
new_module.to(child.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(child.weight.device)


def _create_and_replace(lora_config, adapter_name, target, target_name,
parent):
if (isinstance(target, (ColumnParallelLinear, RowParallelLinear))
and not isinstance(target, LoraLayer)):
new_module = _create_new_module(lora_config, adapter_name, target)
_replace_module(parent, target_name, new_module, target)
elif isinstance(target, LoraLayer):
target.update_layer(adapter_name, lora_config.r,
lora_config.lora_alpha, lora_config.lora_dropout,
lora_config.init_lora_weights)


def add_lora_adapter(model: torch.nn.Module,
lora_path: str,
adapter_name: str):
lora_config = LoraConfig.from_pretrained(lora_path,
revision=None,
use_auth_token=None)
key_list = [key for key, _ in model.named_modules()]

# iterate the modules of LLaMa to insert the LoRA adapter

# TODO: we should re-construct the logic from here to fit LlaMa LoRA

for key in key_list:
# find target module
target_module_found = any(
re.match(f".*\\.{target_key}$", key)
for target_key in lora_config.target_modules) or any(
target_key == key for target_key in lora_config.target_modules)
if not target_module_found:
continue
parent, target, target_name = _get_submodules(model, key)
print(f"parent: {parent}, ")

# create and replace
_create_and_replace(lora_config, adapter_name, target, target_name,
parent)

adapters_weights = torch.load(f"{lora_path}/{WEIGHTS_NAME}")

processed_adapter_state_dict = {}
for key, value in adapters_weights.items():
if key.startswith(PREFIX):
new_key = key[len(PREFIX):]
else:
new_key = key
processed_adapter_state_dict[new_key] = value

state_dict = {}
for k, v in processed_adapter_state_dict.items():
if PARAMETER_PREFIX in k:
suffix = k.split(PARAMETER_PREFIX)[1]
if "." in suffix:
to_replace = ".".join(suffix.split(".")[1:])
k = k.replace(to_replace, f"{adapter_name}.{to_replace}")
else:
k = f"{k}.{adapter_name}"
state_dict[k] = v

# print("====== LORA ======")
# for name in state_dict.keys():
# print(name)

# print("====== MODEL ======")
# for name in model.state_dict().keys():
# print(name)


model.load_lora_weights_parallel(state_dict)
model.cuda()
22 changes: 20 additions & 2 deletions vllm/model_executor/model_loader.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Type
from typing import Type, List, Tuple

import torch
import torch.nn as nn
Expand All @@ -11,6 +11,8 @@
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)

from vllm.model_executor.lora_utils import add_lora_adapter # MODIFY

# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"AquilaModel": AquilaForCausalLM,
Expand Down Expand Up @@ -64,7 +66,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")


def get_model(model_config: ModelConfig) -> nn.Module:
def get_model(model_config: ModelConfig,
lora_configs: List[Tuple[str, str]] = None) -> nn.Module: # MODIFY
model_class = _get_model_architecture(model_config.hf_config)

# Get the quantization config.
Expand Down Expand Up @@ -108,4 +111,19 @@ def get_model(model_config: ModelConfig) -> nn.Module:
model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format, model_config.revision)
model = model.cuda()

# print("====== MODEL ======")
# for name in model.state_dict().keys():
# print(name)

# MODIFY
# load lora adapter
if lora_configs is not None:
for lora_config in lora_configs:
lora_path = lora_config[0]
adapter_name = lora_config[1]
add_lora_adapter(model=model,
lora_path=lora_path,
adapter_name=adapter_name)
# END
return model.eval()
Loading

0 comments on commit 424df61

Please sign in to comment.