diff --git a/format.sh b/format.sh index 8c54b56302d5b..5edc868f9f70c 100755 --- a/format.sh +++ b/format.sh @@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml +mypy vllm/prompt_adapter --config-file pyproject.toml mypy tests --config-file pyproject.toml diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index b50784a205af7..853fd9fb3ce7a 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -92,11 +92,10 @@ def batched_generate( for input in inputs: prompt, sampling_param, lora_req = input # Add requests to the engine and run the engine - llm._validate_and_add_requests( - prompt, - sampling_param, - lora_request=lora_req, - ) + llm._validate_and_add_requests(prompt, + sampling_param, + lora_request=lora_req, + prompt_adapter_request=None) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 2133bce14957b..7bff9e1fbcdcc 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -127,37 +127,37 @@ def test_lora_model_manager(dist_init, dummy_model): model, 2, 2, 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 1 - assert not manager.add_lora(model_lora1) - assert not manager.activate_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora1) + assert not manager.activate_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert not manager.add_lora(model_lora2) - assert not manager.activate_lora(2) - assert manager.add_lora(model_lora3) + assert not manager.add_adapter(model_lora2) + assert not manager.activate_adapter(2) + assert manager.add_adapter(model_lora3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 with pytest.raises(ValueError): - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_lora(model_lora2.id) + assert manager.remove_adapter(model_lora2.id) assert manager.lora_index_to_id[1] is None - assert not manager.remove_lora(model_lora2.id) - assert manager.remove_lora(model_lora1.id) - assert not manager.remove_lora(model_lora1.id) - assert manager.add_lora(model_lora1) + assert not manager.remove_adapter(model_lora2.id) + assert manager.remove_adapter(model_lora1.id) + assert not manager.remove_adapter(model_lora1.id) + assert manager.add_adapter(model_lora1) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] is None - assert manager.add_lora(model_lora2) - assert manager.activate_lora(3) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] is None - assert manager.activate_lora(2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 @@ -173,70 +173,70 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): model, 2, 2, 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 1 - assert not manager.add_lora(model_lora1) - assert not manager.activate_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora1) + assert not manager.activate_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert not manager.add_lora(model_lora2) - assert not manager.activate_lora(2) - assert manager.add_lora(model_lora3) + assert not manager.add_adapter(model_lora2) + assert not manager.activate_adapter(2) + assert manager.add_adapter(model_lora3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_lora(model_lora2.id) + assert manager.remove_adapter(model_lora2.id) assert manager.lora_index_to_id[1] is None - assert not manager.remove_lora(model_lora2.id) - assert manager.remove_lora(model_lora1.id) - assert not manager.remove_lora(model_lora1.id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert not manager.remove_adapter(model_lora2.id) + assert manager.remove_adapter(model_lora1.id) + assert not manager.remove_adapter(model_lora1.id) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 - assert manager.add_lora(model_lora2) - assert manager.deactivate_lora(3) + assert manager.add_adapter(model_lora2) + assert manager.deactivate_adapter(3) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 - assert manager.pin_lora(2) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 - assert manager.activate_lora(1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.deactivate_lora(2) + assert manager.deactivate_adapter(2) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 - assert manager.pin_lora(3) - assert manager.pin_lora(1) + assert manager.pin_adapter(3) + assert manager.pin_adapter(1) with pytest.raises(RuntimeError): - assert manager.pin_lora(2) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 with pytest.raises(RuntimeError): - assert manager.activate_lora(2) + assert manager.activate_adapter(2) - assert manager.deactivate_lora(3) - assert manager.pin_lora(2) + assert manager.deactivate_adapter(3) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.remove_lora(3) + assert manager.remove_adapter(3) with pytest.raises(ValueError): - assert manager.pin_lora(3) + assert manager.pin_adapter(3) def test_lru_lora_model_manager(dist_init, dummy_model): @@ -256,168 +256,169 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity - assert manager.add_lora(model_lora1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(1) - assert manager.activate_lora(2) + assert manager.add_adapter(model_lora1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(1) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {1, 2} + assert set(manager.list_adapters()) == {1, 2} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 # Add over capacity - assert manager.add_lora(model_lora3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(3) - assert manager.activate_lora(4) + assert manager.add_adapter(model_lora3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(3) + assert manager.activate_adapter(4) - assert set(manager.list_loras()) == {3, 4} + assert set(manager.list_adapters()) == {3, 4} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 4 # Add 3 again to move it to the top and then add 2 # should return false since it's in already - assert not manager.add_lora(model_lora3) - assert not manager.activate_lora(3) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora3) + assert not manager.activate_adapter(3) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {3, 2} + assert set(manager.list_adapters()) == {3, 2} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 # Remove manually - assert manager.remove_lora(3) - assert not manager.remove_lora(3) + assert manager.remove_adapter(3) + assert not manager.remove_adapter(3) - assert set(manager.list_loras()) == {2} + assert set(manager.list_adapters()) == {2} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 2 - assert manager.add_lora(model_lora3) - assert manager.activate_lora(3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(4) + assert manager.add_adapter(model_lora3) + assert manager.activate_adapter(3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(4) - assert set(manager.list_loras()) == {3, 4} + assert set(manager.list_adapters()) == {3, 4} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 4 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == {4} + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == {4} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 4 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == set() + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == set() assert all(x is None for x in manager.lora_index_to_id) - assert not manager.remove_oldest_lora() - assert set(manager.list_loras()) == set() + assert not manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == set() assert all(x is None for x in manager.lora_index_to_id) # pinning - assert manager.add_lora(model_lora3) - assert manager.activate_lora(3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(4) - assert set(manager.list_loras()) == {3, 4} + assert manager.add_adapter(model_lora3) + assert manager.activate_adapter(3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(4) + assert set(manager.list_adapters()) == {3, 4} with pytest.raises(ValueError): - assert manager.pin_lora(1) - assert manager.pin_lora(3) + assert manager.pin_adapter(1) + assert manager.pin_adapter(3) # Remove manually - assert manager.remove_lora(3) - assert not manager.remove_lora(3) + assert manager.remove_adapter(3) + assert not manager.remove_adapter(3) - assert set(manager.list_loras()) == {4} + assert set(manager.list_adapters()) == {4} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 4 - assert manager.add_lora(model_lora1) - assert manager.pin_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert manager.add_adapter(model_lora1) + assert manager.pin_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {1, 2} + assert set(manager.list_adapters()) == {1, 2} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == {1} + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == {1} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] is None with pytest.raises(RuntimeError): - assert manager.remove_oldest_lora() + assert manager.remove_oldest_adapter() - assert set(manager.list_loras()) == {1} + assert set(manager.list_adapters()) == {1} -def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): +def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) - worker_lora_manager = LRUCacheWorkerLoRAManager( + worker_adapter_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager( + llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager.list_adapters() == {1, 2} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 3, 4} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 6, 7, 8} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6 + assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6 # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -426,68 +427,69 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, ], mapping) -def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): +def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) - worker_lora_manager = WorkerLoRAManager( + worker_adapter_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager( + llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager.list_adapters() == {1, 2} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 3, 4} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4 + assert worker_adapter_manager.list_adapters() == {1, 3, 4} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager.list_adapters() == {1, 2, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None - assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None + assert worker_adapter_manager.list_adapters() == {1} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {6, 7, 8} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7 + assert worker_adapter_manager.list_adapters() == {6, 7, 8} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7 # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -525,8 +527,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up): assert isinstance(model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA) - assert manager.add_lora(model_lora) - assert manager.add_lora(model_lora1) + assert manager.add_adapter(model_lora) + assert manager.add_adapter(model_lora1) packed_lora = model_lora.get_lora("gate_up_proj") assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py new file mode 100644 index 0000000000000..6528b3009b8c0 --- /dev/null +++ b/tests/prompt_adapter/test_bloom.py @@ -0,0 +1,45 @@ +import pytest + +import vllm +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "bigscience/bloomz-560m" +PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' + + +def do_sample(llm, pa_name: str, pa_id: int): + + prompts = [ + "Tweet text : @nationalgridus I have no water and the bill is \ + current and paid. Can you do something about this? Label : ", + "Tweet text : @nationalgridus Looks good thanks! Label : " + ] + sampling_params = vllm.SamplingParams(temperature=0.0, + max_tokens=3, + stop_token_ids=[3]) + + outputs = llm.generate(prompts, + sampling_params, + prompt_adapter_request=PromptAdapterRequest( + pa_name, pa_id, PA_PATH, 8) if pa_id else None) + + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_twitter_prompt_adapter(enforce_eager: bool): + llm = vllm.LLM(MODEL_PATH, + enforce_eager=enforce_eager, + enable_prompt_adapter=True, + max_prompt_adapter_token=8) + + expected_output = ['complaint', 'no complaint'] + + assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py new file mode 100644 index 0000000000000..39a79becdfbb3 --- /dev/null +++ b/tests/prompt_adapter/test_multi_adapter_inference.py @@ -0,0 +1,53 @@ +from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "bigscience/bloomz-560m" +pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' +pa_path2 = 'swapnilbp/angry_tweet_ptune' + + +def do_sample(engine): + + prompts = [ + ("Tweet text: I have complaints! Label: ", + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("hate_speech", 1, pa_path2, 8)), + ("Tweet text: I have no problems Label: ", + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)), + ("Tweet text: I have complaints! Label: ", + SamplingParams(temperature=0.0, max_tokens=3), None), + ("Tweet text: I have no problems Label: ", + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("complain", 3, pa_path, 8)), + ] + + request_id = 0 + results = set() + while prompts or engine.has_unfinished_requests(): + if prompts: + prompt, sampling_params, pa_request = prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + prompt_adapter_request=pa_request) + request_id += 1 + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + results.add(request_output.outputs[0].text) + return results + + +def test_multi_prompt_adapters(): + engine_args = EngineArgs(model=MODEL_PATH, + max_prompt_adapters=3, + enable_prompt_adapter=True, + max_prompt_adapter_token=8) + engine = LLMEngine.from_engine_args(engine_args) + expected_output = { + ' quot;I', 'hate speech', 'no complaint', 'not hate speech' + } + assert do_sample(engine) == expected_output diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py new file mode 100644 index 0000000000000..2a5f23f7f92ec --- /dev/null +++ b/tests/prompt_adapter/test_pa_lora.py @@ -0,0 +1,61 @@ +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" +pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune") +lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + + +def do_sample(engine): + + prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501 + + # first prompt with a prompt adapter and second without adapter + prompts = [ + (prompt_text, + SamplingParams(temperature=0.0, max_tokens=100, + stop=["[/assistant]"]), + PromptAdapterRequest("hate_speech", 1, pa_path, + 8), LoRARequest("sql_test", 1, lora_path)), + (prompt_text, + SamplingParams(temperature=0.0, max_tokens=100, + stop=["[/assistant]"]), None, + LoRARequest("sql_test", 1, lora_path)), + ] + + request_id = 0 + results = set() + while prompts or engine.has_unfinished_requests(): + if prompts: + prompt, sampling_params, pa_request, lora_request = prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + prompt_adapter_request=pa_request, + lora_request=lora_request) + request_id += 1 + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + results.add(request_output.outputs[0].text) + return results + + +def test_lora_prompt_adapter(): + engine_args = EngineArgs(model=MODEL_PATH, + enable_prompt_adapter=True, + enable_lora=True, + max_num_seqs=60, + max_prompt_adapter_token=8) + engine = LLMEngine.from_engine_args(engine_args) + result = do_sample(engine) + + expected_output = { + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501 + } + assert result == expected_output diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 8ad8e9cb81ff8..fb3415b5db153 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -13,6 +13,7 @@ from vllm.model_executor.utils import set_random_seed from vllm.multimodal import MultiModalDataDict from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob from vllm.usage.usage_lib import UsageContext @@ -92,6 +93,7 @@ def generate( use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalDataDict] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> List[RequestOutput]: if prompts is None: diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e1775790c0a03..b5742c4338616 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: cache_config=engine_config.cache_config, load_config=engine_config.load_config, lora_config=engine_config.lora_config, + prompt_adapter_config=engine_config.prompt_adapter_config, is_driver_worker=True, ) return model_runner diff --git a/vllm/adapter_commons/__init__.py b/vllm/adapter_commons/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py new file mode 100644 index 0000000000000..3ed60678b52f5 --- /dev/null +++ b/vllm/adapter_commons/layers.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import Tuple + + +@dataclass +class AdapterMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py new file mode 100644 index 0000000000000..6939b1405f3e1 --- /dev/null +++ b/vllm/adapter_commons/models.py @@ -0,0 +1,104 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Hashable, Optional, TypeVar + +from torch import nn + +from vllm.logger import init_logger +from vllm.utils import LRUCache + +logger = init_logger(__name__) + + +class AdapterModel(ABC): + + def __init__(self, model_id=None): + self.id = model_id + + @abstractmethod + def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): + # Common initialization code + # Load weights or embeddings from local checkpoint + raise NotImplementedError("Subclasses must implement this method.") + + +T = TypeVar('T') + + +class AdapterLRUCache(LRUCache[T]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: Hashable, value: T): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + +class AdapterModelManager(ABC): + + def __init__( + self, + model: nn.Module, + ): + """Create a AdapterModelManager and adapter for a given model. + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + self._registered_adapters: Dict[int, Any] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_adapters: Dict[int, None] = {} + self.adapter_type = 'Adapter' + self._last_mapping = None + + def __len__(self) -> int: + return len(self._registered_adapters) + + @property + @abstractmethod + def adapter_slots(self): + ... + + @property + @abstractmethod + def capacity(self): + ... + + @abstractmethod + def activate_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def deactivate_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def add_adapter(self, adapter: Any) -> bool: + ... + + @abstractmethod + def set_adapter_mapping(self, mapping: Any) -> None: + ... + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def remove_all_adapters(self): + ... + + @abstractmethod + def get_adapter(self, adapter_id: int) -> Optional[Any]: + ... + + @abstractmethod + def list_adapters(self) -> Dict[int, Any]: + ... + + @abstractmethod + def pin_adapter(self, adapter_id: int) -> bool: + ... diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py new file mode 100644 index 0000000000000..69775ab7d4548 --- /dev/null +++ b/vllm/adapter_commons/request.py @@ -0,0 +1,25 @@ +from abc import abstractmethod +from dataclasses import dataclass + + +@dataclass +class AdapterRequest: + """ + Base class for adapter requests. + """ + + @property + @abstractmethod + def adapter_id(self): + ... + + def __post_init__(self): + if self.adapter_id < 1: + raise ValueError(f"id must be > 0, got {self.adapter_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, self.__class__) and self.adapter_id == value.adapter_id + + def __hash__(self) -> int: + return hash(self.adapter_id) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py new file mode 100644 index 0000000000000..6c5411f7d3d5c --- /dev/null +++ b/vllm/adapter_commons/utils.py @@ -0,0 +1,90 @@ +from typing import Any, Callable, Dict, Optional, Set + + +## model functions +def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], + deactivate_func: Callable) -> bool: + if adapter_id in active_adapters: + deactivate_func(adapter_id) + active_adapters.pop(adapter_id) + return True + return False + + +def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], + capacity: int, add_func: Callable) -> bool: + if adapter.id not in registered_adapters: + if len(registered_adapters) >= capacity: + raise RuntimeError('No free adapter slots.') + add_func(adapter) + registered_adapters[adapter.id] = adapter + return True + return False + + +def set_adapter_mapping(mapping: Any, last_mapping: Any, + set_mapping_func: Callable) -> Any: + if last_mapping != mapping: + set_mapping_func(mapping) + return mapping + return last_mapping + + +def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], + deactivate_func: Callable) -> bool: + deactivate_func(adapter_id) + return bool(registered_adapters.pop(adapter_id, None)) + + +def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: + return dict(registered_adapters) + + +def get_adapter(adapter_id: int, + registered_adapters: Dict[int, Any]) -> Optional[Any]: + return registered_adapters.get(adapter_id, None) + + +## worker functions +def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], + apply_adapters_func, + set_adapter_mapping_func) -> None: + apply_adapters_func(requests) + set_adapter_mapping_func(mapping) + + +def add_adapter_worker(adapter_request: Any, list_adapters_func, + load_adapter_func, add_adapter_func, + activate_adapter_func) -> bool: + if adapter_request.adapter_id in list_adapters_func(): + return False + loaded_adapter = load_adapter_func(adapter_request) + loaded = add_adapter_func(loaded_adapter) + activate_adapter_func(loaded_adapter.id) + return loaded + + +def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, + adapter_slots: int, remove_adapter_func, + add_adapter_func) -> None: + models_that_exist = list_adapters_func() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + f"than the number of GPU model slots " + f"({adapter_slots}).") + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + for adapter_id in models_to_remove: + remove_adapter_func(adapter_id) + for adapter_id in models_to_add: + add_adapter_func(models_map[adapter_id]) + + +def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: + return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py new file mode 100644 index 0000000000000..acf18993af6d7 --- /dev/null +++ b/vllm/adapter_commons/worker_manager.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Set + +import torch + + +class AbstractWorkerManager(ABC): + + def __init__(self, device: torch.device): + self.device = device + + @property + @abstractmethod + def is_enabled(self) -> bool: + ... + + @abstractmethod + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + ... + + @abstractmethod + def add_adapter(self, adapter_request: Any) -> bool: + ... + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def remove_all_adapters(self): + ... + + @abstractmethod + def list_adapters(self) -> Set[int]: + ... diff --git a/vllm/config.py b/vllm/config.py index 1ea2888796808..68ca81a2ec4fe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1285,6 +1285,39 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): raise ValueError("LoRA is not supported with chunked prefill yet.") +@dataclass +class PromptAdapterConfig: + max_prompt_adapters: int + max_prompt_adapter_token: int + max_cpu_prompt_adapters: Optional[int] = None + prompt_adapter_dtype: Optional[torch.dtype] = None + + def __post_init__(self): + library_name = 'peft' + try: + __import__(library_name) + except ImportError as e: + raise ImportError( + f"'{library_name}' is not installed for prompt adapter support." + f"Please install it using 'pip install {library_name}'." + ) from e + + if self.max_prompt_adapters < 1: + raise ValueError(f"max_prompt_adapters " + f"({self.max_prompt_adapters}) must be >= 1.") + if self.max_prompt_adapter_token == 0: + raise ValueError("max_prompt_adapter_token must be set.") + if self.max_cpu_prompt_adapters is None: + self.max_cpu_prompt_adapters = self.max_prompt_adapters + + def verify_with_model_config(self, model_config: ModelConfig): + if self.prompt_adapter_dtype in (None, "auto"): + self.prompt_adapter_dtype = model_config.dtype + elif isinstance(self.prompt_adapter_dtype, str): + self.prompt_adapter_dtype = getattr(torch, + self.prompt_adapter_dtype) + + @dataclass class MultiModalConfig: """Configs the input data format and how models should run for @@ -1518,6 +1551,7 @@ class EngineConfig: speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] observability_config: Optional[ObservabilityConfig] + prompt_adapter_config: Optional[PromptAdapterConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1529,6 +1563,9 @@ def __post_init__(self): self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) def to_dict(self): """Return the configs as a dictionary, for use in **kwargs. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9e626b2883975..6bda18cd4f061 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -11,6 +11,7 @@ from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -139,6 +140,8 @@ def __post_init__(self): if self.num_loras > 0: self._sort_by_lora_ids() + self.num_prompt_adapters: int = len(self.prompt_adapter_requests) + def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in @@ -157,6 +160,14 @@ def lora_requests(self) -> Set[LoRARequest]: if g.seq_group.lora_request is not None } + @property + def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: + return { + g.seq_group.prompt_adapter_request + for g in self.scheduled_seq_groups + if g.seq_group.prompt_adapter_request is not None + } + @dataclass class SchedulerRunningOutputs: @@ -1024,6 +1035,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + prompt_adapter_request=seq_group.prompt_adapter_request, ) seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index afa6892d49eb8..b972573c0258e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,8 +7,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, - TokenizerPoolConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -66,6 +66,9 @@ class EngineArgs: enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 + enable_prompt_adapter: bool = False + max_prompt_adapters: int = 1 + max_prompt_adapter_token: int = 0 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None @@ -449,6 +452,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) + parser.add_argument('--enable-prompt-adapter', + action='store_true', + help='If True, enable handling of PromptAdapters.') + parser.add_argument('--max-prompt-adapters', + type=int, + default=EngineArgs.max_prompt_adapters, + help='Max number of PromptAdapters in a batch.') + parser.add_argument('--max-prompt-adapter-token', + type=int, + default=EngineArgs.max_prompt_adapter_token, + help='Max number of PromptAdapters tokens') parser.add_argument("--device", type=str, default=EngineArgs.device, @@ -726,6 +740,11 @@ def create_engine_config(self, ) -> EngineConfig: model_loader_extra_config=self.model_loader_extra_config, ) + prompt_adapter_config = PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters, + max_prompt_adapter_token=self.max_prompt_adapter_token) \ + if self.enable_prompt_adapter else None + decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) @@ -751,6 +770,7 @@ def create_engine_config(self, ) -> EngineConfig: load_config=load_config, decoding_config=decoding_config, observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config, ) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 33e40c7b3624a..9b4ef48b0e47e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -18,6 +18,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.usage.usage_lib import UsageContext @@ -264,6 +265,7 @@ async def process_model_inputs_async( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -279,6 +281,12 @@ async def process_model_inputs_async( else: prompt_token_ids = inputs["prompt_token_ids"] + if prompt_adapter_request: + prompt_token_ids = [ + 0 + ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \ + prompt_token_ids + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) @@ -286,13 +294,14 @@ async def process_model_inputs_async( return self.input_processor(llm_inputs) async def add_request_async( - self, - request_id: str, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + self, + request_id: str, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -301,7 +310,10 @@ async def add_request_async( arrival_time = time.time() processed_inputs = await self.process_model_inputs_async( - request_id=request_id, inputs=inputs, lora_request=lora_request) + request_id=request_id, + inputs=inputs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, @@ -309,6 +321,7 @@ async def add_request_async( params=params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, ) @@ -627,6 +640,7 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncStream: if self.log_requests: if isinstance(inputs, str): @@ -669,7 +683,7 @@ async def add_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - ) + prompt_adapter_request=prompt_adapter_request) return stream @@ -680,6 +694,7 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -695,6 +710,8 @@ async def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. Yields: The output `RequestOutput` objects from the LLMEngine @@ -749,6 +766,7 @@ async def generate( sampling_params, lora_request=lora_request, trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, ): yield LLMEngine.validate_output(output, RequestOutput) @@ -837,6 +855,7 @@ async def _process_request( *, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -849,6 +868,7 @@ async def _process_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, ) try: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index de7604ece7c31..b476594fc73f6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -8,7 +8,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, + ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) @@ -27,6 +28,7 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, PoolerOutput, SamplerOutput, Sequence, @@ -93,6 +95,8 @@ class LLMEngine: decoding. executor_class: The model executor class for managing distributed execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. """ @@ -161,6 +165,7 @@ def __init__( speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -222,6 +227,7 @@ def __init__( self.speculative_config = speculative_config self.load_config = load_config self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats @@ -250,6 +256,7 @@ def __init__( multimodal_config=multimodal_config, speculative_config=speculative_config, load_config=load_config, + prompt_adapter_config=prompt_adapter_config, ) if not self.model_config.embedding_mode: @@ -282,6 +289,8 @@ def __init__( # Feature flags "enable_lora": bool(lora_config), + "enable_prompt_adapter": + bool(prompt_adapter_config), "enable_prefix_caching": cache_config.enable_prefix_caching, "enforce_eager": @@ -376,7 +385,6 @@ def from_engine_args( engine_config = engine_args.create_engine_config() distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) - # Initialize the cluster and specify the executor class. if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor @@ -409,7 +417,6 @@ def from_engine_args( else: from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor - # Create the LLM engine. engine = cls( **engine_config.to_dict(), @@ -470,6 +477,9 @@ def _verify_args(self) -> None: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) def _get_eos_token_id( self, lora_request: Optional[LoRARequest]) -> Optional[int]: @@ -487,6 +497,7 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Dict[str, str]] = None, ) -> None: # Create the sequences. @@ -495,7 +506,7 @@ def _add_processed_request( eos_token_id = self._get_eos_token_id(lora_request) seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request) + lora_request, prompt_adapter_request) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -506,7 +517,7 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - ) + prompt_adapter_request=prompt_adapter_request) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -514,7 +525,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - ) + prompt_adapter_request=prompt_adapter_request) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -535,6 +546,7 @@ def process_model_inputs( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -549,6 +561,11 @@ def process_model_inputs( else: prompt_token_ids = inputs["prompt_token_ids"] + if prompt_adapter_request: + prompt_token_ids = \ + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + + prompt_token_ids + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) @@ -563,6 +580,7 @@ def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -612,9 +630,11 @@ def add_request( if arrival_time is None: arrival_time = time.time() - processed_inputs = self.process_model_inputs(request_id=request_id, - inputs=inputs, - lora_request=lora_request) + processed_inputs = self.process_model_inputs( + request_id=request_id, + inputs=inputs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, @@ -622,6 +642,7 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, ) @@ -633,6 +654,7 @@ def _create_sequence_group_with_sampling( arrival_time: float, lora_request: Optional[LoRARequest], trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -658,7 +680,7 @@ def _create_sequence_group_with_sampling( sampling_params=sampling_params, lora_request=lora_request, trace_headers=trace_headers, - ) + prompt_adapter_request=prompt_adapter_request) return seq_group @@ -669,16 +691,19 @@ def _create_sequence_group_with_pooling( pooling_params: PoolingParams, arrival_time: float, lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler pooling_params = pooling_params.clone() # Create the sequence group. - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -1082,6 +1107,16 @@ def list_loras(self) -> Set[int]: def pin_lora(self, lora_id: int) -> bool: return self.model_executor.pin_lora(lora_id) + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_executor.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_executor.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> List[int]: + return self.model_executor.list_prompt_adapters() + def check_health(self) -> None: if self.tokenizer: self.tokenizer.check_health() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e3e506d496844..57e81a6317725 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.usage.usage_lib import UsageContext @@ -255,6 +256,7 @@ def generate( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -271,6 +273,8 @@ def generate( prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. Returns: A list of `RequestOutput` objects containing the @@ -304,7 +308,7 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - ) + prompt_adapter_request=prompt_adapter_request) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -397,6 +401,7 @@ def encode( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -412,6 +417,8 @@ def encode( use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. Returns: A list of `EmbeddingRequestOutput` objects containing the @@ -445,6 +452,7 @@ def encode( inputs=inputs, params=pooling_params, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, ) outputs = self._run_engine(use_tqdm=use_tqdm) @@ -504,6 +512,7 @@ def _validate_and_add_requests( params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -526,19 +535,23 @@ def _validate_and_add_requests( params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, - ) + prompt_adapter_request=prompt_adapter_request) def _add_request( - self, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[Union[List[LoRARequest], + LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, - inputs, - params, - lora_request=lora_request) + self.llm_engine.add_request( + request_id, + inputs, + params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) def _run_engine( self, *, use_tqdm: bool diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d3ed1ec7a15c5..6cba356c47063 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -116,7 +116,7 @@ async def detokenize(request: DetokenizeRequest): @app.get("/v1/models") async def show_available_models(): - models = await openai_serving_chat.show_available_models() + models = await openai_serving_completion.show_available_models() return JSONResponse(content=models.model_dump()) @@ -236,7 +236,8 @@ async def authentication(request: Request, call_next): args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules) + engine, model_config, served_model_names, args.lora_modules, + args.prompt_adapters) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) app.root_path = args.root_path diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 59ad73bf097c8..81c474ecc808a 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -9,7 +9,8 @@ import ssl from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str -from vllm.entrypoints.openai.serving_engine import LoRAModulePath +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + PromptAdapterPath) from vllm.utils import FlexibleArgumentParser @@ -23,6 +24,16 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, lora_list) +class PromptAdapterParserAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + adapter_list = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + def make_arg_parser(): parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") @@ -65,6 +76,14 @@ def make_arg_parser(): action=LoRAParserAction, help="LoRA module configurations in the format name=path. " "Multiple modules can be specified.") + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") parser.add_argument("--chat-template", type=nullable_str, default=None, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 415bdbbd7c455..010d6f2ebb909 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -258,7 +258,7 @@ async def create_chat_completion( prompt=prompt, add_special_tokens=request.add_special_tokens) sampling_params = request.to_sampling_params() - lora_request = self._maybe_get_lora(request) + _, lora_request = self._maybe_get_adapter(request) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9c719d634ac7d..b53b058b52af3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -22,7 +22,8 @@ TokenizeResponse, UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, - OpenAIServing) + OpenAIServing, + PromptAdapterPath) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]]): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + prompt_adapters=prompt_adapters) async def create_completion(self, request: CompletionRequest, raw_request: Request): @@ -101,7 +104,12 @@ async def create_completion(self, request: CompletionRequest, generators: List[AsyncIterator[RequestOutput]] = [] try: sampling_params = request.to_sampling_params() - lora_request = self._maybe_get_lora(request) + adapter_type, adapter_request = self._maybe_get_adapter(request) + lora_request, prompt_adapter_request = None, None + if adapter_type == 'LoRA': + lora_request, prompt_adapter_request = adapter_request, None + elif adapter_type == 'PromptAdapter': + lora_request, prompt_adapter_request = None, adapter_request decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend @@ -147,6 +155,7 @@ async def create_completion(self, request: CompletionRequest, sampling_params, f"{request_id}-{i}", lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8d281c51f02bc..58e6571d310e6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,12 +16,19 @@ ModelPermission, TokenizeRequest) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import get_tokenizer logger = init_logger(__name__) +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + @dataclass class LoRAModulePath: name: str @@ -30,9 +37,14 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, - served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]] = None, + ): super().__init__() self.engine = engine @@ -49,9 +61,8 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, self.served_model_names = served_model_names - if lora_modules is None: - self.lora_requests = [] - else: + self.lora_requests = [] + if lora_modules is not None: self.lora_requests = [ LoRARequest( lora_name=lora.name, @@ -60,6 +71,20 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, ) for i, lora in enumerate(lora_modules, start=1) ] + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with open(f"./{prompt_adapter.local_path}" + f"/adapter_config.json") as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -75,7 +100,14 @@ async def show_available_models(self) -> ModelList: permission=[ModelPermission()]) for lora in self.lora_requests ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.served_model_names[0], + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] model_cards.extend(lora_cards) + model_cards.extend(prompt_adapter_cards) return ModelList(data=model_cards) def create_error_response( @@ -109,20 +141,29 @@ async def _check_model( return None if request.model in [lora.lora_name for lora in self.lora_requests]: return None + if request.model in [ + prompt_adapter.prompt_adapter_name + for prompt_adapter in self.prompt_adapter_requests + ]: + return None return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - def _maybe_get_lora( + def _maybe_get_adapter( self, request: Union[CompletionRequest, ChatCompletionRequest, EmbeddingRequest] - ) -> Optional[LoRARequest]: + ) -> Tuple[Optional[str], Optional[Union[LoRARequest, + PromptAdapterRequest]]]: if request.model in self.served_model_names: - return None + return None, None for lora in self.lora_requests: if request.model == lora.lora_name: - return lora + return 'LoRA', lora + for prompt_adapter in self.prompt_adapter_requests: + if request.model == prompt_adapter.prompt_adapter_name: + return 'PromptAdapter', prompt_adapter # if _check_model has been called earlier, this will be unreachable raise ValueError(f"The model `{request.model}` does not exist.") diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 3b5621f70b92d..d3b60e3ff4260 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -7,6 +7,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -48,6 +49,7 @@ def _init_worker(self): lora_config=self.lora_config, multimodal_config=self.multimodal_config, kv_cache_dtype=self.cache_config.cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=True, ) self.driver_worker.init_device() @@ -90,6 +92,19 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_worker.list_prompt_adapters() + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + def check_health(self) -> None: # CPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index fc18dec0bca25..6f9e554459161 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -4,8 +4,10 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -28,6 +30,7 @@ def __init__( lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -38,6 +41,7 @@ def __init__( self.device_config = device_config self.multimodal_config = multimodal_config self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config self._init_executor() @@ -95,6 +99,23 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: raise NotImplementedError + @abstractmethod + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError # type: ignore + + @abstractmethod + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError + @abstractmethod def check_health(self) -> None: """Checks if the executor is healthy. If not, it should raise an @@ -122,12 +143,14 @@ def __init__( lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], ) -> None: self.pp_locks: Optional[List[asyncio.Lock]] = None super().__init__(model_config, cache_config, parallel_config, scheduler_config, device_config, load_config, - lora_config, multimodal_config, speculative_config) + lora_config, multimodal_config, speculative_config, + prompt_adapter_config) @abstractmethod async def execute_model_async( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 7d3183a428a31..6ffc28d21be29 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -3,6 +3,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -45,6 +46,7 @@ def _get_worker_kwargs( lora_config=self.lora_config, multimodal_config=self.multimodal_config, speculative_config=self.speculative_config, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), ) @@ -107,6 +109,25 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_worker.list_prompt_adapters() + def check_health(self) -> None: # GPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index f02d4978371a3..33f9321b5ff36 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -8,7 +8,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray @@ -44,6 +45,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: assert device_config.device_type == "xpu" @@ -58,6 +60,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.multimodal_config = multimodal_config + self.prompt_adapter_config = prompt_adapter_config placement_group = self.parallel_config.placement_group diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 29b246332ad55..f6550cce9ab1a 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -4,7 +4,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger @@ -27,6 +28,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: assert device_config.device_type == "xpu" @@ -43,6 +45,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.multimodal_config = multimodal_config + self.prompt_adapter_config = prompt_adapter_config self.speculative_config = None # Instantiate the worker and load the model to GPU. diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 0a63f9ef012bc..40de134c0a5ee 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from transformers import PretrainedConfig +from vllm.adapter_commons.layers import AdapterMapping from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -134,15 +135,8 @@ def _apply_lora_packed_nslice( @dataclass -class LoRAMapping: - # Per every token in input_ids: - index_mapping: Tuple[int, ...] - # Per sampled token: - prompt_mapping: Tuple[int, ...] - - def __post_init__(self): - self.index_mapping = tuple(self.index_mapping) - self.prompt_mapping = tuple(self.prompt_mapping) +class LoRAMapping(AdapterMapping): + pass class BaseLayerWithLoRA(nn.Module): diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 689835def83dd..e1ede7d4d710a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,12 +4,17 @@ import os import re from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import safetensors.torch import torch from torch import nn +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) from vllm.config import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import (BaseLayerWithLoRA, @@ -19,7 +24,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.utils import LRUCache, is_pin_memory_available +from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -153,7 +158,7 @@ def get_lora_id(): return _GLOBAL_LORA_ID -class LoRAModel: +class LoRAModel(AdapterModel): """A LoRA fine-tuned model.""" def __init__( @@ -388,7 +393,7 @@ def from_local_checkpoint( ) -class LoRAModelManager: +class LoRAModelManager(AdapterModelManager): """A manager that manages multiple LoRA-fine-tuned models.""" def __init__( @@ -440,8 +445,7 @@ def __init__( # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices self.indices_len: List[Optional[int]] = [None] * 4 - - self.model = model + super().__init__(model) if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( self.model.supported_lora_modules) @@ -453,11 +457,11 @@ def __init__( self.model.packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {} - self._registered_loras: Dict[int, LoRAModel] = {} # Dict instead of a Set for compatibility with LRUCache. - self._active_loras: Dict[int, None] = {} self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() + self.model.lora_manager = self + self.adapter_type = 'LoRa' @property def capacity(self) -> int: @@ -467,15 +471,16 @@ def capacity(self) -> int: def lora_slots(self) -> int: return self.lora_config.max_loras - def __len__(self) -> int: - return len(self._registered_loras) + @property + def adapter_slots(self) -> int: + return self.lora_slots - def activate_lora( + def activate_adapter( self, lora_id: int, ) -> bool: """Move LoRA into a GPU buffer to be used in the forward pass.""" - if lora_id in self._active_loras: + if lora_id in self._active_adapters: return False first_free_slot = next( ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) @@ -483,8 +488,8 @@ def activate_lora( if first_free_slot is None: raise ValueError("No free lora slots") index, _ = first_free_slot - self._active_loras[lora_id] = None - lora_model = self._registered_loras[lora_id] + self._active_adapters[lora_id] = None + lora_model = self._registered_adapters[lora_id] logger.debug("Activating LoRA. int id: %d, slot index: %d", lora_model.id, index) self.lora_index_to_id[index] = lora_model.id @@ -498,21 +503,13 @@ def activate_lora( module.reset_lora(index) return True - def _deactivate_lora(self, lora_id: int): + def _deactivate_adapter(self, lora_id: int): try: index = self.lora_index_to_id.index(lora_id) self.lora_index_to_id[index] = None except ValueError: pass - def deactivate_lora(self, lora_id: int) -> bool: - """Remove a LoRA from a GPU buffer.""" - if lora_id in self._active_loras: - self._deactivate_lora(lora_id) - self._active_loras.pop(lora_id) - return True - return False - def _set_long_lora_context(self, lora: LoRAModel): if self.long_lora_context is None: return @@ -528,40 +525,19 @@ def _set_long_lora_context(self, lora: LoRAModel): if offsets: self.long_lora_context.offsets_by_lora_id[lora.id] = offsets - def _add_lora(self, lora: LoRAModel): + def _add_adapter(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) - self._registered_loras[lora.id] = lora + self._registered_adapters[lora.id] = lora self._set_long_lora_context(lora) - def add_lora(self, lora: LoRAModel) -> bool: - """Add a LoRAModel to the manager CPU cache.""" - logger.debug( - "Adding lora. Model id: %d, " - "int id: %d, " - "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - if lora.id not in self._registered_loras: - if len(self._registered_loras) >= self.capacity: - raise RuntimeError("No free LoRA slots.") - self._add_lora(lora) - return True - return False - - def remove_lora(self, lora_id: int) -> bool: - """Remove a LoRAModel from the manager CPU cache.""" - # TODO: should we check active lora? - self.deactivate_lora(lora_id) - if self.long_lora_context: - self.long_lora_context.offsets_by_lora_id.pop(lora_id, None) - return bool(self._registered_loras.pop(lora_id, None)) - - def pin_lora(self, lora_id: int) -> bool: + def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" raise NotImplementedError( "Pinning is not supported in LoRAModelManager." "Use LRUCacheLoRAModelManager for pinning") # type: ignore # TODO see if this can be vectorized - def _set_lora_mapping(self, mapping: LoRAMapping) -> None: + def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_offsets_tensor, indices_len) = convert_mapping(mapping, self.lora_index_to_id, @@ -583,23 +559,11 @@ def _set_lora_mapping(self, mapping: LoRAMapping) -> None: # Maintain the reference self.indices_len[:] = indices_len - def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: - if self._last_mapping != lora_mapping: - self._set_lora_mapping(lora_mapping) - self._last_mapping = lora_mapping - - def list_loras(self) -> Dict[int, LoRAModel]: - """List all registered LoRAModels.""" - return dict(self._registered_loras) - - def get_lora(self, lora_id: int) -> Optional[LoRAModel]: - return self._registered_loras.get(lora_id, None) - - def remove_all_loras(self): + def remove_all_adapters(self): """Remove all LoRAModels from the manager.""" - self._registered_loras.clear() + self._registered_adapters.clear() self.lora_index_to_id = [None] * self.lora_slots - self._active_loras.clear() + self._active_adapters.clear() def _create_lora_modules(self): for module_name, module in self.model.named_modules( @@ -743,18 +707,39 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: lora_model.loras[module_name] = PackedLoRALayerWeights.pack( replacement_loras) + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: LoRAModel) -> bool: + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", adapter.id, adapter.id, + adapter.scaling_factor) + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) -class LoRALRUCache(LRUCache[LoRAModel]): + def set_adapter_mapping(self, mapping: LoRAMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class LoRALRUCache(AdapterLRUCache[LoRAModel]): def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]): - super().__init__(capacity) - self.deactivate_lora_fn = deactivate_lora_fn - - def _on_remove(self, key: int, value: LoRAModel): - logger.debug("Removing LoRA. int id: %d", key) - self.deactivate_lora_fn(key) - return super()._on_remove(key, value) + super().__init__(capacity, deactivate_lora_fn) class LRUCacheLoRAModelManager(LoRAModelManager): @@ -770,49 +755,49 @@ def __init__( ): super().__init__(model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config) - self._registered_loras: LoRALRUCache = LoRALRUCache( - self.capacity, self.deactivate_lora) - self._active_loras: LoRALRUCache = LoRALRUCache( - self.lora_slots, self._deactivate_lora) + self._registered_adapters: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_adapter) - def list_loras(self) -> Dict[int, LoRAModel]: + def list_adapters(self) -> Dict[int, LoRAModel]: """List all registered LoRAModels.""" - return dict(self._registered_loras.cache) + return dict(self._registered_adapters.cache) - def add_lora(self, lora: LoRAModel) -> bool: + def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" logger.debug( "Adding lora. Model id: %d, " "int id: %d, " "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - if lora.id not in self._registered_loras: - self._add_lora(lora) + if lora.id not in self._registered_adapters: + self._add_adapter(lora) was_added = True else: # We always touch to update the LRU cache order - self._registered_loras.touch(lora.id) + self._registered_adapters.touch(lora.id) was_added = False return was_added - def activate_lora( + def activate_adapter( self, lora_id: int, ) -> bool: - if lora_id not in self._active_loras and len( - self._active_loras) >= self.lora_slots: - self._active_loras.remove_oldest() - result = super().activate_lora(lora_id) + if lora_id not in self._active_adapters and len( + self._active_adapters) >= self.lora_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(lora_id) # We always touch to update the LRU cache order - self._active_loras.touch(lora_id) + self._active_adapters.touch(lora_id) return result - def remove_oldest_lora(self) -> bool: - if len(self._registered_loras) > 0: - self._registered_loras.remove_oldest() + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() return True return False - def pin_lora(self, lora_id: int) -> bool: + def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" self._pin_lora_in_cpu_cache(lora_id) self._pin_lora_in_gpu_cache(lora_id) @@ -820,17 +805,17 @@ def pin_lora(self, lora_id: int) -> bool: def _pin_lora_in_cpu_cache(self, lora_id: int): try: - self._registered_loras.pin(lora_id) + self._registered_adapters.pin(lora_id) except ValueError as err: raise ValueError("Pinning failed. " f"LoRA {lora_id} is not registered.") from err def _pin_lora_in_gpu_cache(self, lora_id: int): - if lora_id not in self._active_loras: + if lora_id not in self._active_adapters: # move lora to gpu if not already active - self.activate_lora(lora_id) + self.activate_adapter(lora_id) - self._active_loras.pin(lora_id) + self._active_adapters.pin(lora_id) def create_lora_manager( diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 662774ffe09ae..2d10d037760e2 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,13 +1,15 @@ from dataclasses import dataclass from typing import Optional +from vllm.adapter_commons.request import AdapterRequest + @dataclass -class LoRARequest: +class LoRARequest(AdapterRequest): """ Request for a LoRA adapter. - Note that this class should be be used internally. For online + Note that this class should be used internally. For online serving, it is recommended to not allow users to use this class but instead provide another layer of abstraction to prevent users from accessing unauthorized LoRA adapters. @@ -20,15 +22,16 @@ class LoRARequest: lora_int_id: int lora_local_path: str long_lora_max_len: Optional[int] = None + __hash__ = AdapterRequest.__hash__ - def __post_init__(self): - if self.lora_int_id < 1: - raise ValueError( - f"lora_int_id must be > 0, got {self.lora_int_id}") + @property + def adapter_id(self): + return self.lora_int_id - def __eq__(self, value: object) -> bool: - return isinstance( - value, LoRARequest) and self.lora_int_id == value.lora_int_id + @property + def name(self): + return self.lora_name - def __hash__(self) -> int: - return self.lora_int_id + @property + def local_path(self): + return self.lora_local_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index ca4903c23bcaa..3d0ef4252b024 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,12 +1,15 @@ -from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Literal, Optional, Set, Type, Union import torch +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) from vllm.lora.request import LoRARequest @@ -14,79 +17,13 @@ logger = init_logger(__name__) -class AbstractWorkerLoRAManager(ABC): - """Abstract class for managing LoRA models on the worker side.""" - - def __init__(self, - max_num_seqs: int, - max_num_batched_tokens: int, - vocab_size: int, - lora_config: LoRAConfig, - device: torch.device, - max_position_embeddings: Optional[int] = None): - self.max_num_seqs = max_num_seqs - self.max_num_batched_tokens = max_num_batched_tokens - self.max_position_embeddings = max_position_embeddings - self.vocab_size = vocab_size - self.device = device - self.lora_config = lora_config - - # If False, do not cache. If None, cache is empty. - self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False - - @contextmanager - def dummy_lora_cache(self): - """Use this context manager to reuse the dummy lora model - to avoid creating it repeatedly.""" - self._cached_dummy_lora = None - yield - self._cached_dummy_lora = False - - @property - @abstractmethod - def is_enabled(self) -> bool: - ... - - @abstractmethod - def create_lora_manager( - self, - model: torch.nn.Module, - ) -> Any: - ... - - @abstractmethod - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - ... - - @abstractmethod - def add_lora(self, lora_request: LoRARequest) -> bool: - ... - - @abstractmethod - def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - ... - - @abstractmethod - def remove_lora(self, lora_id: int) -> bool: - ... - - @abstractmethod - def remove_all_loras(self): - ... - - @abstractmethod - def list_loras(self) -> Set[int]: - ... - - -class WorkerLoRAManager(AbstractWorkerLoRAManager): +class WorkerLoRAManager(AbstractWorkerManager): """WorkerLoRAManager that manages LoRA models on the worker side. Every request, the requested LoRAs will be loaded (unless they are already loaded), and every other LoRA will be unloaded.""" - _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + _manager_cls: Type[LoRAModelManager] = LoRAModelManager def __init__( self, @@ -103,16 +40,23 @@ def __init__( self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.lora_config = lora_config + self.max_position_embeddings = max_position_embeddings + super().__init__(device) # Lazily initialized by create_lora_manager. - self._lora_manager: LoRAModelManager - super().__init__( - max_num_seqs, - max_num_batched_tokens, - vocab_size, - lora_config, - device, - max_position_embeddings=max_position_embeddings, - ) + self._adapter_manager: LoRAModelManager + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False @property def is_enabled(self) -> bool: @@ -128,41 +72,14 @@ def create_lora_manager( max_num_batched_tokens=self.max_num_batched_tokens, vocab_size=self.vocab_size, lora_config=self.lora_config, - lora_manager_cls=self._lora_manager_cls, + lora_manager_cls=self._manager_cls, ) - self._lora_manager = lora_manager + self._adapter_manager = lora_manager return lora_manager.model - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - self._apply_loras(lora_requests) - self._lora_manager.set_lora_mapping(lora_mapping) - - def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: - loras_that_exist = self.list_loras() - loras_map = { - lora_request.lora_int_id: lora_request - for lora_request in lora_requests if lora_request - } - if len(loras_map) > self._lora_manager.lora_slots: - raise RuntimeError( - f"Number of requested LoRAs ({len(loras_map)}) is greater " - "than the number of GPU LoRA slots " - f"({self._lora_manager.lora_slots}).") - - new_loras = set(loras_map) - loras_to_add = new_loras - loras_that_exist - loras_to_remove = loras_that_exist - new_loras - - for lora_id in loras_to_remove: - self.remove_lora(lora_id) - - for lora_id in loras_to_add: - self.add_lora(loras_map[lora_id]) - - def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - model = self._lora_manager.model + model = self._adapter_manager.model supported_lora_modules = model.supported_lora_modules packed_modules_mapping = model.packed_modules_mapping expected_lora_modules: List[str] = [] @@ -198,37 +115,45 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - if lora_request.lora_int_id in self.list_loras(): + if lora_request.lora_int_id in self.list_adapters(): return False if isinstance(self._cached_dummy_lora, LoRAModel): dummy_lora = self._cached_dummy_lora.clone( lora_request.lora_int_id) else: - dummy_lora = self._lora_manager.create_dummy_lora( + dummy_lora = self._adapter_manager.create_dummy_lora( lora_request.lora_int_id, rank, 1, self.embedding_modules) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora - return self._lora_manager.add_lora(dummy_lora) + return self._adapter_manager.add_adapter(dummy_lora) - def add_lora(self, lora_request: LoRARequest) -> bool: - if lora_request.lora_int_id in self.list_loras(): - return False - lora = self._load_lora(lora_request) - loaded = self._lora_manager.add_lora(lora) - self._lora_manager.activate_lora(lora.id) - return loaded + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) - def remove_lora(self, lora_id: int) -> bool: - return self._lora_manager.remove_lora(lora_id) + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) - def pin_lora(self, lora_id: int) -> bool: - return self._lora_manager.pin_lora(lora_id) + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) - def remove_all_loras(self): - self._lora_manager.remove_all_loras() + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() - def list_loras(self) -> Set[int]: - return set(self._lora_manager.list_loras()) + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) class LRUCacheWorkerLoRAManager(WorkerLoRAManager): @@ -238,8 +163,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): (unless they are already loaded) and least recently used LoRAs will be unloaded if the cache is above capacity.""" - _lora_manager_cls: Type[ - LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager def create_lora_manager( self, @@ -247,40 +171,41 @@ def create_lora_manager( ) -> Any: lora_manager = create_lora_manager( model, - lora_manager_cls=self._lora_manager_cls, + lora_manager_cls=self._manager_cls, max_num_seqs=self.max_num_seqs, vocab_size=self.vocab_size, lora_config=self.lora_config, max_num_batched_tokens=self.max_num_batched_tokens, ) - self._lora_manager = lora_manager + self._adapter_manager = lora_manager return lora_manager.model - def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: + def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request } - if len(loras_map) > self._lora_manager.lora_slots: + if len(loras_map) > self._adapter_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._lora_manager.lora_slots}).") + f"({self._adapter_manager.lora_slots}).") for lora in loras_map.values(): - self.add_lora(lora) + self.add_adapter(lora) - def add_lora(self, lora_request: LoRARequest) -> bool: - if lora_request.lora_int_id not in self.list_loras(): + def add_adapter(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_adapters(): # Remove before we load the new lora to save memory - if len(self._lora_manager) + 1 > self._lora_manager.capacity: - assert isinstance(self._lora_manager, LRUCacheLoRAModelManager) - self._lora_manager.remove_oldest_lora() - lora = self._load_lora(lora_request) - loaded = self._lora_manager.add_lora(lora) + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, + LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + lora = self._load_adapter(lora_request) + loaded = self._adapter_manager.add_adapter(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches - loaded = self._lora_manager.get_lora( + loaded = self._adapter_manager.get_adapter( lora_request.lora_int_id) is not None - self._lora_manager.activate_lora(lora_request.lora_int_id) + self._adapter_manager.activate_adapter(lora_request.lora_int_id) return loaded diff --git a/vllm/prompt_adapter/__init__.py b/vllm/prompt_adapter/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py new file mode 100644 index 0000000000000..27a61e692e1b7 --- /dev/null +++ b/vllm/prompt_adapter/layers.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import PromptAdapterConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + + +@dataclass +class PromptAdapterMapping(AdapterMapping): + pass + + +class VocabParallelEmbeddingWithPromptAdapter(nn.Module): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.emb_layer = self.base_layer + if 'LoRA' in base_layer.__class__.__name__: + self.emb_layer = self.base_layer.base_layer + + def create_prompt_adapter_weights( + self, prompt_adapter_config: PromptAdapterConfig): + self.embeddings_tensors = torch.zeros( + ( + prompt_adapter_config.max_prompt_adapters, + prompt_adapter_config.max_prompt_adapter_token, + self.emb_layer.embedding_dim, + ), + dtype=self.emb_layer.weight.dtype, + device=self.emb_layer.weight.device, + ) + self.adapter_lengths = torch.zeros( + prompt_adapter_config.max_prompt_adapters, + dtype=torch.long, + device=self.emb_layer.weight.device) + + self.indices_gpu: torch.Tensor + self.embedding_indices_gpu: torch.Tensor + + def reset_prompt_adapter(self, index: int): + self.embeddings_tensors[index] = 0 + + def set_prompt_adapter( + self, + index: int, + adapter_model: Optional[torch.Tensor], + ): + self.reset_prompt_adapter(index) + if adapter_model is not None: + length = adapter_model.shape[0] + self.embeddings_tensors[index, :length] = adapter_model + self.adapter_lengths[index] = length + + def set_mapping( + self, + prompt_indices: torch.Tensor, + prompt_embedding_indices: torch.Tensor, + ): + self.indices_gpu = prompt_indices.to( + device=self.emb_layer.weight.device) + self.embedding_indices_gpu = prompt_embedding_indices.to( + device=self.emb_layer.weight.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden_states = self.base_layer(x) + if self.embedding_indices_gpu.ndim > 1: + valid_mask = self.indices_gpu != -1 + gathered_embeddings = self.embeddings_tensors[ + self.embedding_indices_gpu[:, 0], + self.embedding_indices_gpu[:, 1]] + + # Update hidden states + hidden_states[valid_mask] = gathered_embeddings + return hidden_states \ No newline at end of file diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py new file mode 100644 index 0000000000000..93eb3bde646ac --- /dev/null +++ b/vllm/prompt_adapter/models.py @@ -0,0 +1,355 @@ +import logging +import math +from typing import Any, Callable, Dict, List, Optional, Type + +import torch +from torch import nn + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.layers import ( + VocabParallelEmbeddingWithPromptAdapter) # yapf: disable +from vllm.prompt_adapter.layers import PromptAdapterMapping + +logger = logging.getLogger(__name__) + +_GLOBAL_PROMPT_ADAPTER_ID = 0 + + +def get_prompt_adapter_id(): + global _GLOBAL_PROMPT_ADAPTER_ID + _GLOBAL_PROMPT_ADAPTER_ID += 1 + return _GLOBAL_PROMPT_ADAPTER_ID + + +def convert_to_embedding_indices(indices): + embedding_indices = [] + count = 0 + + for value in indices: + if value == -1: + count = 0 + else: + embedding_indices.append([value, count]) + count += 1 + + return torch.tensor(embedding_indices) + + +def convert_mapping( + mapping: PromptAdapterMapping, + prompt_adapter_index_to_id: List[Optional[int]], +) -> torch.Tensor: + """Converts PromptAdapterMapping to index tensors. + + Args: + mapping: PromptAdapterMapping mapping rows in a + batch to PromptAdapter ids. + prompt_adapter_index_to_id: List mapping PromptAdapter + ids to PromptAdapter indices. + + Returns: + pa_indices: Tensor of shape [batch_size] mapping batch rows to + PromptAdapter indices. + """ + id_to_index = { + id_: idx + for idx, id_ in enumerate(prompt_adapter_index_to_id) + if id_ is not None + } + pa_indices = ([ + id_to_index.get(id_, -1) if id_ > 0 else -1 + for id_ in mapping.index_mapping + ]) + + pa_embedding_mapping = convert_to_embedding_indices(pa_indices) + pa_indices = torch.tensor(pa_indices) + return pa_indices, pa_embedding_mapping + + +class PromptAdapterModel(AdapterModel): + + def __init__(self, + prompt_adapter_id=None, + num_virtual_tokens=None, + prompt_embedding=None) -> None: + self.id = prompt_adapter_id + self.prompt_embedding = prompt_embedding + self.num_virtual_tokens = num_virtual_tokens + + @classmethod + def from_local_checkpoint( + cls, + adapter_model_path: str, + prompt_adapter_id: int, + num_virtual_tokens: int, + config: PromptAdapterConfig, + device: str = "cuda", + ) -> "PromptAdapterModel": + from peft.utils import load_peft_weights + + if num_virtual_tokens > config.max_prompt_adapter_token: + raise ValueError( + f'num_virtual_tokens ({num_virtual_tokens}) should be <= ' + f'max_prompt_adapter_token({config.max_prompt_adapter_token})') + + adapters_weights = load_peft_weights(adapter_model_path, device) + prompt_embedding = adapters_weights["prompt_embeddings"].to( + config.prompt_adapter_dtype) + + return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding) + + +class PromptAdapterModelManager(AdapterModelManager): + """A manager that manages multiple Prompt Adapter models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + """Create a PromptAdapterModel and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + prompt_adapter_config: the PromptAdapter config, + """ + self.model: nn.Module = model + # Dict instead of a Set for compatibility with LRUCache. + self.prompt_adapter_index_to_id: List[ + Optional[int]] = [None] * self.prompt_adapter_slots + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.prompt_adapter_config = prompt_adapter_config + self.model.prompt_adapter_manager = self + self.adapter_type = 'PromptAdapter' + + self.base_indices = torch.tensor([-1]) + self.base_embedding_indices = torch.tensor([]) + + self.modules: Dict[str, nn.Module] = {} + self._create_prompt_adapter_modules() + self._last_mapping: Optional[PromptAdapterMapping] = None + + @property + def prompt_adapter_slots(self) -> int: + return self.prompt_adapter_config.max_prompt_adapters + + @property + def adapter_slots(self) -> int: + return self.prompt_adapter_slots + + @property + def capacity(self) -> int: + return self.prompt_adapter_config.max_cpu_prompt_adapters + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + """Move PromptAdapter into a GPU buffer + to be used in the forward pass.""" + if prompt_adapter_id in self._active_adapters: + return False + first_free_slot = next( + ((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate( + self.prompt_adapter_index_to_id) if prompt_adapter_id is None), + None) + if first_free_slot is None: + raise ValueError("No free prompt_adapter slots") + index, _ = first_free_slot + self._active_adapters[prompt_adapter_id] = None + prompt_adapter_model = (self._registered_adapters[prompt_adapter_id]) + logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", + prompt_adapter_model.id, index) + self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id + for _, v in self.modules.items(): + v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding) + return True + + def _deactivate_adapter(self, prompt_adapter_id: int): + try: + index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) + self.prompt_adapter_index_to_id[index] = None + for _, v in self.modules.items(): + v.reset_prompt_adapter(index) + except ValueError: + pass + + def _add_adapter(self, prompt_adapter: PromptAdapterModel): + self._registered_adapters[prompt_adapter.id] = prompt_adapter + + def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + base_indices, base_embedding_indices = convert_mapping( + mapping, self.prompt_adapter_index_to_id) + for k, v in self.modules.items(): + v.set_mapping(base_indices, base_embedding_indices) + + def _create_prompt_adapter_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if "VocabParallel" in module.__class__.__name__: + new_module = VocabParallelEmbeddingWithPromptAdapter(module) + new_module.create_prompt_adapter_weights( + self.prompt_adapter_config) + replaced_module = self.replace_submodule( + self.model, module_name, new_module) + self.register_module(module.__class__.__name__, + replaced_module) + replaced_module.set_mapping(self.base_indices, + self.base_embedding_indices) + break + + def replace_submodule(self, model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + def register_module(self, module_name: str, module: nn.Module): + self.modules[module_name] = module + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in PromptAdapterModelManager." + "Use LRUCachePromptAdapterModelManager for pinning" + ) # type: ignore + + def remove_all_adapters(self): + """Remove all PromptAdapterModel from the manager.""" + self._registered_adapters.clear() + self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots + self._active_adapters.clear() + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: PromptAdapterModel) -> bool: + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) + + def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): + + def __init__(self, capacity: int, + deactivate_prompt_adapter_fn: Callable[[int], bool]): + super().__init__(capacity, deactivate_prompt_adapter_fn) + + +class LRUCachePromptAdapterModelManager(PromptAdapterModelManager): + """A model manager that manages multiple prompt_adapters with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + self.prompt_adapter_config = prompt_adapter_config + super().__init__(model, max_num_seqs, max_num_batched_tokens, + prompt_adapter_config) + self._registered_adapters = PromptAdapterLRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters = PromptAdapterLRUCache( + self.prompt_adapter_slots, self._deactivate_adapter) + + def list_adapters(self) -> Dict[int, PromptAdapterModel]: + """List all registered PromptAdapterModel.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: + """Add a PromptAdapterModel to the manager.""" + if prompt_adapter.id not in self._registered_adapters: + self._add_adapter(prompt_adapter) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(prompt_adapter.id) + was_added = False + return was_added + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + if prompt_adapter_id not in self._active_adapters and len( + self._active_adapters) >= self.prompt_adapter_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(prompt_adapter_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(prompt_adapter_id) + return result + + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id) + self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id) + return True + + def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int): + try: + self._registered_adapters.pin(prompt_adapter_id) + except ValueError as err: + raise ValueError( + "Pinning failed. " + f"Prompt Adapter {prompt_adapter_id} is not registered." + ) from err + + def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int): + if prompt_adapter_id not in self._active_adapters: + # move adapter to gpu if not already active + self.activate_adapter(prompt_adapter_id) + self._active_adapters.pin(prompt_adapter_id) + + +def create_prompt_adapter_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_manager_cls: Type[ + PromptAdapterModelManager] = PromptAdapterModelManager, + **kwargs) -> PromptAdapterModelManager: + """Create a PromptAdapterModel for a given model.""" + prompt_adapter_manager = prompt_adapter_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + prompt_adapter_config=prompt_adapter_config, + **kwargs) + return prompt_adapter_manager diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py new file mode 100644 index 0000000000000..c0c98cf72bbae --- /dev/null +++ b/vllm/prompt_adapter/request.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass + +from vllm.adapter_commons.request import AdapterRequest + + +@dataclass +class PromptAdapterRequest(AdapterRequest): + """ + Request for a Prompt adapter. + """ + + prompt_adapter_name: str + prompt_adapter_id: int + prompt_adapter_local_path: str + prompt_adapter_num_virtual_tokens: int + + def __hash__(self): + return super().__hash__() + + @property + def adapter_id(self): + return self.prompt_adapter_id + + @property + def name(self): + return self.prompt_adapter_name + + @property + def local_path(self): + return self.prompt_adapter_local_path diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py new file mode 100644 index 0000000000000..ddc1ef893c6f2 --- /dev/null +++ b/vllm/prompt_adapter/worker_manager.py @@ -0,0 +1,176 @@ +import logging +from typing import Any, Optional, Set, Type + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, + PromptAdapterModel, + PromptAdapterModelManager, + create_prompt_adapter_manager) +from vllm.prompt_adapter.request import PromptAdapterRequest + +logger = logging.getLogger(__name__) + + +class WorkerPromptAdapterManager(AbstractWorkerManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Every request, the requested prompt_adapters will be + loaded (unless they are already loaded), + and every other prompt_adapter will be unloaded.""" + + _manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + device: torch.device, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel + ): + self._adapter_manager: PromptAdapterModelManager + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self._prompt_adapter_model_cls = prompt_adapter_model_cls + self.prompt_adapter_config = prompt_adapter_config + super().__init__(device) + + @property + def is_enabled(self) -> bool: + return True + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._manager_cls, + ) + self._adapter_manager = prompt_adapter_manager + return prompt_adapter_manager.model + + def _load_adapter( + self, prompt_adapter_request: PromptAdapterRequest + ) -> PromptAdapterModel: + try: + prompt_adapter = ( + self._prompt_adapter_model_cls.from_local_checkpoint( + prompt_adapter_request.prompt_adapter_local_path, + prompt_adapter_id=prompt_adapter_request.prompt_adapter_id, + num_virtual_tokens=prompt_adapter_request. + prompt_adapter_num_virtual_tokens, + config=self.prompt_adapter_config, + device=str(self.device), + )) + except Exception as e: + raise RuntimeError( + f"Loading prompt_adapter " + f"{prompt_adapter_request.prompt_adapter_local_path}" + f" failed") from e + return prompt_adapter + + def add_dummy_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return True + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Uses an LRU Cache. Every request, the requested + prompt_adapters will be loaded (unless they are already loaded) + and least recently used prompt_adapters will + be unloaded if the cache is above capacity.""" + + _prompt_adapter_manager_cls: Type[ + LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._prompt_adapter_manager_cls) + self._adapter_manager: LRUCachePromptAdapterModelManager = ( + prompt_adapter_manager) + return prompt_adapter_manager.model + + def _apply_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None: + prompt_adapters_map = { + prompt_adapter_request.prompt_adapter_id: prompt_adapter_request + for prompt_adapter_request in prompt_adapter_requests + if prompt_adapter_request + } + if len(prompt_adapters_map + ) > self._adapter_manager.prompt_adapter_slots: + raise RuntimeError( + f"Number of requested prompt_adapters " + f"({len(prompt_adapters_map)}) is greater " + "than the number of GPU prompt_adapter slots " + f"({self._adapter_manager.prompt_adapter_slots}).") + for prompt_adapter in prompt_adapters_map.values(): + self.add_adapter(prompt_adapter) + + def add_adapter(self, + prompt_adapter_request: PromptAdapterRequest) -> bool: + if prompt_adapter_request.prompt_adapter_id not in self.list_adapters( + ): + # Remove before we load the new prompt_adapter to save memory + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + self._adapter_manager.remove_oldest_adapter() + prompt_adapter = self._load_adapter(prompt_adapter_request) + loaded = self._adapter_manager.add_adapter(prompt_adapter) + else: + # If the prompt_adapter is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + prompt_adapter_request.prompt_adapter_id) is not None + self._adapter_manager.activate_adapter( + prompt_adapter_request.prompt_adapter_id) + return loaded diff --git a/vllm/sequence.py b/vllm/sequence.py index d200115aa0921..a3f998b94d795 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -10,6 +10,7 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams if TYPE_CHECKING: @@ -238,21 +239,25 @@ class Sequence: block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. + prompt_adapter_request: Prompt Adapter request. + """ def __init__( - self, - seq_id: int, - inputs: "LLMInputs", - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, + self, + seq_id: int, + inputs: "LLMInputs", + block_size: int, + eos_token_id: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: self.seq_id = seq_id self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -287,6 +292,11 @@ def multi_modal_data(self) -> "MultiModalDataDict": def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + def get_output_text_to_return(self, buffer_length: int): # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() @@ -414,6 +424,7 @@ class SequenceGroup: encoder_seq: Optional, the single encoder sequence. Should be None unless you are working with an encoder/decoder model. trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request. """ def __init__( @@ -427,6 +438,7 @@ def __init__( pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -441,6 +453,7 @@ def __init__( self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params + self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers @@ -466,6 +479,16 @@ def multi_modal_data(self) -> "MultiModalDataDict": def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + if self.prompt_adapter_request else 0 + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -624,6 +647,7 @@ class SequenceGroupMetadata: (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. + prompt_adapter_request: Prompt Adapter request. """ def __init__( @@ -642,6 +666,7 @@ def __init__( multi_modal_data: Optional["MultiModalDataDict"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -650,6 +675,7 @@ def __init__( self.block_tables = block_tables self.pooling_params = pooling_params self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state @@ -674,6 +700,16 @@ def __init__( def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ + if self.prompt_adapter_request else 0 + @property def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 6a2cfc819d8d2..90bba96ee8acb 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -4,7 +4,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -48,6 +48,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, multimodal_config: Optional[MultiModalConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, ): if return_hidden_states: @@ -66,6 +67,7 @@ def __init__( kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, multimodal_config=multimodal_config, + prompt_adapter_config=prompt_adapter_config, return_hidden_states=return_hidden_states, ) @@ -136,6 +138,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + virtual_engine = model_input.virtual_engine outputs: List[SamplerOutput] = [] for step in range(num_steps): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index b4277ae827c02..db0e178e45f4e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -8,7 +8,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model @@ -81,6 +81,7 @@ def __init__( lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, @@ -94,6 +95,7 @@ def __init__( self.cache_config = cache_config self.lora_config = lora_config self.multimodal_config = multimodal_config + self.prompt_adapter_config = prompt_adapter_config self.load_config = load_config self.is_driver_worker = is_driver_worker diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 657505739e236..3c22c73267b7f 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -7,7 +7,7 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -133,6 +133,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -145,6 +146,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config self.multimodal_config = multimodal_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: @@ -167,6 +169,7 @@ def __init__( lora_config=self.lora_config, multimodal_config=self.multimodal_config, kv_cache_dtype=kv_cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a3b31a1c0ac8a..a333e6634a41f 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -5,7 +5,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams @@ -40,6 +40,7 @@ def __init__( lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, ): super().__init__(model_config, @@ -51,6 +52,7 @@ def __init__( lora_config=lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, multimodal_config=multimodal_config) @torch.inference_mode() @@ -71,6 +73,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d0c82d6bbedf3..205b4f58f7a83 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -25,7 +25,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY @@ -40,6 +40,10 @@ supports_vision) from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, MultiModalInputs) +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.prompt_adapter.worker_manager import ( + LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -85,6 +89,8 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_mapping: Optional["LoRAMapping"] = None lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None + prompt_adapter_mapping: Optional[PromptAdapterMapping] = None + prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None @@ -97,6 +103,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, @@ -133,6 +141,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, @@ -172,6 +182,7 @@ def __init__( lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, return_hidden_states: bool = False, ): @@ -183,6 +194,7 @@ def __init__( self.lora_config = lora_config self.load_config = load_config self.is_driver_worker = is_driver_worker + self.prompt_adapter_config = prompt_adapter_config self.multimodal_config = multimodal_config self.return_hidden_states = return_hidden_states @@ -232,6 +244,7 @@ def __init__( self.model: nn.Module # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None self.flashinfer_decode_workspace_buffer = None self.flashinfer_decode_wrapper = None @@ -240,16 +253,14 @@ def __init__( def load_model(self) -> None: with CudaMemoryProfiler() as m: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - multimodal_config=self.multimodal_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config, - ) + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + multimodal_config=self.multimodal_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -274,6 +285,15 @@ def load_model(self) -> None: ) self.model = self.lora_manager.create_lora_manager(self.model) + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, self.device, + self.prompt_adapter_config) + self.model = ( + self.prompt_adapter_manager.create_prompt_adapter_manager( + self.model)) + if self.kv_cache_dtype == "fp8" and is_hip(): # Currently only ROCm accepts kv-cache scaling factors # via quantization_param_path and this will be deprecated @@ -354,6 +374,9 @@ def _prepare_model_input_tensors( lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + prompt_adapter_index_mapping: List[int] = [] + prompt_adapter_prompt_mapping: List[int] = [] + prompt_adapter_requests: Set[PromptAdapterRequest] = set() seq_lens: List[int] = [] prefill_seq_lens: List[int] = [] @@ -504,6 +527,7 @@ def _prepare_model_input_tensors( input_tokens.extend(tokens) input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id + prompt_adapter_id = seq_group_metadata.prompt_adapter_id if is_prompt: assert len(seq_ids) == 1 @@ -534,6 +558,21 @@ def _prepare_model_input_tensors( mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) + if prompt_adapter_id > 0 and is_prompt: + prompt_adapter_requests.add( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.\ + prompt_adapter_num_virtual_tokens + pm = [prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + prompt_adapter_index_mapping += pm + prompt_adapter_prompt_mapping.extend( + [prompt_adapter_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + is_profile_run = _is_block_tables_empty( seq_group_metadata.block_tables) if is_profile_run: @@ -618,12 +657,11 @@ def _prepare_model_input_tensors( seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) - + prompt_adapter_index_mapping.append(0) if self.attn_backend.get_name() == "flashinfer": last_paged_kv_indptr = paged_kv_indptr[-1] paged_kv_indptr.append(last_paged_kv_indptr) paged_kv_last_page_len.append(0) - batch_size = graph_batch_size num_decode_tokens = batch_size @@ -759,6 +797,14 @@ def _prepare_model_input_tensors( else: lora_mapping = None + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + prompt_adapter_index_mapping, + prompt_adapter_prompt_mapping, + ) + else: + prompt_adapter_mapping = None + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, device=self.device) request_ids_to_seq_ids = { @@ -776,7 +822,10 @@ def _prepare_model_input_tensors( lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=prompt_adapter_requests, + ) @torch.inference_mode() def profile_run(self) -> None: @@ -878,33 +927,67 @@ def profile_run(self) -> None: def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_loras() + self.lora_manager.remove_all_adapters() def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_loras(lora_requests, lora_mapping) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_lora(lora_request) + return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_lora(lora_id) + return self.lora_manager.remove_adapter(lora_id) def pin_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_lora(lora_id) + return self.lora_manager.pin_adapter(lora_id) def list_loras(self) -> Set[int]: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_loras() + return self.lora_manager.list_adapters() + + def remove_all_prompt_adapters(self): + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.remove_all_adapters() + + def set_active_prompt_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest], + prompt_adapter_mapping: PromptAdapterMapping) -> None: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.set_active_adapters( + prompt_adapter_requests, prompt_adapter_mapping) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.list_adapters() @torch.inference_mode() def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: @@ -1063,6 +1146,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: ) self.set_active_loras(set(), lora_mapping) + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + [-1] * batch_size, + [-1] * batch_size, + ) + self.set_active_prompt_adapters( + set(), prompt_adapter_mapping) + graph_runner = CUDAGraphRunner( self.model, self.attn_backend.get_name()) @@ -1189,6 +1280,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + if self.attn_backend.get_name() == "flashinfer": assert model_input.attn_metadata is not None assert model_input.input_tokens is not None diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 58707269bd68c..857cd86beff92 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -16,6 +17,7 @@ from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner @@ -45,6 +47,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, ) -> None: @@ -59,6 +62,7 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if parallel_config and is_driver_worker: assert rank % parallel_config.tensor_parallel_size == 0, \ @@ -92,6 +96,7 @@ def __init__( lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, multimodal_config=multimodal_config, **speculative_args, ) @@ -296,6 +301,19 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.model_runner.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_runner.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.remove_lora(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.model_runner.list_prompt_adapters() + @property def max_model_len(self) -> int: return self.model_config.max_model_len diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 03b9cce5ae792..e03f24fdfc41a 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -8,7 +8,7 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.distributed import broadcast_tensor_dict from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -88,6 +88,7 @@ def __init__( lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, @@ -98,6 +99,7 @@ def __init__( self.lora_config = lora_config self.load_config = load_config self.cache_config = cache_config + self.prompt_adapter_config = prompt_adapter_config self.multimodal_config = multimodal_config self.is_driver_worker = is_driver_worker diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 94dfcfec37757..6a822c2ba3e7a 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -10,7 +10,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -47,6 +48,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: assert device_config.device_type == "xpu" @@ -63,6 +65,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0."