Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a test case to test retract #1662

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,9 +590,11 @@ def retract_decode(self):

retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
self.token_to_kv_pool.available_size()
< len(sorted_indices) * global_config.retract_decode_steps
or first_iter
):
if len(sorted_indices) == 1:
# Corner case: only one request left
Expand All @@ -601,6 +603,7 @@ def retract_decode(self):
), "No space left for only one request"
break

first_iter = False
idx = sorted_indices.pop()
req = self.reqs[idx]
retracted_reqs.append(req)
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"

# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"


class Scheduler:
"""A scheduler that manages a tensor parallel GPU worker."""
Expand Down Expand Up @@ -611,10 +614,11 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
return new_batch

def update_running_batch(self):
global test_retract
batch = self.running_batch

# Check if decode out of memory
if not batch.check_decode_mem():
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
old_ratio = self.new_token_ratio

retracted_reqs, new_token_ratio = batch.retract_decode()
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"test_large_max_new_tokens.py",
"test_openai_server.py",
"test_pytorch_sampling_backend.py",
"test_retract_decode.py",
"test_server_args.py",
"test_skip_tokenizer_init.py",
"test_srt_engine.py",
Expand Down
41 changes: 41 additions & 0 deletions test/srt/test_retract_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class TestRetractDecode(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)

def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
assert metrics["score"] >= 0.65


if __name__ == "__main__":
unittest.main()
Loading