Skip to content

Commit

Permalink
Offline inference benchmarking (#44)
Browse files Browse the repository at this point in the history
* adding tests for offline inference

* remove comment
  • Loading branch information
tstescoTT authored Nov 29, 2024
1 parent f76bd44 commit 2a3261d
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 3 deletions.
10 changes: 10 additions & 0 deletions benchmarking/benchmark_vllm_offline_llama31_70b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# for Time To First Token (TTFT) and throughput (TPS, TPOT, ITL)
python ~/tests/benchmark_vllm_offline_inference.py --max_seqs_in_batch 32 --input_seq_len 128 --output_seq_len 128 --greedy_sampling
python ~/tests/benchmark_vllm_offline_inference.py --max_seqs_in_batch 32 --input_seq_len 512 --output_seq_len 512 --greedy_sampling
python ~/tests/benchmark_vllm_offline_inference.py --max_seqs_in_batch 32 --input_seq_len 1024 --output_seq_len 1024 --greedy_sampling
python ~/tests/benchmark_vllm_offline_inference.py --max_seqs_in_batch 32 --input_seq_len 2048 --output_seq_len 2048 --greedy_sampling

219 changes: 219 additions & 0 deletions tests/benchmark_vllm_offline_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

import argparse
import time
import uvloop
from tqdm import tqdm

from vllm import LLM, ModelRegistry, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.utils import merge_async_iterators

# importing logging utils
from utils.logging_utils import RawStatLogger

# Import and register model from tt-metal
from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration

ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompts_json",
type=str,
default="/home/user/vllm/tt_metal/prompts.json",
help="Path to JSON file containing prompts",
)
parser.add_argument(
"--input_seq_len",
type=int,
default=128,
help="Length of dummy prompts for performance measurement",
)
parser.add_argument(
"--output_seq_len", type=int, default=128, help="Length of outputs"
)
parser.add_argument(
"--greedy_sampling",
action="store_true",
help="Use greedy decoding instead of top-k/p",
)
parser.add_argument(
"--max_seqs_in_batch",
type=int,
default=32,
help="Maximum batch size for inference",
)
parser.add_argument("--async_engine", action="store_true", help="Use async engine")
return parser.parse_args()


def run_inference(
prompts_json,
max_tokens,
max_seqs_in_batch,
prompt_len,
greedy_sampling, # Option to use greedy decoding instead of top-k/p
async_engine,
):
# LLM args
engine_kw_args = {
"model": "meta-llama/Llama-3.1-70B-Instruct",
"block_size": 64,
"max_num_seqs": max_seqs_in_batch,
"max_model_len": 131072,
"disable_log_stats": False,
"max_num_batched_tokens": 131072,
"log_global_stats": True,
"num_scheduler_steps": 10,
"disable_async_output_proc": True,
}

# Generation args
ignore_eos = True

print("Generating prompts with output length", max_tokens)
if greedy_sampling:
sampling_params = SamplingParams(
max_tokens=max_tokens, ignore_eos=ignore_eos, temperature=0.0
)
else:
sampling_params = SamplingParams(
max_tokens=max_tokens,
ignore_eos=ignore_eos,
top_k=10,
top_p=0.9,
temperature=1.0,
)

# Prepare inputs
assert prompt_len is not None, "prompt_len is required to generate dummy prompts"
print("Measuring performance with dummy prompts of length", prompt_len)
prompt_token_ids = [[0] * prompt_len] * max_seqs_in_batch # dummy prompts
sampling_params = (
sampling_params[:max_seqs_in_batch]
if isinstance(sampling_params, list)
else sampling_params
)

# check prompt lengths fit in model context
max_model_len = engine_kw_args["max_model_len"]
assert_str = f"prompt length ({prompt_len}) + num generated tokens ({sampling_params.max_tokens}) will exceed max_model_len ({max_model_len})"
assert prompt_len + sampling_params.max_tokens <= max_model_len, assert_str

# Create and run LLM
if not async_engine:
llm = LLM(**engine_kw_args)
# Add raw stats logging to the llm engine
llm.llm_engine.stat_loggers["raw_logging"] = RawStatLogger(
engine_kw_args["num_scheduler_steps"],
batch_size=engine_kw_args["max_num_seqs"],
)
run_inference_perf(llm, prompt_token_ids, sampling_params)
else:
print("Using async engine")
engine_args = AsyncEngineArgs(**engine_kw_args)

async def _run_inference_async():
async with build_async_engine_client_from_engine_args(engine_args) as llm:
await run_inference_perf_async(llm, prompt_token_ids, sampling_params)

uvloop.run(_run_inference_async())


def run_inference_perf(
llm: LLM,
prompt_token_ids,
sampling_params,
N_warmup=1,
N_inference=4,
):
"""Run llm N_inference times and measure the average time taken per inference run."""
for i in tqdm(range(N_inference), desc="Inference runs"):
if i == N_warmup:
start_time = time.perf_counter()
generate_tokens(
llm, None, sampling_params, prompt_token_ids, print_output=False
)
avg_time = (time.perf_counter() - start_time) / (N_inference - N_warmup)
print(f"Average time taken per inference run: {avg_time:.2f} s")


async def run_inference_perf_async(
llm: LLM,
prompt_token_ids,
sampling_params,
N_warmup=1,
N_inference=4,
):
for i in tqdm(range(N_inference), desc="Inference runs"):
if i == N_warmup:
start_time = time.perf_counter()
await generate_tokens_async(
llm, None, sampling_params, prompt_token_ids, print_output=False
)
avg_time = (time.perf_counter() - start_time) / (N_inference - N_warmup)
print(f"Average time taken per inference run: {avg_time:.2f} s")


def generate_tokens(
llm: LLM, prompts, sampling_params, prompt_token_ids=None, print_output=True
):
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params, prompt_token_ids)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
if print_output:
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


async def generate_tokens_async(
llm: MQLLMEngineClient,
prompts,
sampling_params,
prompt_token_ids=None,
print_output=True,
):
# Use tokenized prompts if provided
if prompt_token_ids is not None:
prompts = []
for single_prompt_token_ids in prompt_token_ids:
prompts.append(TokensPrompt(prompt_token_ids=single_prompt_token_ids))

if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts)

generators = []
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
prompt = res.prompt
generated_text = res.outputs[0].text
if print_output and res.finished:
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
args = parse_args()
run_inference(
args.prompts_json,
prompt_len=args.input_seq_len,
max_tokens=args.output_seq_len,
greedy_sampling=args.greedy_sampling,
max_seqs_in_batch=args.max_seqs_in_batch,
async_engine=args.async_engine,
)
33 changes: 33 additions & 0 deletions tests/mock_benchmark_vllm_offline_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

from unittest.mock import patch

from vllm import ModelRegistry

# import mocking utils + classes to mock
from mock_vllm_model import MockModel, new_allocate_kv_cache, new_init_cache_enginer
from vllm.worker.tt_worker import TTCacheEngine, TTWorker
from benchmark_vllm_offline_inference import run_inference, parse_args

ModelRegistry.register_model("TTLlamaForCausalLM", MockModel)


@patch.object(TTWorker, "init_device", new=lambda x: None)
@patch.object(TTWorker, "_init_cache_engine", new=new_init_cache_enginer)
@patch.object(TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache)
def mock_run_inference(*args, **kwargs):
run_inference(*args, **kwargs)


if __name__ == "__main__":
args = parse_args()
mock_run_inference(
args.prompts_json,
prompt_len=args.input_seq_len,
max_tokens=args.output_seq_len,
greedy_sampling=args.greedy_sampling,
max_seqs_in_batch=args.max_seqs_in_batch,
async_engine=args.async_engine,
)
11 changes: 8 additions & 3 deletions tests/mock_vllm_offline_inference_tt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

import argparse
import json
import time
import uvloop
from unittest.mock import patch

# TODO: import logging_init_wrapper from vllm-tt-metal-llama3-70b/src/logging_utils.py after refactoring
from mock_vllm_model import (
MockModel,
new_allocate_kv_cache,
new_init_cache_enginer,
logging_init_wrapper,
)
from tqdm import tqdm
from vllm import LLM, ModelRegistry, SamplingParams
Expand All @@ -23,6 +25,9 @@
from vllm.worker.tt_worker import TTCacheEngine, TTWorker
from vllm.engine.llm_engine import LLMEngine

from utils.logging_utils import logging_init_wrapper


ModelRegistry.register_model("TTLlamaForCausalLM", MockModel)


Expand All @@ -46,7 +51,7 @@ def run_inference(
):
# LLM args
engine_kw_args = {
"model": "meta-llama/Meta-Llama-3.1-70B",
"model": "meta-llama/Llama-3.1-70B-Instruct",
"block_size": 64,
"max_num_seqs": max_seqs_in_batch,
"max_model_len": 131072,
Expand Down

0 comments on commit 2a3261d

Please sign in to comment.