diff --git a/benchmarking/benchmark_vllm_offline_llama31_70b.sh b/benchmarking/benchmark_vllm_offline_llama31_70b.sh new file mode 100644 index 0000000..36454dc --- /dev/null +++ b/benchmarking/benchmark_vllm_offline_llama31_70b.sh @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# for Time To First Token (TTFT) and throughput (TPS, TPOT, ITL) +python ~/tests/benchmark_vllm_offline_inference.py --max_seqs_in_batch 32 --input_seq_len 128 --output_seq_len 128 --greedy_sampling +python ~/tests/benchmark_vllm_offline_inference.py --max_seqs_in_batch 32 --input_seq_len 512 --output_seq_len 512 --greedy_sampling +python ~/tests/benchmark_vllm_offline_inference.py --max_seqs_in_batch 32 --input_seq_len 1024 --output_seq_len 1024 --greedy_sampling +python ~/tests/benchmark_vllm_offline_inference.py --max_seqs_in_batch 32 --input_seq_len 2048 --output_seq_len 2048 --greedy_sampling + diff --git a/tests/benchmark_vllm_offline_inference.py b/tests/benchmark_vllm_offline_inference.py new file mode 100644 index 0000000..73ef98f --- /dev/null +++ b/tests/benchmark_vllm_offline_inference.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +import argparse +import time +import uvloop +from tqdm import tqdm + +from vllm import LLM, ModelRegistry, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args, +) +from vllm.inputs.data import TokensPrompt +from vllm.utils import merge_async_iterators + +# importing logging utils +from utils.logging_utils import RawStatLogger + +# Import and register model from tt-metal +from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration + +ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompts_json", + type=str, + default="/home/user/vllm/tt_metal/prompts.json", + help="Path to JSON file containing prompts", + ) + parser.add_argument( + "--input_seq_len", + type=int, + default=128, + help="Length of dummy prompts for performance measurement", + ) + parser.add_argument( + "--output_seq_len", type=int, default=128, help="Length of outputs" + ) + parser.add_argument( + "--greedy_sampling", + action="store_true", + help="Use greedy decoding instead of top-k/p", + ) + parser.add_argument( + "--max_seqs_in_batch", + type=int, + default=32, + help="Maximum batch size for inference", + ) + parser.add_argument("--async_engine", action="store_true", help="Use async engine") + return parser.parse_args() + + +def run_inference( + prompts_json, + max_tokens, + max_seqs_in_batch, + prompt_len, + greedy_sampling, # Option to use greedy decoding instead of top-k/p + async_engine, +): + # LLM args + engine_kw_args = { + "model": "meta-llama/Llama-3.1-70B-Instruct", + "block_size": 64, + "max_num_seqs": max_seqs_in_batch, + "max_model_len": 131072, + "disable_log_stats": False, + "max_num_batched_tokens": 131072, + "log_global_stats": True, + "num_scheduler_steps": 10, + "disable_async_output_proc": True, + } + + # Generation args + ignore_eos = True + + print("Generating prompts with output length", max_tokens) + if greedy_sampling: + sampling_params = SamplingParams( + max_tokens=max_tokens, ignore_eos=ignore_eos, temperature=0.0 + ) + else: + sampling_params = SamplingParams( + max_tokens=max_tokens, + ignore_eos=ignore_eos, + top_k=10, + top_p=0.9, + temperature=1.0, + ) + + # Prepare inputs + assert prompt_len is not None, "prompt_len is required to generate dummy prompts" + print("Measuring performance with dummy prompts of length", prompt_len) + prompt_token_ids = [[0] * prompt_len] * max_seqs_in_batch # dummy prompts + sampling_params = ( + sampling_params[:max_seqs_in_batch] + if isinstance(sampling_params, list) + else sampling_params + ) + + # check prompt lengths fit in model context + max_model_len = engine_kw_args["max_model_len"] + assert_str = f"prompt length ({prompt_len}) + num generated tokens ({sampling_params.max_tokens}) will exceed max_model_len ({max_model_len})" + assert prompt_len + sampling_params.max_tokens <= max_model_len, assert_str + + # Create and run LLM + if not async_engine: + llm = LLM(**engine_kw_args) + # Add raw stats logging to the llm engine + llm.llm_engine.stat_loggers["raw_logging"] = RawStatLogger( + engine_kw_args["num_scheduler_steps"], + batch_size=engine_kw_args["max_num_seqs"], + ) + run_inference_perf(llm, prompt_token_ids, sampling_params) + else: + print("Using async engine") + engine_args = AsyncEngineArgs(**engine_kw_args) + + async def _run_inference_async(): + async with build_async_engine_client_from_engine_args(engine_args) as llm: + await run_inference_perf_async(llm, prompt_token_ids, sampling_params) + + uvloop.run(_run_inference_async()) + + +def run_inference_perf( + llm: LLM, + prompt_token_ids, + sampling_params, + N_warmup=1, + N_inference=4, +): + """Run llm N_inference times and measure the average time taken per inference run.""" + for i in tqdm(range(N_inference), desc="Inference runs"): + if i == N_warmup: + start_time = time.perf_counter() + generate_tokens( + llm, None, sampling_params, prompt_token_ids, print_output=False + ) + avg_time = (time.perf_counter() - start_time) / (N_inference - N_warmup) + print(f"Average time taken per inference run: {avg_time:.2f} s") + + +async def run_inference_perf_async( + llm: LLM, + prompt_token_ids, + sampling_params, + N_warmup=1, + N_inference=4, +): + for i in tqdm(range(N_inference), desc="Inference runs"): + if i == N_warmup: + start_time = time.perf_counter() + await generate_tokens_async( + llm, None, sampling_params, prompt_token_ids, print_output=False + ) + avg_time = (time.perf_counter() - start_time) / (N_inference - N_warmup) + print(f"Average time taken per inference run: {avg_time:.2f} s") + + +def generate_tokens( + llm: LLM, prompts, sampling_params, prompt_token_ids=None, print_output=True +): + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params, prompt_token_ids) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + if print_output: + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +async def generate_tokens_async( + llm: MQLLMEngineClient, + prompts, + sampling_params, + prompt_token_ids=None, + print_output=True, +): + # Use tokenized prompts if provided + if prompt_token_ids is not None: + prompts = [] + for single_prompt_token_ids in prompt_token_ids: + prompts.append(TokensPrompt(prompt_token_ids=single_prompt_token_ids)) + + if not isinstance(sampling_params, list): + sampling_params = [sampling_params] * len(prompts) + + generators = [] + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + prompt = res.prompt + generated_text = res.outputs[0].text + if print_output and res.finished: + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + args = parse_args() + run_inference( + args.prompts_json, + prompt_len=args.input_seq_len, + max_tokens=args.output_seq_len, + greedy_sampling=args.greedy_sampling, + max_seqs_in_batch=args.max_seqs_in_batch, + async_engine=args.async_engine, + ) diff --git a/tests/mock_benchmark_vllm_offline_inference.py b/tests/mock_benchmark_vllm_offline_inference.py new file mode 100644 index 0000000..2cb3c10 --- /dev/null +++ b/tests/mock_benchmark_vllm_offline_inference.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +from unittest.mock import patch + +from vllm import ModelRegistry + +# import mocking utils + classes to mock +from mock_vllm_model import MockModel, new_allocate_kv_cache, new_init_cache_enginer +from vllm.worker.tt_worker import TTCacheEngine, TTWorker +from benchmark_vllm_offline_inference import run_inference, parse_args + +ModelRegistry.register_model("TTLlamaForCausalLM", MockModel) + + +@patch.object(TTWorker, "init_device", new=lambda x: None) +@patch.object(TTWorker, "_init_cache_engine", new=new_init_cache_enginer) +@patch.object(TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache) +def mock_run_inference(*args, **kwargs): + run_inference(*args, **kwargs) + + +if __name__ == "__main__": + args = parse_args() + mock_run_inference( + args.prompts_json, + prompt_len=args.input_seq_len, + max_tokens=args.output_seq_len, + greedy_sampling=args.greedy_sampling, + max_seqs_in_batch=args.max_seqs_in_batch, + async_engine=args.async_engine, + ) diff --git a/tests/mock_vllm_offline_inference_tt.py b/tests/mock_vllm_offline_inference_tt.py index f874d2d..71c3de9 100644 --- a/tests/mock_vllm_offline_inference_tt.py +++ b/tests/mock_vllm_offline_inference_tt.py @@ -1,15 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + import argparse import json import time import uvloop from unittest.mock import patch -# TODO: import logging_init_wrapper from vllm-tt-metal-llama3-70b/src/logging_utils.py after refactoring from mock_vllm_model import ( MockModel, new_allocate_kv_cache, new_init_cache_enginer, - logging_init_wrapper, ) from tqdm import tqdm from vllm import LLM, ModelRegistry, SamplingParams @@ -23,6 +25,9 @@ from vllm.worker.tt_worker import TTCacheEngine, TTWorker from vllm.engine.llm_engine import LLMEngine +from utils.logging_utils import logging_init_wrapper + + ModelRegistry.register_model("TTLlamaForCausalLM", MockModel) @@ -46,7 +51,7 @@ def run_inference( ): # LLM args engine_kw_args = { - "model": "meta-llama/Meta-Llama-3.1-70B", + "model": "meta-llama/Llama-3.1-70B-Instruct", "block_size": 64, "max_num_seqs": max_seqs_in_batch, "max_model_len": 131072,