Skip to content

Commit

Permalink
[Core] Optimize SPMD architecture with delta + serialization optimiza…
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored Aug 19, 2024
1 parent 200a2ff commit ff7ec82
Show file tree
Hide file tree
Showing 36 changed files with 727 additions and 351 deletions.
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
msgspec
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
Expand Down
18 changes: 18 additions & 0 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from prometheus_client import REGISTRY

import vllm.envs as envs
from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT)
Expand All @@ -24,6 +25,13 @@
"tests/basic_correctness/test_preemption.py`")


@pytest.fixture
def worker_use_ray() -> bool:
# When SPMD worker is used, use ray_use_worker=True
# to test delta input optimization works with preemption.
return envs.VLLM_USE_RAY_SPMD_WORKER


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
Expand All @@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
worker_use_ray: bool,
) -> None:
"""Ensure that chunked prefill works with preemption."""
max_num_seqs = min(chunked_prefill_token_size, 256)
Expand All @@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
Expand All @@ -79,6 +89,7 @@ def test_preemption(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""By default, recompute preemption is enabled"""

Expand All @@ -89,6 +100,7 @@ def test_preemption(
model,
dtype=dtype,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
Expand Down Expand Up @@ -132,6 +144,7 @@ def test_swap(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
Expand All @@ -144,6 +157,7 @@ def test_swap(
dtype=dtype,
swap_space=10,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
Expand Down Expand Up @@ -188,6 +202,7 @@ def test_swap_infeasible(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
Expand All @@ -204,6 +219,7 @@ def test_swap_infeasible(
# decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
Expand All @@ -230,6 +246,7 @@ def test_preemption_infeasible(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16
Expand All @@ -244,6 +261,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever.
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True)
Expand Down
33 changes: 33 additions & 0 deletions tests/core/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import msgspec

from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.sequence import ExecuteModelRequest

from ..spec_decode.utils import create_batch


def test_msgspec_serialization():
num_lookahead_slots = 4
seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=num_lookahead_slots,
running_queue_size=4)

encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
req = decoder.decode(encoder.encode(execute_model_req))
expected = execute_model_req.seq_group_metadata_list
actual = req.seq_group_metadata_list
assert (len(expected) == len(actual))
expected = expected[0]
actual = actual[0]

assert expected.block_tables == actual.block_tables
assert expected.is_prompt == actual.is_prompt
assert expected.request_id == actual.request_id
assert (expected.seq_data[0].prompt_token_ids ==
actual.seq_data[0].prompt_token_ids)
assert (expected.seq_data[0].output_token_ids ==
actual.seq_data[0].output_token_ids)
3 changes: 2 additions & 1 deletion tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, test_suite", [
"model, distributed_executor_backend, attention_backend, "
"test_suite", [
("facebook/opt-125m", "ray", "", "L4"),
("facebook/opt-125m", "mp", "", "L4"),
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),
Expand Down
7 changes: 7 additions & 0 deletions tests/distributed/test_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
```
"""

import os

import pytest

from vllm.utils import cuda_device_count_stateless
Expand All @@ -30,6 +32,11 @@ def test_models(
model: str,
distributed_executor_backend: str,
) -> None:
if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa
assert distributed_executor_backend == "ray"
# test ray adag
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"

dtype = "half"
max_tokens = 5
Expand Down
25 changes: 19 additions & 6 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import random
from array import array
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch

Expand All @@ -10,7 +11,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import Counter, is_pin_memory_available


Expand Down Expand Up @@ -56,7 +58,9 @@ def _do_sample(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -201,7 +205,8 @@ def create_sampling_params(min_tokens,

def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
random.choices(range(0, VOCAB_SIZE), k=num_input))
array(VLLM_TOKEN_ID_ARRAY_TYPE,
random.choices(range(0, VOCAB_SIZE), k=num_input)))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
Expand Down Expand Up @@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
Expand Down Expand Up @@ -650,7 +659,11 @@ def test_sampling_params(sampling_params: List[SamplingParams]):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0:
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
[1, 2, 3]))
},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
Expand Down
9 changes: 6 additions & 3 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from array import array
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
Expand All @@ -9,7 +10,8 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
Expand Down Expand Up @@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
seq_data={
i:
SequenceData(
prompt_token_ids=prompt_token_ids[:],
output_token_ids=cont_token_ids[:],
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
cont_token_ids[:]),
),
},
sampling_params=SamplingParams(temperature=0.0, ),
Expand Down
8 changes: 6 additions & 2 deletions tests/test_logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from array import array
from typing import Tuple
from unittest.mock import patch

Expand All @@ -8,7 +9,8 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -69,7 +71,9 @@ def pick_ith(token_ids, logits):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
Expand Down
7 changes: 5 additions & 2 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from array import array

import pytest

from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput)

from .core.utils import create_dummy_prompt
Expand Down Expand Up @@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):


def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
Expand Down
16 changes: 11 additions & 5 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from array import array
from typing import List

import pytest
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_cpu
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner

Expand Down Expand Up @@ -125,10 +127,12 @@ def test_prepare_prompt(
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand Down Expand Up @@ -319,10 +323,12 @@ def test_prepare_decode(
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
Expand Down
Loading

0 comments on commit ff7ec82

Please sign in to comment.