From 48d961a548e31205dee56a35d57b37e7b8b47e3b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 20 Nov 2024 01:18:45 -0800 Subject: [PATCH] Allow skipping overlap in bench_offline_throughput.py --- python/sglang/bench_offline_throughput.py | 23 ++++++++++++------- .../srt/managers/tp_worker_overlap_thread.py | 6 ++--- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 7717c16f033..cb502fa027a 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -57,6 +57,7 @@ class BenchArgs: disable_ignore_eos: bool = False extra_request_body: Optional[str] = None seed: int = 1 + skip_warmup: bool = False do_not_exit: bool = False @staticmethod @@ -152,6 +153,11 @@ def add_cli_args(parser: argparse.ArgumentParser): "additional generate params like sampling params.", ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--skip-warmup", + action="store_true", + help="Skip the warmup batches.", + ) parser.add_argument( "--do-not-exit", action="store_true", @@ -261,14 +267,15 @@ def throughput_test( ) # Warm up - logging.info("\nWarmup...") - throughput_test_once( - backend_name=bench_args.backend, - backend=backend, - reqs=warmup_requests, - ignore_eos=not bench_args.disable_ignore_eos, - extra_request_body=extra_request_body, - ) + if not bench_args.skip_warmup: + logging.info("\nWarmup...") + throughput_test_once( + backend_name=bench_args.backend, + backend=backend, + reqs=warmup_requests, + ignore_eos=not bench_args.disable_ignore_eos, + extra_request_body=extra_request_body, + ) logging.info("\nBenchmark...") result = throughput_test_once( diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 5435e4bf970..ab37ceed261 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -156,9 +156,6 @@ def resolve_batch_result(self, bid: int): return logits_output, next_token_ids def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): - # A cuda stream sync here to avoid the cuda illegal memory access error. - torch.cuda.current_stream().synchronize() - # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. sampling_info = model_worker_batch.sampling_info sampling_info.update_penalties() @@ -169,6 +166,9 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): linear_penalties=sampling_info.linear_penalties, ) + # A cuda stream sync here to avoid the cuda illegal memory access error. + torch.cuda.current_stream().synchronize() + # Push a new batch to the queue self.input_queue.put((model_worker_batch, self.future_token_ids_ct))