Skip to content

Commit

Permalink
Fix memory leak when doing chunked prefill (#1787)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Oct 25, 2024
1 parent 2148914 commit a2f5e75
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 69 deletions.
12 changes: 11 additions & 1 deletion python/sglang/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):

# Runtime constants: New generation token ratio estimation
self.init_new_token_ratio = 0.7
self.base_min_new_token_ratio = 0.1
self.min_new_token_ratio = 0.1
self.new_token_ratio_decay = 0.001

# Runtime constants: others
Expand All @@ -32,5 +32,15 @@ def __init__(self):
self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True

def adjust_new_token_ratio(self, schedule_conservativeness=1):
assert schedule_conservativeness >= 0, "Invalid schedule_conservativeness"
min_new_token_ratio = min(
self.min_new_token_ratio * schedule_conservativeness,
1.0,
)
init_new_token_ratio = max(self.init_new_token_ratio, min_new_token_ratio)

return min_new_token_ratio, init_new_token_ratio


global_config = GlobalConfig()
7 changes: 3 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(
self.prefix_indices = []
self.extend_input_len = 0
self.last_node = None
self.is_inflight_req = 0
self.is_being_chunked = False

# Logprobs (arguments)
self.return_logprob = False
Expand Down Expand Up @@ -906,15 +906,14 @@ def prepare_for_decode(self, enable_overlap: bool = False):

def filter_batch(
self,
current_inflight_req: Optional[Req] = None,
being_chunked_req: Optional[Req] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] is not current_inflight_req
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
]

if keep_indices is None or len(keep_indices) == 0:
Expand Down
25 changes: 18 additions & 7 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(

self.req_states = None
self.can_run_list = []
self.new_inflight_req = None
self.new_chunked_req = None
self.log_hit_tokens = 0
self.log_input_tokens = 0

Expand Down Expand Up @@ -176,7 +176,7 @@ def _prefill_one_req(
self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len

def add_inflight_req(self, req: Req):
def add_being_chunked_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
Expand All @@ -192,8 +192,13 @@ def add_inflight_req(self, req: Req):
),
)

# Return if chunked prefill not finished
return req if truncated else None
if truncated:
# Continue to chunk the request
assert req.is_being_chunked
self.new_chunked_req = req
else:
# Release the being chunked status
req.is_being_chunked = False

@contextmanager
def _lock_node(self, last_node: TreeNode):
Expand Down Expand Up @@ -262,11 +267,14 @@ def add_req_state(r, insert_sort=False):
)
else:
# Chunked prefill
assert self.new_chunked_req is None

trunc_len = self.rem_chunk_tokens
req.extend_input_len = trunc_len
req.is_being_chunked = True
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_inflight_req = req
self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0)

return self.budget_state()
Expand Down Expand Up @@ -305,15 +313,18 @@ def add_one_req(self, req: Req):
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER

# Chunked prefill
assert self.new_chunked_req is None

req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
req.is_being_chunked = True
self.can_run_list.append(req)
self.new_inflight_req = req
self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)

Expand Down
95 changes: 38 additions & 57 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,35 +219,28 @@ def __init__(

# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
self.being_chunked_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)

# Init the FSM cache for constrained generation
if not server_args.skip_tokenizer_init:
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
self.jump_forward_cache = JumpForwardCache()

# Init new token estimation
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.min_new_token_ratio = min(
global_config.base_min_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
self.min_new_token_ratio, self.init_new_token_ratio = (
global_config.adjust_new_token_ratio(server_args.schedule_conservativeness)
)
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.new_token_ratio = self.init_new_token_ratio
self.batch_is_full = False

# Init profiler
Expand Down Expand Up @@ -294,7 +287,7 @@ def event_loop_normal(self):
self.process_batch_result(batch, result)
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.new_token_ratio = self.init_new_token_ratio

self.last_batch = batch

Expand All @@ -321,7 +314,7 @@ def event_loop_overlap(self):
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.new_token_ratio = self.init_new_token_ratio

self.last_batch = batch

Expand Down Expand Up @@ -499,20 +492,18 @@ def check_memory(self):
)
exit(1) if crash_on_warning else None

def get_next_batch_to_run(self):
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
if (
self.last_batch
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.current_inflight_req:
self.last_batch.filter_batch(
current_inflight_req=self.current_inflight_req
)
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
if self.being_chunked_req:
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Being chunked request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False
if not self.last_batch.is_empty():
if self.running_batch is None:
Expand Down Expand Up @@ -543,7 +534,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
) and self.being_chunked_req is None:
return None

running_bs = len(self.running_batch.reqs) if self.running_batch else 0
Expand All @@ -566,22 +557,20 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
num_mixed_running,
)

has_inflight = self.current_inflight_req is not None
if has_inflight:
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
)

if self.lora_paths:
lora_set = (
set([req.lora_path for req in self.running_batch.reqs])
if self.running_batch is not None
else set([])
)

# NOTE: if there is request being chunked, we always add it first
has_being_chunked = self.being_chunked_req is not None
if has_being_chunked:
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
self.being_chunked_req.init_next_round_input()
adder.add_being_chunked_req(self.being_chunked_req)

# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
Expand Down Expand Up @@ -615,12 +604,8 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
x for x in self.waiting_queue if x not in set(can_run_list)
]

if adder.new_inflight_req is not None:
assert self.current_inflight_req is None
self.current_inflight_req = adder.new_inflight_req

if self.current_inflight_req:
self.current_inflight_req.is_inflight_req += 1
# Update new round being chunked request
self.being_chunked_req = adder.new_chunked_req

# Print stats
if self.tp_rank == 0:
Expand Down Expand Up @@ -649,7 +634,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
)
else:
logger.info(
Expand All @@ -660,7 +645,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
)

# Create a new batch
Expand Down Expand Up @@ -709,7 +694,7 @@ def update_running_batch(self):
self.waiting_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
self.new_token_ratio - global_config.new_token_ratio_decay,
self.min_new_token_ratio,
)

Expand Down Expand Up @@ -783,10 +768,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
# Check finish conditions
logprob_pt = 0
for i, req in enumerate(batch.reqs):
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished
if not req.is_being_chunked:
# Being chunked reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
Expand All @@ -812,10 +795,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
# Check finish conditions
for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i]
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished
if not req.is_being_chunked:
# Being chunked reqs' prefill is not finished
# dummy output token for embedding models
req.output_ids.append(0)
req.check_finished()
Expand Down
1 change: 1 addition & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def run_mmlu_test(
chunked_prefill_size=32,
):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
other_args += ["--mem-fraction-static", "0.85"]
if disable_radix_cache:
other_args += ["--disable-radix-cache"]
if enable_mixed_chunk:
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

suites = {
"minimal": [
"test_radix_attention.py",
"models/test_embedding_models.py",
"models/test_generation_models.py",
"models/test_lora.py",
Expand Down
Loading

0 comments on commit a2f5e75

Please sign in to comment.