Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak when doing chunked prefill #1787

Merged
merged 10 commits into from
Oct 25, 2024
Merged
14 changes: 13 additions & 1 deletion python/sglang/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):

# Runtime constants: New generation token ratio estimation
self.init_new_token_ratio = 0.7
self.base_min_new_token_ratio = 0.1
self.min_new_token_ratio = 0.1
self.new_token_ratio_decay = 0.001

# Runtime constants: others
Expand All @@ -32,5 +32,17 @@ def __init__(self):
self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True

def adjust_new_token_ratio(self, schedule_conservativeness=1):
assert schedule_conservativeness >= 0, "Invalid schedule_conservativeness"
min_new_token_ratio = min(
global_config.min_new_token_ratio * schedule_conservativeness,
1.0,
)
init_new_token_ratio = max(
global_config.init_new_token_ratio, min_new_token_ratio
)
merrymercy marked this conversation as resolved.
Show resolved Hide resolved

return min_new_token_ratio, init_new_token_ratio


global_config = GlobalConfig()
7 changes: 3 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(
self.prefix_indices = []
self.extend_input_len = 0
self.last_node = None
self.is_inflight_req = 0
self.is_being_chunked = False

# Logprobs (arguments)
self.return_logprob = False
Expand Down Expand Up @@ -906,15 +906,14 @@ def prepare_for_decode(self, enable_overlap: bool = False):

def filter_batch(
self,
current_inflight_req: Optional[Req] = None,
being_chunked_req: Optional[Req] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] is not current_inflight_req
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
]

if keep_indices is None or len(keep_indices) == 0:
Expand Down
25 changes: 18 additions & 7 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(

self.req_states = None
self.can_run_list = []
self.new_inflight_req = None
self.new_chunked_req = None
self.log_hit_tokens = 0
self.log_input_tokens = 0

Expand Down Expand Up @@ -176,7 +176,7 @@ def _prefill_one_req(
self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len

def add_inflight_req(self, req: Req):
def add_being_chunked_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
Expand All @@ -192,8 +192,13 @@ def add_inflight_req(self, req: Req):
),
)

# Return if chunked prefill not finished
return req if truncated else None
if truncated:
# Continue to chunk the request
assert req.is_being_chunked
self.new_chunked_req = req
else:
# Release the being chunked status
req.is_being_chunked = False

@contextmanager
def _lock_node(self, last_node: TreeNode):
Expand Down Expand Up @@ -262,11 +267,14 @@ def add_req_state(r, insert_sort=False):
)
else:
# Chunked prefill
assert self.new_chunked_req is None

trunc_len = self.rem_chunk_tokens
req.extend_input_len = trunc_len
req.is_being_chunked = True
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_inflight_req = req
self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0)

return self.budget_state()
Expand Down Expand Up @@ -305,15 +313,18 @@ def add_one_req(self, req: Req):
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER

# Chunked prefill
assert self.new_chunked_req is None

req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
req.is_being_chunked = True
self.can_run_list.append(req)
self.new_inflight_req = req
self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)

Expand Down
95 changes: 38 additions & 57 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,35 +219,28 @@ def __init__(

# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
self.being_chunked_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)

# Init the FSM cache for constrained generation
if not server_args.skip_tokenizer_init:
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
self.jump_forward_cache = JumpForwardCache()

# Init new token estimation
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.min_new_token_ratio = min(
global_config.base_min_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
self.min_new_token_ratio, self.init_new_token_ratio = (
global_config.adjust_new_token_ratio(server_args.schedule_conservativeness)
)
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.new_token_ratio = self.init_new_token_ratio
self.batch_is_full = False

# Init profiler
Expand Down Expand Up @@ -294,7 +287,7 @@ def event_loop_normal(self):
self.process_batch_result(batch, result)
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.new_token_ratio = self.init_new_token_ratio

self.last_batch = batch

Expand All @@ -321,7 +314,7 @@ def event_loop_overlap(self):
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.new_token_ratio = self.init_new_token_ratio

self.last_batch = batch

Expand Down Expand Up @@ -499,20 +492,18 @@ def check_memory(self):
)
exit(1) if crash_on_warning else None

def get_next_batch_to_run(self):
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
if (
self.last_batch
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.current_inflight_req:
self.last_batch.filter_batch(
current_inflight_req=self.current_inflight_req
)
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
if self.being_chunked_req:
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Being chunked request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False
if not self.last_batch.is_empty():
if self.running_batch is None:
Expand Down Expand Up @@ -543,7 +534,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
) and self.being_chunked_req is None:
return None

running_bs = len(self.running_batch.reqs) if self.running_batch else 0
Expand All @@ -566,22 +557,20 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
num_mixed_running,
)

has_inflight = self.current_inflight_req is not None
if has_inflight:
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
)

if self.lora_paths:
lora_set = (
set([req.lora_path for req in self.running_batch.reqs])
if self.running_batch is not None
else set([])
)

# NOTE: if there is request being chunked, we always add it first
has_being_chunked = self.being_chunked_req is not None
if has_being_chunked:
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
self.being_chunked_req.init_next_round_input()
adder.add_being_chunked_req(self.being_chunked_req)

# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
Expand Down Expand Up @@ -615,12 +604,8 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
x for x in self.waiting_queue if x not in set(can_run_list)
]

if adder.new_inflight_req is not None:
assert self.current_inflight_req is None
self.current_inflight_req = adder.new_inflight_req

if self.current_inflight_req:
self.current_inflight_req.is_inflight_req += 1
# Update new round being chunked request
self.being_chunked_req = adder.new_chunked_req

# Print stats
if self.tp_rank == 0:
Expand Down Expand Up @@ -649,7 +634,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
)
else:
logger.info(
Expand All @@ -660,7 +645,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
)

# Create a new batch
Expand Down Expand Up @@ -709,7 +694,7 @@ def update_running_batch(self):
self.waiting_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
self.new_token_ratio - global_config.new_token_ratio_decay,
self.min_new_token_ratio,
)

Expand Down Expand Up @@ -783,10 +768,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
# Check finish conditions
logprob_pt = 0
for i, req in enumerate(batch.reqs):
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished
if not req.is_being_chunked:
# Being chunked reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
Expand All @@ -812,10 +795,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
# Check finish conditions
for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i]
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished
if not req.is_being_chunked:
# Being chunked reqs' prefill is not finished
# dummy output token for embedding models
req.output_ids.append(0)
req.check_finished()
Expand Down
1 change: 1 addition & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def run_mmlu_test(
chunked_prefill_size=32,
):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
other_args += ["--mem-fraction-static", "0.85"]
if disable_radix_cache:
other_args += ["--disable-radix-cache"]
if enable_mixed_chunk:
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

suites = {
"minimal": [
"test_radix_attention.py",
"models/test_embedding_models.py",
"models/test_generation_models.py",
"models/test_lora.py",
Expand Down
Loading
Loading