Skip to content

Commit

Permalink
[Benchmark] Support sample from HF datasets and image input for bench…
Browse files Browse the repository at this point in the history
…mark_serving (#8495)
  • Loading branch information
Isotr0py authored Sep 17, 2024
1 parent cbdb252 commit 1b6de83
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 68 deletions.
6 changes: 5 additions & 1 deletion benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class RequestFuncInput:
best_of: int = 1
use_beam_search: bool = False
logprobs: Optional[int] = None
multi_modal_content: Optional[dict] = None


@dataclass
Expand Down Expand Up @@ -312,12 +313,15 @@ async def async_request_openai_chat_completions(

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": request_func_input.model,
"messages": [
{
"role": "user",
"content": request_func_input.prompt,
"content": content
},
],
"temperature": 0.0,
Expand Down
239 changes: 172 additions & 67 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@
"""
import argparse
import asyncio
import base64
import io
import json
import os
import random
import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple

import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
RequestFuncOutput)
from datasets import load_dataset
from PIL.Image import Image
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -84,7 +88,7 @@ def sample_sharegpt_requests(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
) -> List[Tuple[str, int, int, None]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
Expand Down Expand Up @@ -119,7 +123,7 @@ def sample_sharegpt_requests(
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_dataset.append((prompt, prompt_len, output_len, None))

return filtered_dataset

Expand All @@ -131,7 +135,7 @@ def sample_sonnet_requests(
output_len: int,
prefix_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, str, int, int]]:
) -> List[Tuple[str, str, int, int, None]]:
assert (
input_len > prefix_len
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
Expand Down Expand Up @@ -189,7 +193,65 @@ def sample_sonnet_requests(
message, add_generation_prompt=True, tokenize=False)
prompt_len = len(tokenizer(prompt_formatted).input_ids)
sampled_requests.append(
(prompt, prompt_formatted, prompt_len, output_len))
(prompt, prompt_formatted, prompt_len, output_len, None))

return sampled_requests


def sample_hf_requests(
dataset_path: str,
dataset_subset: str,
dataset_split: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
dataset = load_dataset(dataset_path,
name=dataset_subset,
split=dataset_split,
streaming=True)
assert "conversations" in dataset.features, (
"HF Dataset must have 'conversations' column.")
filtered_dataset = dataset.shuffle().filter(
lambda x: len(x["conversations"]) >= 2)
sampled_requests: List[Tuple[str, int, int, Dict[str,
Collection[str]]]] = []
for data in filtered_dataset:
if len(sampled_requests) == num_requests:
break

# Tokenize the prompts and completions.
prompt = data["conversations"][0]["value"]
prompt_token_ids = tokenizer(prompt).input_ids
completion = data["conversations"][1]["value"]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue

if "image" in data and isinstance(data["image"], Image):
image: Image = data["image"]
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
image_base64 = base64.b64encode(
image_data.getvalue()).decode("utf-8")
mm_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
else:
mm_content = None

sampled_requests.append((prompt, prompt_len, output_len, mm_content))

return sampled_requests

Expand Down Expand Up @@ -223,8 +285,8 @@ def sample_random_requests(
[(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])])

input_requests.append(
(prompt, int(prefix_len + input_lens[i]), int(output_lens[i])))
input_requests.append((prompt, int(prefix_len + input_lens[i]),
int(output_lens[i]), None))

return input_requests

Expand Down Expand Up @@ -343,7 +405,12 @@ async def benchmark(
raise ValueError(f"Unknown backend: {backend}")

print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
input_requests[0])
if backend != "openai-chat" and test_mm_content is not None:
# multi-modal benchmark is only available on OpenAI Chat backend.
raise ValueError(
"Multi-modal content is only supported on 'openai-chat' backend.")
test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
Expand All @@ -353,6 +420,7 @@ async def benchmark(
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
Expand All @@ -373,6 +441,7 @@ async def benchmark(
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
Expand All @@ -385,7 +454,7 @@ async def benchmark(
benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(
model=model_id,
prompt=prompt,
Expand All @@ -395,6 +464,7 @@ async def benchmark(
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=mm_content,
)
tasks.append(
asyncio.create_task(
Expand Down Expand Up @@ -575,6 +645,16 @@ def main(args: argparse.Namespace):
for prompt, prompt_formatted, prompt_len,
output_len in input_requests]

elif args.dataset_name == "hf":
input_requests = sample_hf_requests(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
dataset_split=args.hf_split,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.hf_output_len,
)

elif args.dataset_name == "random":
input_requests = sample_random_requests(
prefix_len=args.random_prefix_len,
Expand Down Expand Up @@ -685,13 +765,14 @@ def main(args: argparse.Namespace):
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "sonnet", "random"],
choices=["sharegpt", "sonnet", "random", "hf"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset.")
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.")
parser.add_argument(
"--model",
type=str,
Expand All @@ -718,26 +799,6 @@ def main(args: argparse.Namespace):
default=1000,
help="Number of prompts to process.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
parser.add_argument(
"--sonnet-input-len",
type=int,
default=550,
help=
"Number of input tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--sonnet-output-len",
type=int,
default=150,
help=
"Number of output tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--logprobs",
type=int,
Expand All @@ -748,42 +809,6 @@ def main(args: argparse.Namespace):
"logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"),
)
parser.add_argument(
"--sonnet-prefix-len",
type=int,
default=200,
help=
"Number of prefix tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--random-input-len",
type=int,
default=1024,
help=
"Number of input tokens per request, used only for random sampling.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=128,
help=
"Number of output tokens per request, used only for random sampling.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=1.0,
help="Range of sampled ratio of input/output length, "
"used only for random sampling.",
)
parser.add_argument(
"--random-prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before random "
" context. The length range of context in a random "
" request is [random-prefix-len, "
" random-prefix-len + random-prefix-len * random-range-ratio).")
parser.add_argument(
"--request-rate",
type=float,
Expand Down Expand Up @@ -857,5 +882,85 @@ def main(args: argparse.Namespace):
"Use \"--percentile-metrics\" to select metrics.",
)

# group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options")
sonnet_group.add_argument(
"--sonnet-input-len",
type=int,
default=550,
help=
"Number of input tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-output-len",
type=int,
default=150,
help=
"Number of output tokens per request, used only for sonnet dataset.",
)
sonnet_group.add_argument(
"--sonnet-prefix-len",
type=int,
default=200,
help=
"Number of prefix tokens per request, used only for sonnet dataset.",
)

sharegpt_group = parser.add_argument_group("sharegpt dataset options")
sharegpt_group.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")

random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
"--random-input-len",
type=int,
default=1024,
help=
"Number of input tokens per request, used only for random sampling.",
)
random_group.add_argument(
"--random-output-len",
type=int,
default=128,
help=
"Number of output tokens per request, used only for random sampling.",
)
random_group.add_argument(
"--random-range-ratio",
type=float,
default=1.0,
help="Range of sampled ratio of input/output length, "
"used only for random sampling.",
)
random_group.add_argument(
"--random-prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before random "
" context. The length range of context in a random "
" request is [random-prefix-len, "
" random-prefix-len + random-prefix-len * random-range-ratio).")

hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
)

args = parser.parse_args()
main(args)

0 comments on commit 1b6de83

Please sign in to comment.