Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add profile in offline benchmark & update doc #2123

Merged
merged 14 commits into from
Nov 27, 2024
18 changes: 18 additions & 0 deletions docs/references/benchmark_and_profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,21 @@ with nvtx.annotate("description", color="color"):
## Other tips

1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder.

## Profile with PyTorch Profiler
- To profile a server
```bash
# set trace path
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
# start server
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct

python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile
```

Traces can be visualized using https://ui.perfetto.dev/.

- To profile offline
```bash
python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --profile-dir=/root/sglang/profile_log --mem-frac=0.8
```
42 changes: 41 additions & 1 deletion python/sglang/bench_offline_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import logging
import random
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch

from sglang.api import Engine
from sglang.bench_serving import (
Expand Down Expand Up @@ -52,6 +54,8 @@ class BenchArgs:
seed: int = 1
skip_warmup: bool = False
do_not_exit: bool = False
profile: bool = False
profile_dir: str = ""

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -156,6 +160,20 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler",
)
parser.add_argument(
"--profile-dir",
type=str,
default=None,
help=(
"path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."
),
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand All @@ -169,6 +187,8 @@ def throughput_test_once(
reqs: List[Tuple[str, int, int]],
ignore_eos: bool,
extra_request_body: Dict,
profile: bool,
profile_dir: str,
):
measurement_results = {
"backend": backend_name,
Expand All @@ -194,7 +214,23 @@ def throughput_test_once(
]

st = time.perf_counter()
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
if profile:
if not profile_dir:
profile_dir = (
Path(".") / "sglang_benchmark_result" / f"latency_result_{time.time()}"
)
print(f"Profiling (results will be saved to '{profile_dir}')...")
with torch.profiler.profile(
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(str(profile_dir)),
) as p:
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
else:
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
latency = time.perf_counter() - st

if backend_name == "runtime":
Expand Down Expand Up @@ -268,6 +304,8 @@ def throughput_test(
reqs=warmup_requests,
ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
profile=False,
profile_dir=bench_args.profile_dir,
)

logging.info("\nBenchmark...")
Expand All @@ -277,6 +315,8 @@ def throughput_test(
reqs=input_requests,
ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
profile=bench_args.profile,
profile_dir=bench_args.profile_dir,
)

if bench_args.result_filename:
Expand Down
Loading