Skip to content

Commit

Permalink
Fix SDPA in case attn_mask == None (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandropalla authored Jun 26, 2024
1 parent 736afb5 commit 75bc3c9
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 9 deletions.
3 changes: 2 additions & 1 deletion intel_npu_acceleration_library/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .qlinear import QLinear
from .tensor import Tensor
from .factory import NNFactory
from .sdpa import SDPA
from .sdpa import SDPA, SimpleSDPA
from .runtime import run_matmul, run_factory, clear_cache

check_npu_and_driver_version()
Expand All @@ -27,6 +27,7 @@
"QLinear",
"Convolution",
"SDPA",
"SimpleSDPA",
"run_matmul",
"run_factory",
"clear_cache",
Expand Down
47 changes: 47 additions & 0 deletions intel_npu_acceleration_library/backend/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,50 @@ def run(
np.ndarray: result
"""
return super().run(query, key, value, mask)


class SimpleSDPA(NNFactory):
"""Implementation of a ScaledDotProductAttention NPU operation."""

def __init__(
self,
query_shapes: Tuple[int, int],
key_shapes: Tuple[int, int],
value_shapes: Tuple[int, int],
is_causal: bool = False,
profile: bool = False,
device: str = "NPU",
):
"""Initialize the SDPA.
Args:
query_shapes (Tuple[int, int]): shape of the query tensor
key_shapes (Tuple[int, int]): shape of the key tensor
value_shapes (Tuple[int, int]): shape of the value tensor
is_causal (bool, optional): If the SDPA mask is is_causal or not. Defaults to False.
profile (bool, optional): Enable/Disable profiling. Defaults to False.
device (str, optional): Target device, default to "NPU".
"""
super().__init__(profile, device)

self.query = self.parameter(query_shapes)
self.key = self.parameter(key_shapes)
self.value = self.parameter(value_shapes)

_ = self.scaled_dot_product_attention_simple( # type: ignore[attr-defined]
self.query, self.key, self.value, is_causal
)
self.compile()

def run(self, query: np.ndarray, key: np.ndarray, value: np.ndarray) -> np.ndarray:
"""Run the scaled dot product attention kernel.
Args:
query (np.ndarray): sdpa query tensor
key (np.ndarray): sdpa key tensor
value (np.ndarray): sdpa value tensor
Returns:
np.ndarray: result
"""
return super().run(query, key, value)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#
from intel_npu_acceleration_library.backend import run_factory, SDPA
from intel_npu_acceleration_library.backend import run_factory, SDPA, SimpleSDPA
from typing import Optional
from functools import partial
import torch
Expand Down Expand Up @@ -34,10 +34,14 @@ def scaled_dot_product_attention(
Returns:
torch.Tensor: _description_
"""
backend_cls = partial(SDPA, is_causal=is_causal)
if dropout_p != 0:
raise RuntimeError("dropout_p != 0 is not supported yet")
if scale is not None:
raise RuntimeError("scale != 0 is not supported yet")

return run_factory([query, key, value, attn_mask], [], backend_cls)
if attn_mask is None:
backend_cls = partial(SimpleSDPA, is_causal=is_causal) # type: ignore
return run_factory([query, key, value], [], backend_cls)
else:
backend_cls = partial(SDPA, is_causal=is_causal) # type: ignore
return run_factory([query, key, value, attn_mask], [], backend_cls)
14 changes: 9 additions & 5 deletions test/python/test_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def test_sdpa(heads, sequence, dim, kv_cache, is_causal):
@pytest.mark.parametrize("dim", [512, 1024])
@pytest.mark.parametrize("kv_cache", [True, False])
@pytest.mark.parametrize("is_causal", [False, True])
def test_sdpa_runtime(heads, sequence, dim, kv_cache, is_causal):
@pytest.mark.parametrize("use_mask", [False, True])
def test_sdpa_runtime(heads, sequence, dim, kv_cache, is_causal, use_mask):

min_value = torch.finfo(torch.float16).min

Expand All @@ -68,10 +69,13 @@ def test_sdpa_runtime(heads, sequence, dim, kv_cache, is_causal):
)
key = torch.rand(1, heads, sequence, dim // heads).to(torch.float16)
value = torch.rand(1, heads, sequence, dim // heads).to(torch.float16)
mask = min_value * torch.ones(1, heads, 1 if kv_cache else sequence, sequence).to(
torch.float16
)
mask = torch.triu(mask)
if use_mask:
mask = min_value * torch.ones(
1, heads, 1 if kv_cache else sequence, sequence
).to(torch.float16)
mask = torch.triu(mask)
else:
mask = None

npu_result = scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal
Expand Down

0 comments on commit 75bc3c9

Please sign in to comment.