Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Asynchronous Output Processor #7049

Merged
merged 21 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def run_vllm(
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand All @@ -110,6 +111,7 @@ def run_vllm(
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -237,7 +239,8 @@ def main(args: argparse.Namespace):
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format)
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -418,6 +421,11 @@ def main(args: argparse.Namespace):
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
"--disable-async-output-proc",
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
6 changes: 6 additions & 0 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def test_models(
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
# Due to low-precision numerical divergence, this test is too sensitive to
# the async postprocessor
@pytest.mark.parametrize("disable_async_output_proc", [True])
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
def test_models_with_fp8_kv_cache(
vllm_runner,
example_prompts,
Expand All @@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache(
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
disable_async_output_proc: bool,
) -> None:
"""
Only checks log probs match between chunked-prefill and
Expand Down Expand Up @@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
**extra_kwargs,
) as vllm_model:
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
Expand All @@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
**extra_kwargs,
) as vllm_model:
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
Expand Down
1 change: 0 additions & 1 deletion tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def test_swap_infeasible(
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1]

with vllm_runner(
model,
dtype=dtype,
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):


def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_maximal_decoding():
"""Verify decoding requests are prioritized."""
block_size = 4
max_seqs = 2
max_model_len = 2
max_model_len = 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We needed to add a detection into the scheduler of when the sequence reaches max_model_len, since this is the only case of stopping that needs to be detected immediately (and cannot have the extra +1 decode run). As a result, this test started to fail because it was limited to tiny length, where the actual length used is larger.

max_num_batched_tokens = 2
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
Expand Down
2 changes: 1 addition & 1 deletion tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):


def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
Expand Down
2 changes: 2 additions & 0 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
str(TP_SIZE),
"--distributed-executor-backend",
DIST_BACKEND,
# disable async output processor to test PP
"--disable-async-output-proc",
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
]

# compare without pipeline parallelism
Expand Down
2 changes: 2 additions & 0 deletions tests/distributed/test_pp_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
str(PP_SIZE),
"--distributed-executor-backend",
"mp",
# disable output proc callback to test PP
"--disable-async-output-proc",
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
]
os.environ["VLLM_ATTENTION_BACKEND"] = ATTN_BACKEND

Expand Down
155 changes: 103 additions & 52 deletions tests/engine/test_stop_strings.py
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -7,106 +7,157 @@
MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200

IS_ASYNC = False


@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
with vllm_runner(MODEL) as vllm_model:
yield vllm_model


@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False,
use_async_output_proc: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)

output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None

if use_async_output_proc:
llm_engine.step()

while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs

# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason

assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason


def _set_async_mode(llm_engine, is_async):
llm_engine.scheduler[0].use_async_output_proc = is_async


def _stop_basic(llm_engine, is_async):
_test_stopping(llm_engine,
stop=["."],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".")
expected_reason=".",
use_async_output_proc=is_async)

_test_stopping(vllm_model.model.llm_engine,
_test_stopping(llm_engine,
stop=["."],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".")
expected_reason=".",
use_async_output_proc=is_async)


@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
def _stop_multi_tokens(llm_engine, is_async):
_test_stopping(
vllm_model.model.llm_engine,
llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo")
expected_reason="group of peo",
use_async_output_proc=is_async)

_test_stopping(
vllm_model.model.llm_engine,
llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo")
expected_reason="group of peo",
use_async_output_proc=is_async)


@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
def _stop_partial_token(llm_engine, is_async):
_test_stopping(llm_engine,
stop=["gani"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani")
expected_reason="gani",
use_async_output_proc=is_async)

_test_stopping(vllm_model.model.llm_engine,
_test_stopping(llm_engine,
stop=["gani"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani")
expected_reason="gani",
use_async_output_proc=is_async)


@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
def _stop_token_id(llm_engine, is_async):
# token id 13013 => " organization"

_test_stopping(vllm_model.model.llm_engine,
_test_stopping(llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013)
expected_reason=13013,
use_async_output_proc=is_async)

_test_stopping(vllm_model.model.llm_engine,
_test_stopping(llm_engine,
stop_token_ids=[13013],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013)
expected_reason=13013,
use_async_output_proc=is_async)


def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_basic(vllm_model.model.llm_engine, is_async=True)

output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_basic(vllm_model.model.llm_engine, is_async=False)

# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason

assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)

_set_async_mode(vllm_model.model.llm_engine, False)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_partial_token(vllm_model.model.llm_engine, is_async=True)

_set_async_mode(vllm_model.model.llm_engine, False)
_stop_partial_token(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_token_id(vllm_model.model.llm_engine, is_async=True)

_set_async_mode(vllm_model.model.llm_engine, False)
_stop_token_id(vllm_model.model.llm_engine, is_async=False)
3 changes: 3 additions & 0 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if eager_mode:
ms_server_args.append("--enforce-eager")

Expand Down
Loading
Loading