Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NotImplementedError: There was no rule registered for HOP flex_attention and mode #70

Open
LeoXinhaoLee opened this issue Nov 2, 2024 · 2 comments

Comments

@LeoXinhaoLee
Copy link

LeoXinhaoLee commented Nov 2, 2024

Hi, thank you for releasing this wonderful code base.

When I'm trying to combine causal and cross-document mask, like below

document_ids = document_ids.unsqueeze(1)  # [B, S] -> [B, 1, S]

def document_causal_mask(b, h, q_idx, kv_idx):
       causal_mask = q_idx >= kv_idx
       document_mask = document_ids[b, h, q_idx] == document_ids[b, h, kv_idx]
        return causal_mask & document_mask

block_mask = create_block_mask(document_causal_mask, 
B=xq.shape[0], H=None, 
Q_LEN=xq.shape[2], KV_LEN=xk.shape[2], device=xq.device)

# xq/xk/xv shape: [B,H,S,d]
output = flex_attention(xq, xk, xv, block_mask=block_mask, score_mod=None)

I encountered this error:

    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 703, in flex_attention_autograd
      out, logsumexp = FlexAttentionAutogradOp.apply(
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
      return super().apply(*args, **kwargs)  # type: ignore[misc]
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 591, in forward
      out, logsumexp = flex_attention(
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 109, in __call__
      return super().__call__(
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_ops.py", line 433, in __call__
      return wrapper()
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
      return fn(*args, **kwargs)
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_ops.py", line 424, in wrapper
      return torch.overrides.handle_torch_function(
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/overrides.py", line 1717, in handle_torch_function
      result = mode.__torch_function__(public_api, types, args, kwargs)
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 85, in __torch_function__
      return func(*args, **(kwargs or {}))
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 109, in __call__
      return super().__call__(
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_ops.py", line 433, in __call__
      return wrapper()
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
      return fn(*args, **kwargs)
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_ops.py", line 429, in wrapper
      return self.dispatch(
    File "/lustre/fsw/portfolios/llmservice/users/xiaolwang/software/miniforge3/envs/torchtitan-triton/lib/python3.10/site-packages/torch/_ops.py", line 334, in dispatch
      raise NotImplementedError(
  NotImplementedError: There was no rule registered for HOP flex_attention and mode <torch.utils.checkpoint._CachingTorchDispatchMode object at 0x14f2d67d1c90>. We recommend filing an issue.

Could you please help me with this? Thank you!

My env:
torch 2.5.0

@LeoXinhaoLee
Copy link
Author

From the error message, I speculate it's because I'm using FlexAttention in a gradient checkpointed nn module. I'm wondering is there a way to make FlexAttention compatible with that?

@Chillee
Copy link
Contributor

Chillee commented Nov 12, 2024

@LeoXinhaoLee This is currently a gap in our support - we're currently working on it. See pytorch/pytorch#140322

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants