Skip to content

Commit

Permalink
lora + chunked prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
aurickq committed Oct 23, 2024
1 parent fc6c274 commit ae18da2
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 24 deletions.
8 changes: 6 additions & 2 deletions tests/lora/test_chatglm3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -37,13 +39,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


def test_chatglm3_lora(chatglm3_lora_files):
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_chatglm3_lora(chatglm3_lora_files, enable_chunked_prefill):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True)
trust_remote_code=True,
enable_chunked_prefill=enable_chunked_prefill)

expected_lora_output = [
"SELECT count(*) FROM singer",
Expand Down
6 changes: 4 additions & 2 deletions tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:


@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_gemma_lora(gemma_lora_files, enable_chunked_prefill):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4)
max_loras=4,
enable_chunked_prefill=enable_chunked_prefill)

expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n",
Expand Down
25 changes: 18 additions & 7 deletions tests/lora/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:


@pytest.mark.parametrize("tp_size", [1, 2, 4])
def test_llama_lora(sql_lora_files, tp_size, num_gpus_available):
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_llama_lora(sql_lora_files, tp_size, enable_chunked_prefill,
num_gpus_available):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=tp_size)
tensor_parallel_size=tp_size,
enable_chunked_prefill=enable_chunked_prefill)

expected_no_lora_output = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
Expand Down Expand Up @@ -88,7 +91,8 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
tensor_parallel_size=1,
enable_chunked_prefill=True)
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)

del llm_tp1
Expand All @@ -98,7 +102,8 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=2)
tensor_parallel_size=2,
enable_chunked_prefill=True)
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)

del llm_tp2
Expand All @@ -110,7 +115,8 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4)
tensor_parallel_size=4,
enable_chunked_prefill=True)
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)

del llm_tp4
Expand All @@ -125,13 +131,18 @@ def test_llama_lora_warmup(sql_lora_files):

@ray.remote(num_gpus=1)
def get_num_gpu_blocks_lora():
llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16)
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
enable_chunked_prefill=True)
num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
return num_gpu_blocks_lora_warmup

@ray.remote(num_gpus=1)
def get_num_gpu_blocks_no_lora():
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
llm = vllm.LLM(MODEL_PATH,
max_num_seqs=16,
enable_chunked_prefill=True)
num_gpu_blocks_no_lora_warmup = (
llm.llm_engine.cache_config.num_gpu_blocks)
return num_gpu_blocks_no_lora_warmup
Expand Down
3 changes: 2 additions & 1 deletion tests/lora/test_long_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def lora_llm(long_context_infos):
tensor_parallel_size=4,
# FIXME enable async output processor
disable_async_output_proc=True,
distributed_executor_backend="mp")
distributed_executor_backend="mp",
enable_chunked_prefill=True)
yield llm
del llm

Expand Down
3 changes: 2 additions & 1 deletion tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def test_minicpmv_lora(minicpmv_lora_files):
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
gpu_memory_utilization=0.97 # This model is pretty big for CI gpus
gpu_memory_utilization=0.97, # This model is pretty big for CI gpus
enable_chunked_prefill=True,
)

output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
Expand Down
2 changes: 2 additions & 0 deletions tests/lora/test_minicpmv_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
tensor_parallel_size=2,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
enable_chunked_prefill=True,
)

output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
Expand All @@ -89,6 +90,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
enable_chunked_prefill=True,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
Expand Down
1 change: 1 addition & 0 deletions tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
enable_chunked_prefill=True,
)

expected_lora_output = [
Expand Down
3 changes: 2 additions & 1 deletion tests/lora/test_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_phi2_lora(phi2_lora_files):
max_model_len=1024,
enable_lora=True,
max_loras=2,
enforce_eager=True)
enforce_eager=True,
enable_chunked_prefill=True)

expected_lora_output = [
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501
Expand Down
9 changes: 6 additions & 3 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
tensor_parallel_size=tp_size,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
trust_remote_code=True,
enable_chunked_prefill=True)

if model.quantization is None:
expected_no_lora_output = [
Expand Down Expand Up @@ -176,7 +177,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
tensor_parallel_size=1,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
trust_remote_code=True,
enable_chunked_prefill=True)
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)

del llm_tp1
Expand All @@ -189,7 +191,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
max_loras=4,
tensor_parallel_size=2,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization)
quantization=model.quantization,
enable_chunked_prefill=True)
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)

del llm_tp2
Expand Down
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,8 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if scheduler_config.chunked_prefill_enabled:
raise ValueError("LoRA is not supported with chunked prefill yet.")
logger.warning("LoRA with chunked prefill is "
"experimental and may be unstable.")


@dataclass
Expand Down
6 changes: 5 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,13 @@ def is_empty(self) -> bool:
and not self.blocks_to_swap_out and not self.blocks_to_copy)

def _sort_by_lora_ids(self):
# Sort sequence groups so that (1) all prefills come before all decodes
# (required by chunked prefill), and (2) all LoRAs are grouped together
# for improved performance.
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
key=lambda g: (not g.seq_group.is_prefill(), g.seq_group.
lora_int_id, g.seq_group.request_id))

@property
def lora_requests(self) -> Set[LoRARequest]:
Expand Down
12 changes: 7 additions & 5 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,13 @@ def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
inter_data.lora_requests.add(seq_group_metadata.lora_request)
query_len = inter_data.query_lens[seq_idx]
inter_data.lora_index_mapping.append([lora_id] * query_len)
inter_data.lora_prompt_mapping.append(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs is not None
else 1))
sampling_params = seq_group_metadata.sampling_params
if sampling_params and sampling_params.prompt_logprobs is not None:
inter_data.lora_prompt_mapping.append([lora_id] * query_len)
elif seq_group_metadata.do_sample:
inter_data.lora_prompt_mapping.append([lora_id])
else:
inter_data.lora_prompt_mapping.append([])

def _compute_prompt_adapter_input(
self, inter_data: InterDataForSeqGroup,
Expand Down

0 comments on commit ae18da2

Please sign in to comment.