From 1b6de8352b878348974b3f117cbb68ed18daa609 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 17 Sep 2024 15:34:27 +0800 Subject: [PATCH] [Benchmark] Support sample from HF datasets and image input for benchmark_serving (#8495) --- benchmarks/backend_request_func.py | 6 +- benchmarks/benchmark_serving.py | 239 +++++++++++++++++++++-------- 2 files changed, 177 insertions(+), 68 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 3243bb94f787c..3def4a6d67acf 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -25,6 +25,7 @@ class RequestFuncInput: best_of: int = 1 use_beam_search: bool = False logprobs: Optional[int] = None + multi_modal_content: Optional[dict] = None @dataclass @@ -312,12 +313,15 @@ async def async_request_openai_chat_completions( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) payload = { "model": request_func_input.model, "messages": [ { "role": "user", - "content": request_func_input.prompt, + "content": content }, ], "temperature": 0.0, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9ba3f649810b7..3ace910a6cac6 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -24,6 +24,8 @@ """ import argparse import asyncio +import base64 +import io import json import os import random @@ -31,11 +33,13 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput) +from datasets import load_dataset +from PIL.Image import Image from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase @@ -84,7 +88,7 @@ def sample_sharegpt_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, int, int]]: +) -> List[Tuple[str, int, int, None]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") # Load the dataset. @@ -119,7 +123,7 @@ def sample_sharegpt_requests( if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append((prompt, prompt_len, output_len, None)) return filtered_dataset @@ -131,7 +135,7 @@ def sample_sonnet_requests( output_len: int, prefix_len: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, str, int, int]]: +) -> List[Tuple[str, str, int, int, None]]: assert ( input_len > prefix_len ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." @@ -189,7 +193,65 @@ def sample_sonnet_requests( message, add_generation_prompt=True, tokenize=False) prompt_len = len(tokenizer(prompt_formatted).input_ids) sampled_requests.append( - (prompt, prompt_formatted, prompt_len, output_len)) + (prompt, prompt_formatted, prompt_len, output_len, None)) + + return sampled_requests + + +def sample_hf_requests( + dataset_path: str, + dataset_subset: str, + dataset_split: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + dataset = load_dataset(dataset_path, + name=dataset_subset, + split=dataset_split, + streaming=True) + assert "conversations" in dataset.features, ( + "HF Dataset must have 'conversations' column.") + filtered_dataset = dataset.shuffle().filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests: List[Tuple[str, int, int, Dict[str, + Collection[str]]]] = [] + for data in filtered_dataset: + if len(sampled_requests) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = data["conversations"][0]["value"] + prompt_token_ids = tokenizer(prompt).input_ids + completion = data["conversations"][1]["value"] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + if "image" in data and isinstance(data["image"], Image): + image: Image = data["image"] + image = image.convert("RGB") + image_data = io.BytesIO() + image.save(image_data, format='JPEG') + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + mm_content = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + else: + mm_content = None + + sampled_requests.append((prompt, prompt_len, output_len, mm_content)) return sampled_requests @@ -223,8 +285,8 @@ def sample_random_requests( [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) - input_requests.append( - (prompt, int(prefix_len + input_lens[i]), int(output_lens[i]))) + input_requests.append((prompt, int(prefix_len + input_lens[i]), + int(output_lens[i]), None)) return input_requests @@ -343,7 +405,12 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0]) + if backend != "openai-chat" and test_mm_content is not None: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( model=model_id, prompt=test_prompt, @@ -353,6 +420,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=test_mm_content, ) test_output = await request_func(request_func_input=test_input) if not test_output.success: @@ -373,6 +441,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=test_mm_content, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -385,7 +454,7 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request + prompt, prompt_len, output_len, mm_content = request request_func_input = RequestFuncInput( model=model_id, prompt=prompt, @@ -395,6 +464,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=mm_content, ) tasks.append( asyncio.create_task( @@ -575,6 +645,16 @@ def main(args: argparse.Namespace): for prompt, prompt_formatted, prompt_len, output_len in input_requests] + elif args.dataset_name == "hf": + input_requests = sample_hf_requests( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.hf_output_len, + ) + elif args.dataset_name == "random": input_requests = sample_random_requests( prefix_len=args.random_prefix_len, @@ -685,13 +765,14 @@ def main(args: argparse.Namespace): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "sonnet", "random"], + choices=["sharegpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) parser.add_argument("--dataset-path", type=str, default=None, - help="Path to the dataset.") + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.") parser.add_argument( "--model", type=str, @@ -718,26 +799,6 @@ def main(args: argparse.Namespace): default=1000, help="Number of prompts to process.", ) - parser.add_argument( - "--sharegpt-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") - parser.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", - ) parser.add_argument( "--logprobs", type=int, @@ -748,42 +809,6 @@ def main(args: argparse.Namespace): "logprob is returned for each token; or (2) if beam search " "is enabled 1 logprob per token is computed"), ) - parser.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--random-input-len", - type=int, - default=1024, - help= - "Number of input tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-output-len", - type=int, - default=128, - help= - "Number of output tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-range-ratio", - type=float, - default=1.0, - help="Range of sampled ratio of input/output length, " - "used only for random sampling.", - ) - parser.add_argument( - "--random-prefix-len", - type=int, - default=0, - help="Number of fixed prefix tokens before random " - " context. The length range of context in a random " - " request is [random-prefix-len, " - " random-prefix-len + random-prefix-len * random-range-ratio).") parser.add_argument( "--request-rate", type=float, @@ -857,5 +882,85 @@ def main(args: argparse.Namespace): "Use \"--percentile-metrics\" to select metrics.", ) + # group for dataset specific arguments + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=1.0, + help="Range of sampled ratio of input/output length, " + "used only for random sampling.", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + args = parser.parse_args() main(args)