forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core][Hash][Automatic Prefix caching] Accelerating the hashing funct…
…ion by avoiding deep copies (vllm-project#4696)
- Loading branch information
Showing
2 changed files
with
77 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
import cProfile | ||
import pstats | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
# A very long prompt, total number of tokens is about 15k. | ||
LONG_PROMPT = ["You are an expert in large language models, aren't you?" | ||
] * 1000 | ||
LONG_PROMPT = ' '.join(LONG_PROMPT) | ||
|
||
|
||
def main(args): | ||
llm = LLM( | ||
model=args.model, | ||
enforce_eager=True, | ||
enable_prefix_caching=True, | ||
tensor_parallel_size=args.tensor_parallel_size, | ||
use_v2_block_manager=args.use_v2_block_manager, | ||
) | ||
|
||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) | ||
profiler = cProfile.Profile() | ||
|
||
print("------warm up------") | ||
for i in range(3): | ||
output = llm.generate(LONG_PROMPT, sampling_params) | ||
print(output[0].outputs[0].text) | ||
|
||
print("------start generating------") | ||
for i in range(3): | ||
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', | ||
globals(), locals()) | ||
|
||
# analyze the runtime of hashing function | ||
stats = pstats.Stats(profiler) | ||
stats.sort_stats('cumulative') | ||
total_time = 0 | ||
total_calls = 0 | ||
for func in stats.stats: | ||
if 'hash_of_block' in func[2]: | ||
total_time = stats.stats[func][3] | ||
total_calls = stats.stats[func][0] | ||
percentage = (total_time / stats.total_tt) * 100 | ||
print(f"Hashing took {total_time:.2f} seconds," | ||
f"{percentage:.2f}% of the total runtime.") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description='Benchmark the performance of hashing function in' | ||
'automatic prefix caching.') | ||
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') | ||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) | ||
parser.add_argument('--output-len', type=int, default=10) | ||
parser.add_argument('--enable-prefix-caching', | ||
action='store_true', | ||
help='enable prefix caching') | ||
parser.add_argument('--use-v2-block-manager', | ||
action='store_true', | ||
help='Use BlockSpaceMangerV2') | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters