Skip to content

Commit

Permalink
Unify index operations (#620)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Jul 14, 2024
1 parent 564a898 commit a56858b
Showing 1 changed file with 18 additions and 40 deletions.
58 changes: 18 additions & 40 deletions python/sglang/srt/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,33 @@ def alloc(self, need_size):
return None

select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
self.mem_state[select_index] = 0
self.mem_state[select_index] = False
self.can_use_mem_size -= need_size

return select_index.to(torch.int32)

def free(self, free_index):
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
else:
self.can_use_mem_size += free_index.shape[0]
self.mem_state[free_index] = 1

self.mem_state[free_index] = True

def clear(self):
self.mem_state.fill_(1)
self.mem_state.fill_(True)
self.can_use_mem_size = len(self.mem_state)


class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.size = size

# This can be promised:
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.bool, device="cuda")
self.total_size = self.size
self.total_alloc = 0
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
self.can_use_mem_size = self.size

# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
Expand Down Expand Up @@ -73,9 +75,8 @@ def alloc(self, need_size):

addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = (
torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
)
select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size]
select_index = select_index.to(torch.int32)

if select_index.shape[0] < addition_size:
return None
Expand All @@ -88,43 +89,20 @@ def alloc(self, need_size):

return ret_index

def alloc_contiguous(self, need_size):
# NOTE: This function is deprecated.
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
if empty_index.shape[0] < need_size:
return None
empty_size = len(empty_index)
loc_sum = (
empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)]
)
can_used_loc = empty_index[: empty_size - (need_size - 1)][
loc_sum == need_size - 1
]
if can_used_loc.shape[0] == 0:
return None

start_loc = can_used_loc[0].item()
select_index = torch.arange(start_loc, start_loc + need_size, device="cuda")
self.add_refs(select_index)
return select_index.to(torch.int32), start_loc, start_loc + need_size

def used_size(self):
return self.total_alloc

def available_size(self):
return self.total_size - self.total_alloc + len(self.prefetch_buffer)
return self.can_use_mem_size + len(self.prefetch_buffer)

def add_refs(self, token_index: torch.Tensor):
self.total_alloc += len(token_index)
self.mem_state[token_index] ^= True
self.can_use_mem_size -= len(token_index)
self.mem_state[token_index] = False

def dec_refs(self, token_index: torch.Tensor):
self.total_alloc -= len(token_index)
self.mem_state[token_index] ^= True
self.can_use_mem_size += len(token_index)
self.mem_state[token_index] = True

def clear(self):
self.mem_state.fill_(0)
self.total_alloc = 0
self.mem_state.fill_(True)
self.can_use_mem_size = self.size

# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = True
self.mem_state[0] = False

0 comments on commit a56858b

Please sign in to comment.