Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] LoRA + Chunked Prefill #9057

Merged
merged 16 commits into from
Dec 11, 2024
9 changes: 6 additions & 3 deletions tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_chatglm3_lora(chatglm3_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)
trust_remote_code=True,
enable_chunked_prefill=True)

output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
Expand All @@ -73,7 +74,8 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False)
fully_sharded_loras=False,
enable_chunked_prefill=True)

output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
Expand All @@ -93,7 +95,8 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True)
fully_sharded_loras=True,
enable_chunked_prefill=True)
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
Expand Down
3 changes: 2 additions & 1 deletion tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4)
max_loras=4,
enable_chunked_prefill=True)

expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n",
Expand Down
6 changes: 5 additions & 1 deletion tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_llama_lora(sql_lora_files):
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
tensor_parallel_size=1,
enable_chunked_prefill=True)
generate_and_test(llm, sql_lora_files)


Expand Down Expand Up @@ -120,6 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)

Expand All @@ -135,6 +137,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
max_loras=4,
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)

Expand All @@ -151,5 +154,6 @@ def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_lora_bias=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)
3 changes: 2 additions & 1 deletion tests/lora/test_long_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def lora_llm(long_context_infos):
tensor_parallel_size=4,
# FIXME enable async output processor
disable_async_output_proc=True,
distributed_executor_backend="mp")
distributed_executor_backend="mp",
enable_chunked_prefill=True)
yield llm
del llm

Expand Down
3 changes: 2 additions & 1 deletion tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def test_minicpmv_lora(minicpmv_lora_files):
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
gpu_memory_utilization=0.97 # This model is pretty big for CI gpus
gpu_memory_utilization=0.97, # This model is pretty big for CI gpus
enable_chunked_prefill=True,
)
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
Expand Down
2 changes: 2 additions & 0 deletions tests/lora/test_minicpmv_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
tensor_parallel_size=2,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
enable_chunked_prefill=True,
)

output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
Expand All @@ -89,6 +90,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
enable_chunked_prefill=True,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
Expand Down
1 change: 1 addition & 0 deletions tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
enable_chunked_prefill=True,
)

expected_lora_output = [
Expand Down
3 changes: 2 additions & 1 deletion tests/lora/test_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_phi2_lora(phi2_lora_files):
max_model_len=1024,
enable_lora=True,
max_loras=2,
enforce_eager=True)
enforce_eager=True,
enable_chunked_prefill=True)

expected_lora_output = [
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501
Expand Down
9 changes: 6 additions & 3 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
tensor_parallel_size=tp_size,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
trust_remote_code=True,
enable_chunked_prefill=True)

if model.quantization is None:
expected_no_lora_output = [
Expand Down Expand Up @@ -176,7 +177,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
tensor_parallel_size=1,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
trust_remote_code=True,
enable_chunked_prefill=True)
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)

del llm_tp1
Expand All @@ -189,7 +191,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
max_loras=4,
tensor_parallel_size=2,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization)
quantization=model.quantization,
enable_chunked_prefill=True)
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)

del llm_tp2
Expand Down
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,8 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if scheduler_config.chunked_prefill_enabled:
raise ValueError("LoRA is not supported with chunked prefill yet.")
logger.warning("LoRA with chunked prefill is still experimental "
"and may be unstable.")


@dataclass
Expand Down
15 changes: 12 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,18 @@ def is_empty(self) -> bool:
and not self.blocks_to_swap_out and not self.blocks_to_copy)

def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)

def key_fn(group: ScheduledSequenceGroup):
key = (group.seq_group.lora_int_id, group.seq_group.request_id)
if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
# Sort sequence groups so that all prefills come before all
# decodes as required by chunked prefill.
return (not group.seq_group.is_prefill(), *key)
return key

self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
key=key_fn)

@property
def lora_requests(self) -> Set[LoRARequest]:
Expand Down
12 changes: 7 additions & 5 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,11 +622,13 @@ def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
inter_data.lora_requests.add(seq_group_metadata.lora_request)
query_len = inter_data.query_lens[seq_idx]
inter_data.lora_index_mapping.append([lora_id] * query_len)
inter_data.lora_prompt_mapping.append(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs is not None
else 1))
sampling_params = seq_group_metadata.sampling_params
if sampling_params and sampling_params.prompt_logprobs is not None:
inter_data.lora_prompt_mapping.append([lora_id] * query_len)
elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
inter_data.lora_prompt_mapping.append([lora_id])
else:
inter_data.lora_prompt_mapping.append([])

def _compute_prompt_adapter_input(
self, inter_data: InterDataForSeqGroup,
Expand Down
Loading