Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove cached triton launcher #656

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions python/sglang/srt/layers/context_flashattention_nopad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import triton
import triton.language as tl

from sglang.srt.utils import wrap_kernel_launcher

CUDA_CAPABILITY = torch.cuda.get_device_capability()


Expand Down Expand Up @@ -119,9 +117,6 @@ def _fwd_kernel(
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)


cached_kernel = None


def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
if CUDA_CAPABILITY[0] >= 8:
BLOCK = 128
Expand All @@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8

global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
)
return

_fwd_kernel[grid](
q,
k,
Expand All @@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
num_warps=num_warps,
num_stages=1,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
39 changes: 0 additions & 39 deletions python/sglang/srt/layers/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import triton.language as tl

from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.utils import wrap_kernel_launcher

CUDA_CAPABILITY = torch.cuda.get_device_capability()

Expand Down Expand Up @@ -172,9 +171,6 @@ def _fwd_kernel(
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])


cached_kernel = None


def extend_attention_fwd(
q_extend,
k_extend,
Expand Down Expand Up @@ -222,40 +218,6 @@ def extend_attention_fwd(
num_warps = 4 if Lk <= 64 else 8
num_stages = 1

global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
)
return

_fwd_kernel[grid](
q_extend,
k_extend,
Expand Down Expand Up @@ -290,7 +252,6 @@ def extend_attention_fwd(
num_stages=num_stages,
logit_cap=logit_cap,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)


def redundant_attention(
Expand Down
50 changes: 0 additions & 50 deletions python/sglang/srt/layers/token_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import triton.language as tl

from sglang.srt.server import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher

if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
Expand Down Expand Up @@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
tl.store(out_ptrs, acc)


cached_kernel_stage1 = None
cached_kernel_stage2 = None


def _token_att_m_fwd(
q,
k_buffer,
Expand Down Expand Up @@ -194,28 +189,6 @@ def _token_att_m_fwd(
else:
num_warps = 2

global cached_kernel_stage1
if cached_kernel_stage1:
cached_kernel_stage1(
grid,
num_warps,
q,
k_buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
att_out.stride(0),
)
return

_fwd_kernel_stage1[grid](
q,
k_buffer,
Expand All @@ -238,7 +211,6 @@ def _token_att_m_fwd(
num_warps=num_warps,
num_stages=1,
)
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)


def _token_softmax_reducev_fwd(
Expand All @@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(

num_warps = 1

global cached_kernel_stage2
if cached_kernel_stage2:
cached_kernel_stage2(
grid,
num_warps,
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
)
return

_fwd_kernel_stage2[grid](
logics,
v_buffer,
Expand All @@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
num_warps=num_warps,
num_stages=3,
)
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)


def token_attention_fwd(
Expand Down
27 changes: 2 additions & 25 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
set_ulimit,
)
from sglang.utils import get_exception_traceback

Expand Down Expand Up @@ -145,30 +146,6 @@ def _set_global_server_args(server_args: ServerArgs):
}


def _set_ulimit(target_soft_limit=65535):
import resource

resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)

if current_soft >= target_soft_limit:
logger.info(
f"Current limits are already sufficient: soft={current_soft}, hard={current_hard}"
)
else:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
new_soft, new_hard = resource.getrlimit(resource_type)
logger.info(
f"Successfully set new limits: soft={new_soft}, hard={new_hard}"
)
except ValueError as e:
logger.warn(f"Failed to set new limits: {e}")
logger.info(
f"Limits remain unchanged: soft={current_soft}, hard={current_hard}"
)


def launch_server(
server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
Expand All @@ -186,7 +163,7 @@ def launch_server(
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
_set_ulimit()
set_ulimit()
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
Expand Down
80 changes: 13 additions & 67 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import random
import resource
import socket
import struct
import time
Expand All @@ -16,6 +17,7 @@
import psutil
import requests
import torch
import torch.distributed as dist
import triton
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
Expand Down Expand Up @@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
return logit_bias


def wrap_kernel_launcher(kernel):
"""A faster launcher for triton kernels."""
if int(triton.__version__.split(".")[0]) >= 3:
return None

gpu_id = torch.cuda.current_device()
kernels = kernel.cache[gpu_id].values()
kernel = next(iter(kernels))

# Different trition versions use different low-level names
if hasattr(kernel, "cu_function"):
kfunction = kernel.cu_function
else:
kfunction = kernel.function

if hasattr(kernel, "c_wrapper"):
run = kernel.c_wrapper
else:
run = kernel.run

add_cluster_dim = True

def ret_func(grid, num_warps, *args):
nonlocal add_cluster_dim

try:
if add_cluster_dim:
run(
grid[0],
grid[1],
grid[2],
num_warps,
1,
1,
1,
1,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
else:
run(
grid[0],
grid[1],
grid[2],
num_warps,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
except TypeError:
add_cluster_dim = not add_cluster_dim
ret_func(grid, num_warps, *args)

return ret_func


def is_multimodal_model(model):
from sglang.srt.model_config import ModelConfig

Expand Down Expand Up @@ -512,7 +449,6 @@ def get_ip_address(ifname):

def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1
import torch.distributed as dist

ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
Expand Down Expand Up @@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):

def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1
import torch.distributed as dist

ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
Expand Down Expand Up @@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args):

dist.barrier()
dist.destroy_process_group()


def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)

if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
Loading