Skip to content

Commit

Permalink
Move status check in the memory pool to CPU (#1557)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 3, 2024
1 parent 317631c commit 4ae0969
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 42 deletions.
62 changes: 21 additions & 41 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from abc import ABC, abstractmethod
from typing import List, Tuple, Union

import numpy as np
import torch

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,56 +70,27 @@ def __init__(
else:
self.store_dtype = dtype

# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")

# Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 512

self.can_use_mem_size = self.size
self.free_slots = None
self.clear()

def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer)
return len(self.free_slots)

def alloc(self, need_size: int):
buffer_len = len(self.prefetch_buffer)
if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return select_index

addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
)

if select_index.shape[0] < addition_size:
if need_size > len(self.free_slots):
return None

self.mem_state[select_index] = False
self.can_use_mem_size -= len(select_index)

self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]

return ret_index
return torch.tensor(select_index, dtype=torch.int32, device="cuda")

def free(self, free_index: torch.Tensor):
self.mem_state[free_index] = True
self.can_use_mem_size += len(free_index)
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))

def clear(self):
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)

self.mem_state.fill_(True)
self.can_use_mem_size = self.size

# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = False
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = np.arange(1, self.size + 1)

@abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
Expand Down Expand Up @@ -152,19 +124,25 @@ def __init__(
head_num: int,
head_dim: int,
layer_num: int,
device: str,
):
super().__init__(size, dtype)

# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
(size + 1, head_num, head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
(size + 1, head_num, head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
Expand Down Expand Up @@ -210,15 +188,17 @@ def __init__(
kv_lora_rank: int,
qk_rope_head_dim: int,
layer_num: int,
device: str,
):
super().__init__(size, dtype)

self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device="cuda",
device=device,
)
for _ in range(layer_num)
]
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,11 @@ def init_memory_pool(
4096,
)

device = "cuda"
self.req_to_token_pool = ReqToTokenPool(
max_num_reqs + 1, self.model_config.context_len + 4, device="cuda"
max_num_reqs + 1,
self.model_config.context_len + 4,
device=device,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
Expand All @@ -422,6 +425,7 @@ def init_memory_pool(
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
device=device,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
Expand All @@ -430,6 +434,7 @@ def init_memory_pool(
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=device,
)
logger.info(
f"Memory pool end. "
Expand Down

0 comments on commit 4ae0969

Please sign in to comment.