Skip to content

Commit

Permalink
Enable Random Prefix Caching in Serving Profiling Tool (benchmark_ser…
Browse files Browse the repository at this point in the history
…ving.py) (vllm-project#8241)
  • Loading branch information
wschin authored and dtrifiro committed Sep 12, 2024
1 parent 15b9658 commit b17d45c
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,16 @@ def sample_sonnet_requests(


def sample_random_requests(
input_len: int, output_len: int, num_prompts: int, range_ratio: float,
tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
prefix_len: int,
input_len: int,
output_len: int,
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
prefix_token_ids = np.random.randint(0,
tokenizer.vocab_size,
size=prefix_len).tolist()

input_lens = np.random.randint(
int(input_len * range_ratio),
Expand All @@ -211,10 +219,12 @@ def sample_random_requests(
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
prompt = tokenizer.decode(prefix_token_ids +
[(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])])

input_requests.append(
(prompt, int(input_lens[i]), int(output_lens[i])))
(prompt, int(prefix_len + input_lens[i]), int(output_lens[i])))

return input_requests

Expand Down Expand Up @@ -567,6 +577,7 @@ def main(args: argparse.Namespace):

elif args.dataset_name == "random":
input_requests = sample_random_requests(
prefix_len=args.random_prefix_len,
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
Expand Down Expand Up @@ -765,6 +776,14 @@ def main(args: argparse.Namespace):
help="Range of sampled ratio of input/output length, "
"used only for random sampling.",
)
parser.add_argument(
"--random-prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before random "
" context. The length range of context in a random "
" request is [random-prefix-len, "
" random-prefix-len + random-prefix-len * random-range-ratio).")
parser.add_argument(
"--request-rate",
type=float,
Expand Down

0 comments on commit b17d45c

Please sign in to comment.