Skip to content

Commit

Permalink
req_pool slots: gpu -> cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 committed Aug 2, 2024
1 parent 5256a52 commit 4b2dc62
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 28 deletions.
24 changes: 12 additions & 12 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,9 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)

req_pool_indices_cpu = req_pool_indices.tolist()

pt = 0
for i, req in enumerate(reqs):
reqs[i].req_pool_idx = req_pool_indices_cpu[i]
reqs[i].req_pool_idx = req_pool_indices[i]

extend_lens.append(len(input_ids[i]))
prefix_lens.append(len(prefix_indices[i]))
Expand All @@ -426,23 +424,25 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
] = out_cache_loc[pt : pt + extend_lens[i]]
pt += extend_lens[i]

# Set fields
self.input_ids = torch.tensor(
sum(input_ids, []), dtype=torch.int32, device=device
)
# Image auxiliary
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
]
self.req_pool_indices = req_pool_indices
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)

self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]

with torch.device(device):
# Batched tensors
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
self.out_cache_loc = out_cache_loc

self.batch_sampling_params(vocab_size, int_token_logit_bias)

def check_decode_mem(self):
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,10 @@ def check_memory(self):
"KV cache pool leak detected!"
)

if self.req_to_token_pool.can_use_mem_size != self.req_to_token_pool.size:
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
warnings.warn(
"Warning: "
f"available req slots={self.req_to_token_pool.can_use_mem_size}, "
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
f"total slots={self.req_to_token_pool.size}\n"
"Memory pool leak detected!"
)
Expand Down
23 changes: 9 additions & 14 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Memory pool."""

import logging
from typing import List

import torch

Expand All @@ -27,34 +28,28 @@ class ReqToTokenPool:

def __init__(self, size: int, max_context_len: int):
self.size = size
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self.free_slots = list(range(size))
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
)
self.can_use_mem_size = size

def alloc(self, need_size: int):
if need_size > self.can_use_mem_size:
def alloc(self, need_size: int) -> List[int]:
if need_size > len(self.free_slots):
return None

select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
)
self.mem_state[select_index] = False
self.can_use_mem_size -= need_size
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]

return select_index

def free(self, free_index):
self.mem_state[free_index] = True
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
self.free_slots.append(free_index)
else:
self.can_use_mem_size += free_index.shape[0]
self.free_slots.extend(free_index)

def clear(self):
self.mem_state.fill_(True)
self.can_use_mem_size = len(self.mem_state)
self.free_slots = list(range(self.size))


class TokenToKVPool:
Expand Down

0 comments on commit 4b2dc62

Please sign in to comment.