Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lora] Support long context lora #4787

Merged
merged 29 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ exclude = [

[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data"

[tool.isort]
use_parentheses = true
Expand Down
57 changes: 57 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,41 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model

LONG_LORA_INFOS = [
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be moved to hf hub

{
"lora_id":
1,
"context_length":
"16k",
"local_path":
"/mnt/local_storage/long_context_checkpoint_16k",
"lora":
"s3://endpoints-finetune-mirrors/dev-test/long-context-test/model_1/lora/",
"merged":
"s3://endpoints-finetune-mirrors/dev-test/long-context-test/model_1/merged/"
},
{
"lora_id": 2,
"context_length": "16k",
"local_path": "/mnt/local_storage/long_context_checkpoint_16k_2/",
"lora":
"s3://endpoints-finetune-mirrors/dev-test/long-context-test/model_2/lora/",
"merged": None # This model has not been merged
},
{
"lora_id":
3,
"context_length":
"32k",
"local_path":
"/mnt/local_storage/long_context_checkpoint_32k",
"lora":
"s3://endpoints-finetune-mirrors/dev-test/long-context-test/model_3/lora/",
"merged":
"s3://endpoints-finetune-mirrors/dev-test/long-context-test/model_3/merged/"
}
]


def cleanup():
destroy_model_parallel()
Expand Down Expand Up @@ -154,6 +189,28 @@ def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


# SANG-TODO Download long lora files.
@pytest.fixture(scope="session")
def long_context_infos():
import subprocess
infos = {}
for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"]
local_lora_path = lora_checkpoint_info['local_path'] + "/lora"
print(
f"Downloading {lora_checkpoint_info['lora']} to {local_lora_path} "
)
subprocess.run([
"aws", "s3", "sync", "--quiet", lora_checkpoint_info["lora"],
local_lora_path
])
infos[lora_id] = {
"context_length": lora_checkpoint_info["context_length"],
"lora": local_lora_path,
}
return infos


@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
Expand Down
Empty file added tests/lora/data/__init__.py
Empty file.
97 changes: 97 additions & 0 deletions tests/lora/data/long_context_test_data.py

Large diffs are not rendered by default.

243 changes: 243 additions & 0 deletions tests/lora/test_long_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import numpy as np
import pytest
import vllm
import ast

from vllm import SamplingParams

import torch
from typing import List, Optional, Tuple
from vllm.lora.request import LoRARequest
# from vllm.anyscale.tokenization import InputTooLongError

from .data.long_context_test_data import prompts_and_responses

context_len_to_scaling_factor = {
"16k": 4,
"32k": 8,
}

# We use the same sampling params for all requests
sampling_params = SamplingParams(
temperature=0,
max_tokens=100,
)


def _create_lora_request(lora_id, long_context_infos):
context_len = long_context_infos[lora_id]["context_length"]
scaling_factor = context_len_to_scaling_factor[context_len]
return LoRARequest(context_len, lora_id,
long_context_infos[lora_id]["lora"],
4096 * scaling_factor)


def evaluate_json_response(model_response, golden_response):
"""Evaluates the model response against the golden response.

Returns a score between 0 and 1, where 1 is a perfect match and 0 is no match.
The score quantifies how well the model is able to extract the golden JSON from the long context.
"""
try:
model_response = ast.literal_eval(model_response)
except:
raise ValueError(
f"Model response is not a valid JSON. Expected {golden_response}, got {model_response}"
)

# Normally, we would flatten the dictionary and compare the values, but in this case, we know that the dictionary is only 2 levels deep
positive_values = 0
total_values = 0
# We look at all the attributes of the person that we are extracting a biography of and copmare them to the golden response
for person_attribute, person_attribute_value in golden_response.items():
if person_attribute in model_response:
if type(person_attribute_value) == dict:
for sub_attribute, sub_attribute_value in person_attribute_value.items(
):
total_values += 1
if sub_attribute in model_response[
person_attribute] and model_response[
person_attribute][
sub_attribute] == sub_attribute_value:
positive_values += 1
else:
total_values += 1
if model_response[person_attribute] == person_attribute_value:
positive_values += 1
else:
# We count a missing sub-dict as a single missed value.
total_values += 1

# Return a score between 0 and 1
return positive_values / total_values


def generate(
llm,
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
):
prompts, sampling_param, lora_request = inputs
outputs = llm.generate(prompts, sampling_param, lora_request=lora_request)
return outputs[0].outputs[0].text.strip()


def batched_generate(
llm,
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
):
for input in inputs:
prompt, sampling_param, lora_req = input
llm._add_request(prompt, sampling_param, lora_request=lora_req)
outputs = llm._run_engine()
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]


class TestLongContext:

def _get_lora_llm(self, long_context_infos):
scaling_factors = [
context_len_to_scaling_factor[info["context_length"]]
for info in long_context_infos.values()
]

lora_llm = vllm.LLM(
"meta-llama/Llama-2-13b-chat-hf",
enable_lora=True,
max_num_seqs=16,
max_loras=2,
long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4,
)
return lora_llm

def test_batched_rope_kernel(self, long_context_infos):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently this test has illegal memory access

"""We test the batched kernel by comparing the results of batched and non-batched generation."""
lora_llm = self._get_lora_llm(long_context_infos)

# Create non batched results first to compare against batched results
non_batched_results = []

for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
lora_prompt = (prompts_and_responses[context_len][0]["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos))
lora_output = generate(lora_llm, lora_prompt)
non_batched_results.append(lora_output)

# Create batched results
# Each element of the batch must be (prompt, prompt_sampling_params, prompt_lora_request)
batched_prompts = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos))
])

batched_results = batched_generate(lora_llm, batched_prompts)

# Results should be the same
for non_batched, batched in zip(non_batched_results, batched_results):
assert non_batched == batched, f"Non batched and batched results should be the same:\n{batched}\n{non_batched}"


# def test_self_consistency(self, long_context_infos):
# """We test consistency of the batched kernel by permuting batched inputs and comparing the results to the non-permuted batched results."""
# lora_llm = self._get_lora_llm(long_context_infos)
# num_loras = len(long_context_infos)

# # Create results in order of long_context_infos
# batched_prompts = []
# for lora_id, info in long_context_infos.items():
# context_len = info["context_length"]
# batched_prompts.extend([
# (prompts_and_responses[context_len][0]["prompt"],
# sampling_params,
# _create_lora_request(lora_id, long_context_infos))
# ])

# batched_results = batched_generate(lora_llm, batched_prompts)

# permutation = np.random.default_rng(seed=42).permutation(num_loras)

# # Create results in random order of permutation
# batched_prompts = []
# for i in permutation:
# lora_id, info = list(long_context_infos.items())[i]
# context_len = info["context_length"]
# batched_prompts.extend([
# (prompts_and_responses[context_len][0]["prompt"],
# sampling_params,
# _create_lora_request(lora_id, long_context_infos))
# ])

# permutated_batched_results = batched_generate(lora_llm,
# batched_prompts)

# # Results should be the same
# for i in range(num_loras):
# assert batched_results[i] == permutated_batched_results[permutation[
# i]], f"Results should be the same:\n{batched_results[i]}\n{permutated_batched_results[permutation[i]]}"

# def test_quality(self, long_context_infos):
# """We test the quality of the answers given by the LoRA model by comparing the generated text to the merged model's outputs.

# This is effectively a mini-benchmark over four prompts.
# If this test fails, this indicates that the quality of the LoRA model is suboptimal compared to the merged model.
# For example, if the model does not output valid dictionaries, this test will fail.

# If needed for testing, the merged versions of the models are available as part of the `conftest`.
# a
# The test is expected to run for about 1 minute on a p4de.24xlarge instance.
# """
# lora_llm = self._get_lora_llm(long_context_infos)

# scores = []
# for lora_id, info in long_context_infos.items():
# context_len = info["context_length"]
# for prompt_and_response in prompts_and_responses[context_len]:
# lora_prompt = (prompt_and_response["prompt"], sampling_params,
# _create_lora_request(lora_id,
# long_context_infos))
# response = generate(lora_llm, [lora_prompt])
# golden_answer = prompt_and_response["golden_answer"]
# score = evaluate_json_response(response, golden_answer)
# scores.append(score)
# assert score > 0.3, f"Quality of the answer is not good enough. Expected {golden_answer}, got {response}"
# assert np.mean(scores) > 0.5

# def test_max_len(self, long_context_infos):
# """Test that we raise an InputTooLongError when the input of a given LoRA model exceeds the maximum length."""
# lora_llm = self._get_lora_llm(long_context_infos)

# # Since each LoRA model has a different maximum length, we need to test each one separately
# for lora_id, info in long_context_infos.items():
# context_len = info["context_length"]
# lora_request = _create_lora_request(lora_id, long_context_infos)
# # Good prompt should be fine
# good_prompt = prompts_and_responses[context_len][0]["prompt"]
# generate(lora_llm, [(good_prompt, sampling_params, lora_request)])
# # Bad prompt should raise an error
# bad_prompt = good_prompt * 2
# with pytest.raises(InputTooLongError):
# generate(lora_llm,
# [(bad_prompt, sampling_params, lora_request)])

# # Also test batched
# batched_prompts = []
# for lora_id_with_bad_inputs in long_context_infos.keys():
# for lora_id, info in long_context_infos.items():
# context_len = info["context_length"]
# batched_prompts.extend([
# (prompts_and_responses[context_len][0]["prompt"] *
# (2 if lora_id == lora_id_with_bad_inputs else 1),
# sampling_params,
# _create_lora_request(lora_id, long_context_infos))
# ])
# # Turn good prompt into bad prompt inside of batched prompts

# with pytest.raises(InputTooLongError):
# batched_generate(lora_llm, batched_prompts)
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union, Tuple

import torch
from transformers import PretrainedConfig
Expand Down Expand Up @@ -990,6 +990,7 @@ class LoRAConfig:
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None

def __post_init__(self):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
Expand Down
9 changes: 6 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,14 @@ def _schedule_prefills(
num_prompt_tokens = waiting_seqs[0].get_len()
assert num_new_tokens == num_prompt_tokens

if num_new_tokens > self.prompt_limit:
prompt_limit = (seq_group.lora_request.long_lora_max_len
if seq_group.lora_request
and seq_group.lora_request.long_lora_max_len else
self.prompt_limit)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put it in a method/function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed self.prompt_limit btw because in this case, it doesn't make much sense to have it.

if num_new_tokens > prompt_limit:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens,
self.prompt_limit)
" and exceeds limit of %d", num_new_tokens, prompt_limit)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
Expand Down
9 changes: 8 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import dataclasses
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple

from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
Expand Down Expand Up @@ -63,6 +63,7 @@ class EngineArgs:
max_lora_rank: int = 16
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
Expand Down Expand Up @@ -397,6 +398,11 @@ def add_cli_args(
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
# Q: Do we need it? We can't pass tuple to cli args?
# parser.add_argument('--long-lora-scaling-factors',
# type=Tuple[float],
# default=EngineArgs.long_lora_scaling_factors,
# help='Scaling factors of long LoRAs')
parser.add_argument(
'--max-cpu-loras',
type=int,
Expand Down Expand Up @@ -589,6 +595,7 @@ def create_engine_config(self, ) -> EngineConfig:
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
Expand Down
Loading
Loading