Skip to content

Commit

Permalink
[Fix] Fix cuda graph padding for triton attention backend (#1782)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 24, 2024
1 parent 0089c4b commit fc82f5a
Show file tree
Hide file tree
Showing 6 changed files with 3 additions and 19 deletions.
4 changes: 0 additions & 4 deletions python/sglang/srt/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ def init_forward_metadata_replay_cuda_graph(
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()

def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError()

def forward(
self,
q: torch.Tensor,
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/layers/attention/double_sparsity_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,6 @@ def init_forward_metadata_replay_cuda_graph(
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)

def get_cuda_graph_seq_len_fill_value(self):
return 1

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,6 @@ def init_forward_metadata_replay_cuda_graph(
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
)

def get_cuda_graph_seq_len_fill_value(self):
return 0

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ def init_forward_metadata_replay_cuda_graph(
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)

def get_cuda_graph_seq_len_fill_value(self):
return 1

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, size: int, max_context_len: int, device: str, use_records: bo
self.size = size
self.max_context_len = max_context_len
self.device = device
self.req_to_token = torch.empty(
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
Expand Down
7 changes: 2 additions & 5 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,8 @@ def __init__(self, model_runner: "ModelRunner"):
# Attention backend
self.max_bs = max(self.capture_bs)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)

# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.seq_len_fill_value = 1
self.encoder_len_fill_value = 0

if self.use_torch_compile:
Expand Down Expand Up @@ -290,7 +287,7 @@ def replay(self, forward_batch: ForwardBatch):
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()

# Common inputs
Expand Down

0 comments on commit fc82f5a

Please sign in to comment.