From c7ff25792f01028d1d1dba28c4ff37b329f238d9 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 23 Aug 2024 17:09:34 +0800 Subject: [PATCH] Custom benchmark with parameters (#88) * Custom benchmark with parameters * Mention arguments for benchmark.py * Tweak --- README.md | 3 +++ examples/benchmark.py | 48 +++++++++++++++++++++-------------------- src/openai/streaming.rs | 2 -- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index f51b896..3cb96fb 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,9 @@ After the `candle-vllm` service is running, run the Python script and enjoy effi ## Batched requests +``` shell +python3 examples/benchmark.py --batch 16 --max_tokens 1024 +``` Refer to `examples/benchmark.py` ``` python diff --git a/examples/benchmark.py b/examples/benchmark.py index 1595b02..d5df23b 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -3,15 +3,26 @@ from openai import Stream from openai.types.chat import ChatCompletionChunk from typing import List -# Run: cargo run --release -- --port 2000 --model-id --repeat-last-n 64 +import argparse +# Run candle-vllm service: cargo run --release -- --port 2000 --model-id --repeat-last-n 64 # MODEL_ID is the huggingface model id or local weight path # MODEL_TYPE is one of ["llama", "llama3", "mistral", "phi2", "phi3", "qwen2", "gemma", "yi", "stable-lm"] - +# Then run this file: python3 examples/benchmark.py --batch 16 openai.api_key = "EMPTY" openai.base_url = "http://localhost:2000/v1/" +# You may add your custom prompts here +PROMPT_CANDIDATES = ["Explain how to best learn Rust.", + "Please talk about deep learning.", + "Do you know the capital city of China? Talk the details of you known.", + "Who is the best female actor in the world? Explain why.", + "Let me know how to deal with depression?", + "How to make money in short time?", + "What is the future trend of large language model?", + "The famous tech companies in the world."] + async def chat_completion(model, max_tokens, prompt): completion = openai.chat.completions.create( model=model, @@ -34,26 +45,12 @@ async def stream_response(response_idx, stream: Stream[ChatCompletionChunk]): result += r return (response_idx, result) -async def benchmark(): - model = "mistral7b" - max_tokens = 1024 - # 16 requests - prompts = ["Explain how to best learn Rust.", - "Please talk about deep learning.", - "Do you know the capital city of China? Talk the details of you known.", - "Who is the best female actor in the world? Explain why.", - "Let me know how to deal with depression?", - "How to make money in short time?", - "What is the future trend of large language model?", - "The famous tech companies in the world.", - "Explain how to best learn Rust.", - "Please talk about deep learning.", - "Do you know the capital city of China? Talk the details of you known.", - "Who is the best female actor in the world? Explain why.", - "Let me know how to deal with depression?", - "How to make money in short time?", - "What is the future trend of large language model?", - "The famous tech companies in the world."] +async def benchmark(batch, max_tokens=1024): + model = "any" # model used dependent on the server side + # candidate requests + prompts = [] + for i in range(batch): + prompts.append(PROMPT_CANDIDATES[i % len(PROMPT_CANDIDATES)]) # avoid generating very short answers for i in range(len(prompts)): @@ -86,4 +83,9 @@ async def benchmark(): print("\n\n Response {}: \n\n {}".format(idx, output)) -asyncio.run(benchmark()) \ No newline at end of file +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Using 'batch' and 'max_tokens' parameters for candle-vllm benchmark.") + parser.add_argument('--batch', default=16, type=int) + parser.add_argument('--max_tokens', default=1024, type=int) + args = parser.parse_args() + asyncio.run(benchmark(args.batch, args.max_tokens)) \ No newline at end of file diff --git a/src/openai/streaming.rs b/src/openai/streaming.rs index 455dab1..a88d019 100644 --- a/src/openai/streaming.rs +++ b/src/openai/streaming.rs @@ -50,11 +50,9 @@ impl Stream for Streamer { Poll::Ready(Some(Ok(Event::default().data("[DONE]")))) } }, - Err(e) => { if self.status == StreamingStatus::Started && e == flume::TryRecvError::Disconnected { - //no TryRecvError::Disconnected returned even if the client closed the stream or disconnected self.status = StreamingStatus::Interrupted; Poll::Ready(None) } else {