From 75bc3c97463daa4c81a62068a03111c96c434800 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Wed, 26 Jun 2024 10:38:42 +0200 Subject: [PATCH] Fix SDPA in case attn_mask == None (#78) --- .../backend/__init__.py | 3 +- .../backend/sdpa.py | 47 +++++++++++++++++++ .../scaled_dot_product_attention.py | 10 ++-- test/python/test_sdpa.py | 14 ++++-- 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/intel_npu_acceleration_library/backend/__init__.py b/intel_npu_acceleration_library/backend/__init__.py index 5cab810..f8fbeaa 100644 --- a/intel_npu_acceleration_library/backend/__init__.py +++ b/intel_npu_acceleration_library/backend/__init__.py @@ -12,7 +12,7 @@ from .qlinear import QLinear from .tensor import Tensor from .factory import NNFactory -from .sdpa import SDPA +from .sdpa import SDPA, SimpleSDPA from .runtime import run_matmul, run_factory, clear_cache check_npu_and_driver_version() @@ -27,6 +27,7 @@ "QLinear", "Convolution", "SDPA", + "SimpleSDPA", "run_matmul", "run_factory", "clear_cache", diff --git a/intel_npu_acceleration_library/backend/sdpa.py b/intel_npu_acceleration_library/backend/sdpa.py index 5374ed8..18e9636 100644 --- a/intel_npu_acceleration_library/backend/sdpa.py +++ b/intel_npu_acceleration_library/backend/sdpa.py @@ -58,3 +58,50 @@ def run( np.ndarray: result """ return super().run(query, key, value, mask) + + +class SimpleSDPA(NNFactory): + """Implementation of a ScaledDotProductAttention NPU operation.""" + + def __init__( + self, + query_shapes: Tuple[int, int], + key_shapes: Tuple[int, int], + value_shapes: Tuple[int, int], + is_causal: bool = False, + profile: bool = False, + device: str = "NPU", + ): + """Initialize the SDPA. + + Args: + query_shapes (Tuple[int, int]): shape of the query tensor + key_shapes (Tuple[int, int]): shape of the key tensor + value_shapes (Tuple[int, int]): shape of the value tensor + is_causal (bool, optional): If the SDPA mask is is_causal or not. Defaults to False. + profile (bool, optional): Enable/Disable profiling. Defaults to False. + device (str, optional): Target device, default to "NPU". + """ + super().__init__(profile, device) + + self.query = self.parameter(query_shapes) + self.key = self.parameter(key_shapes) + self.value = self.parameter(value_shapes) + + _ = self.scaled_dot_product_attention_simple( # type: ignore[attr-defined] + self.query, self.key, self.value, is_causal + ) + self.compile() + + def run(self, query: np.ndarray, key: np.ndarray, value: np.ndarray) -> np.ndarray: + """Run the scaled dot product attention kernel. + + Args: + query (np.ndarray): sdpa query tensor + key (np.ndarray): sdpa key tensor + value (np.ndarray): sdpa value tensor + + Returns: + np.ndarray: result + """ + return super().run(query, key, value) diff --git a/intel_npu_acceleration_library/functional/scaled_dot_product_attention.py b/intel_npu_acceleration_library/functional/scaled_dot_product_attention.py index f79551f..142798b 100644 --- a/intel_npu_acceleration_library/functional/scaled_dot_product_attention.py +++ b/intel_npu_acceleration_library/functional/scaled_dot_product_attention.py @@ -2,7 +2,7 @@ # Copyright © 2024 Intel Corporation # SPDX-License-Identifier: Apache 2.0 # -from intel_npu_acceleration_library.backend import run_factory, SDPA +from intel_npu_acceleration_library.backend import run_factory, SDPA, SimpleSDPA from typing import Optional from functools import partial import torch @@ -34,10 +34,14 @@ def scaled_dot_product_attention( Returns: torch.Tensor: _description_ """ - backend_cls = partial(SDPA, is_causal=is_causal) if dropout_p != 0: raise RuntimeError("dropout_p != 0 is not supported yet") if scale is not None: raise RuntimeError("scale != 0 is not supported yet") - return run_factory([query, key, value, attn_mask], [], backend_cls) + if attn_mask is None: + backend_cls = partial(SimpleSDPA, is_causal=is_causal) # type: ignore + return run_factory([query, key, value], [], backend_cls) + else: + backend_cls = partial(SDPA, is_causal=is_causal) # type: ignore + return run_factory([query, key, value, attn_mask], [], backend_cls) diff --git a/test/python/test_sdpa.py b/test/python/test_sdpa.py index e104c5e..86c0e99 100644 --- a/test/python/test_sdpa.py +++ b/test/python/test_sdpa.py @@ -59,7 +59,8 @@ def test_sdpa(heads, sequence, dim, kv_cache, is_causal): @pytest.mark.parametrize("dim", [512, 1024]) @pytest.mark.parametrize("kv_cache", [True, False]) @pytest.mark.parametrize("is_causal", [False, True]) -def test_sdpa_runtime(heads, sequence, dim, kv_cache, is_causal): +@pytest.mark.parametrize("use_mask", [False, True]) +def test_sdpa_runtime(heads, sequence, dim, kv_cache, is_causal, use_mask): min_value = torch.finfo(torch.float16).min @@ -68,10 +69,13 @@ def test_sdpa_runtime(heads, sequence, dim, kv_cache, is_causal): ) key = torch.rand(1, heads, sequence, dim // heads).to(torch.float16) value = torch.rand(1, heads, sequence, dim // heads).to(torch.float16) - mask = min_value * torch.ones(1, heads, 1 if kv_cache else sequence, sequence).to( - torch.float16 - ) - mask = torch.triu(mask) + if use_mask: + mask = min_value * torch.ones( + 1, heads, 1 if kv_cache else sequence, sequence + ).to(torch.float16) + mask = torch.triu(mask) + else: + mask = None npu_result = scaled_dot_product_attention( query, key, value, mask, is_causal=is_causal