Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

failed to dispatch head_dim 96 #455

Open
ZX-ModelCloud opened this issue Aug 20, 2024 · 0 comments
Open

failed to dispatch head_dim 96 #455

ZX-ModelCloud opened this issue Aug 20, 2024 · 0 comments
Assignees

Comments

@ZX-ModelCloud
Copy link

env CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=10 python -m sglang.launch_server --model-path vonjack/Phi-3-mini-4k-instruct-LLaMAfied --port 30000

When loading vonjack/Phi-3-mini-4k-instruct-LLaMAfied using sglang, the following error occurs.

server_args=ServerArgs(model_path='vonjack/Phi-3-mini-4k-instruct-LLaMAfied', tokenizer_path='vonjack/Phi-3-mini-4k-instruct-LLaMAfied', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', dtype='auto', trust_remote_code=False, context_length=None, quantization=None, served_model_name='vonjack/Phi-3-mini-4k-instruct-LLaMAfied', chat_template=None, host='127.0.0.1', port=30000, additional_ports=[30001, 30002, 30003, 30004], mem_fraction_static=0.88, max_running_requests=None, max_num_reqs=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=173762660, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, api_key=None, file_storage_pth='SGLang_storage', dp_size=1, load_balance_method='round_robin', disable_flashinfer=False, disable_flashinfer_sampling=False, disable_radix_cache=False, disable_regex_jump_forward=False, disable_cuda_graph=False, disable_disk_cache=False, enable_torch_compile=False, enable_p2p_check=False, enable_mla=False, attention_reduce_in_fp32=False, efficient_weight_load=False, nccl_init_addr=None, nnodes=1, node_rank=None)
[gpu=0] Init nccl begin.
[gpu=0] Load weight begin. avail mem=94.87 GB
INFO 08-20 05:30:40 weight_utils.py:225] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.17it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.17it/s]

[gpu=0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=87.66 GB
[gpu=0] Memory pool end. avail mem=11.20 GB
[gpu=0] Capture cuda graph begin. This can take up to several minutes.
Process Process-1:
Initialization failed. controller_init_state: Traceback (most recent call last):
  File "/root/miniconda3/envs/base2/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 371, in init_cuda_graphs
    self.cuda_graph_runner.capture(batch_size_list)
  File "/root/miniconda3/envs/base2/lib/python3.11/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 162, in capture
    ) = self.capture_one_batch_size(bs, forward)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/base2/lib/python3.11/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 214, in capture_one_batch_size
    update_flashinfer_indices(
  File "/root/miniconda3/envs/base2/lib/python3.11/site-packages/sglang/srt/model_executor/forward_batch_info.py", line 289, in update_flashinfer_indices
    flashinfer_decode_wrapper.begin_forward(
  File "/root/projects/fanfiction-go/python/hub/flashinfer/python/flashinfer/decode.py", line 539, in begin_forward
    self._wrapper.begin_forward(
RuntimeError: BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(at::Tensor, at::Tensor, at::Tensor, at::Tensor, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, float, at::Tensor, at::Tensor)::<lambda()>::<lambda()>::<lambda()> failed to dispatch head_dim 96

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/miniconda3/envs/base2/lib/python3.11/site-packages/sglang/srt/managers/controller_single.py", line 150, in start_controller_process
    controller = ControllerSingle(
                 ^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/base2/lib/python3.11/site-packages/sglang/srt/managers/controller_single.py", line 84, in __init__
    self.tp_server = ModelTpServer(
                     ^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/base2/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 99, in __init__
    self.model_runner = ModelRunner(
                        ^^^^^^^^^^^^
  File "/root/miniconda3/envs/base2/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 140, in __init__
    self.init_cuda_graphs()
  File "/root/miniconda3/envs/base2/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 373, in init_cuda_graphs
    raise Exception(
Exception: Capture cuda graph failed: BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(at::Tensor, at::Tensor, at::Tensor, at::Tensor, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, float, at::Tensor, at::Tensor)::<lambda()>::<lambda()>::<lambda()> failed to dispatch head_dim 96

@yzh119 yzh119 self-assigned this Aug 20, 2024
yzh119 added a commit that referenced this issue Oct 7, 2024
This PR implements the JIT compilation (#170 ) of flashinfer, after this
PR, flashinfer will compile kernels just-in-time for different input
data types and shapes, and cached the kernels at the disk, instead of
pre-compile a set of kernels in the wheel.

# Motivation
The pip wheel size is exploding as we add support to more data types,
more head dimensions, more attention variants and more kernel
implementation. Pre-compile everything is not sustainable, and impedes
development speed.

This PR refactors the codebase to use torch's [JIT Compiling
Extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions)
feature instead of pre-compile kernels in the wheel.

## Attention Variants
We learned from [FlexAttention](https://pytorch.org/blog/flexattention/)
and describes every attention variant as a template class, each instance
of the struct can carry some closure variable defined in local memory or
shared memory, below are two examples (logits soft cap and alibi
attention, the programming interface is tentative and will be updated as
we improve the programmability of the JIT template):

```cuda
template <typename ParamsT>
struct LogitsSoftCap {
  using DTypeQ = typename ParamsT::DTypeQ;
  using DTypeKV = typename ParamsT::DTypeKV;
  using DTypeO = typename ParamsT::DTypeO;

  uint32_t qo_len, kv_len;
  uint32_t window_left;

  __device__ __host__ LogitsSoftCap(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    window_left = kv_len;
  }

  template <typename T>
  __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
    return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
  }

  template <typename T>
  __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
                                               uint32_t qo_idx, uint32_t kv_idx,
                                               uint32_t qo_head_idx, uint32_t kv_head_idx) {
    return params.logits_soft_cap * math::log2e * float(math::tanh(logits));
  }

  __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
                                             uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
                                             uint32_t kv_head_idx) {
    return true;
  }
};

template <typename ParamsT>
struct ALIBIAttention {
  using DTypeQ = typename ParamsT::DTypeQ;
  using DTypeKV = typename ParamsT::DTypeKV;
  using DTypeO = typename ParamsT::DTypeO;
  using IdType = typename ParamsT::IdType;

  uint32_t qo_len, kv_len;
  uint32_t window_left;

  __device__ __host__ ALIBIAttention(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    window_left = kv_len;
  }

  template <typename T>
  __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
    return float(q) * params.sm_scale * math::log2e;
  }

  template <typename T>
  __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
                                               uint32_t qo_idx, uint32_t kv_idx,
                                               uint32_t qo_head_idx, uint32_t kv_head_idx) {
    return logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx));
  }

  __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
                                             uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
                                             uint32_t kv_head_idx) {
    return true;
  }
};
```
User can customize their own `ParamsT` class and variants class to
define their own attention variants, we hope such refactor will make the
codebase more concise and extensive.

# Roadmap

After this PR, we will add support for:
1. PyPI wheels #153 
2. fp8 tensor cores attention: #502
3. different head dimensions: #142 #454 #455
4. flashattention3 #369 
5. multi-head latency attention #237 
6. Generate ParamsT and Attention variants description from python dsl

The development of this features have been blocked by the limitation of
wheel size (binary size >= 2GB will trigger some linking issues), I hope
this PR will make development easier in the future.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants