Skip to content

Commit

Permalink
Merge pull request #1261 from IntelPython/experimental/more_fetch_phi…
Browse files Browse the repository at this point in the history
…_fns

Adds all fetch_* SPIR-V overload to experimental
  • Loading branch information
Diptorup Deb authored Jan 10, 2024
2 parents 089402a + 3edf70a commit 315b6e7
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def gen(context, builder, sig, args):
"--spirv-ext=+SPV_EXT_shader_atomic_float_add"
]

context.extra_compile_options[LLVM_SPIRV_ARGS] = [
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max"
]

ptr_type = retty.as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space

Expand Down Expand Up @@ -118,6 +122,59 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add")


def _atomic_sub_float_wrapper(gen_fn):
def gen(context, builder, sig, args):
# args is a tuple, which is immutable
# covert tuple to list obj first before replacing arg[1]
# with fneg and convert back to tuple again.
args_lst = list(args)
args_lst[1] = builder.fneg(args[1])
args = tuple(args_lst)

gen_fn(context, builder, sig, args)

return gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val):
if ty_atomic_ref.dtype in (types.float32, types.float64):
# dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub
# for floats is implemented by negating the value and calling fetch_add.
# For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val).
sig, gen = _intrinsic_helper(
ty_context, ty_atomic_ref, ty_val, "fetch_add"
)
return sig, _atomic_sub_float_wrapper(gen)

return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_min(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_min")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_max(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_max")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_and(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_and")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_or(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_or")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_xor(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_xor")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_atomic_ref_ctor(
ty_context, ref, ty_index, ty_retty_ref # pylint: disable=unused-argument
Expand Down Expand Up @@ -294,3 +351,168 @@ def ol_fetch_add_impl(atomic_ref, val):
return _intrinsic_fetch_add(atomic_ref, val)

return ol_fetch_add_impl


@overload_method(AtomicRefType, "fetch_sub", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_sub(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_sub`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_sub` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to sub: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_sub_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_sub(atomic_ref, val)

return ol_fetch_sub_impl


@overload_method(AtomicRefType, "fetch_min", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_min(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_min`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_min` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to find min: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_min_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_min(atomic_ref, val)

return ol_fetch_min_impl


@overload_method(AtomicRefType, "fetch_max", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_max(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_max`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_max` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to find max: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_max_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_max(atomic_ref, val)

return ol_fetch_max_impl


@overload_method(AtomicRefType, "fetch_and", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_and(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_and`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_and` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to and: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_and operation only supported on int32 and int64 dtypes."
)

def ol_fetch_and_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_and(atomic_ref, val)

return ol_fetch_and_impl


@overload_method(AtomicRefType, "fetch_or", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_or(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_or`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_or` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to or: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_or operation only supported on int32 and int64 dtypes."
)

def ol_fetch_or_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_or(atomic_ref, val)

return ol_fetch_or_impl


@overload_method(AtomicRefType, "fetch_xor", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_xor(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_xor`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_xor` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to xor: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_xor operation only supported on int32 and int64 dtypes."
)

def ol_fetch_xor_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_xor(atomic_ref, val)

return ol_fetch_xor_impl
5 changes: 4 additions & 1 deletion numba_dpex/spirv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def finalize(self):
# TODO: find better approach to set SPIRV compiler arguments. Workaround
# against caching intrinsic that sets this argument.
# https://github.com/IntelPython/numba-dpex/issues/1262
llvm_spirv_args = ["--spirv-ext=+SPV_EXT_shader_atomic_float_add"]
llvm_spirv_args = [
"--spirv-ext=+SPV_EXT_shader_atomic_float_add",
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max",
]
for key in list(self.context.extra_compile_options.keys()):
if key == LLVM_SPIRV_ARGS:
llvm_spirv_args = self.context.extra_compile_options[key]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import dpnp
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
Expand All @@ -14,30 +15,80 @@
no_bool=True, no_float16=True, no_none=True, no_complex=True
)

list_of_fetch_phi_funcs = [
"fetch_add",
"fetch_sub",
"fetch_min",
"fetch_max",
"fetch_and",
"fetch_or",
"fetch_xor",
]


@pytest.fixture(params=list_of_fetch_phi_funcs)
def fetch_phi_fn(request):
return request.param


@pytest.fixture(params=list_of_supported_dtypes)
def input_arrays(request):
# The size of input and out arrays to be used
N = 10
a = dpnp.ones(N, dtype=request.param)
b = dpnp.zeros(N, dtype=request.param)
a = dpnp.arange(N, dtype=request.param)
b = dpnp.ones(N, dtype=request.param)
return a, b


@pytest.mark.parametrize("ref_index", [0, 5])
def test_fetch_add(input_arrays, ref_index):
def test_fetch_phi_fn(input_arrays, ref_index, fetch_phi_fn):
"""A test for all fetch_phi atomic functions."""

@dpex_exp.kernel
def atomic_ref_kernel(a, b, ref_index):
def _kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_add(a[i])
getattr(v, fetch_phi_fn)(a[i])

a, b = input_arrays

dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b, ref_index)
if (
fetch_phi_fn in ["fetch_and", "fetch_or", "fetch_xor"]
and issubclass(a.dtype.type, dpnp.floating)
and issubclass(b.dtype.type, dpnp.floating)
):
# fetch_and, fetch_or, fetch_xor accept only int arguments.
# test for TypingError when float arguments are passed.
with pytest.raises(TypingError):
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
else:
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
# Verify that `a` accumulated at b[ref_index] by kernel
# matches the `a` accumulated at b[ref_index+1] using Python
for i in range(a.size):
v = AtomicRef(b, index=ref_index + 1)
getattr(v, fetch_phi_fn)(a[i])

assert b[ref_index] == b[ref_index + 1]


def test_fetch_phi_diff_types(fetch_phi_fn):
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and value to be added are of different types.
"""

@dpex_exp.kernel
def _kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0)
getattr(v, fetch_phi_fn)(a[i])

N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.int32)

# Verify that `a` was accumulated at b[ref_index]
assert b[ref_index] == 10
with pytest.raises(TypingError):
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b)


@dpex_exp.kernel
Expand All @@ -54,7 +105,7 @@ def atomic_ref_1(a):
v.fetch_add(a[i + 2])


def test_spirv_compiler_flags():
def test_spirv_compiler_flags_add():
"""Check if float atomic flag is being populated from intrinsic for the
second call.
Expand All @@ -68,3 +119,36 @@ def test_spirv_compiler_flags():

assert a[0] == N - 1
assert a[1] == N - 1


@dpex_exp.kernel
def atomic_max_0(a):
i = dpex.get_global_id(0)
v = AtomicRef(a, index=0)
if i != 0:
v.fetch_max(a[i])


@dpex_exp.kernel
def atomic_max_1(a):
i = dpex.get_global_id(0)
v = AtomicRef(a, index=0)
if i != 0:
v.fetch_max(a[i])


def test_spirv_compiler_flags_max():
"""Check if float atomic flag is being populated from intrinsic for the
second call.
https://github.com/IntelPython/numba-dpex/issues/1262
"""
N = 10
a = dpnp.arange(N, dtype=dpnp.float32)
b = dpnp.arange(N, dtype=dpnp.float32)

dpex_exp.call_kernel(atomic_max_0, dpex.Range(N), a)
dpex_exp.call_kernel(atomic_max_1, dpex.Range(N), b)

assert a[0] == N - 1
assert b[0] == N - 1

0 comments on commit 315b6e7

Please sign in to comment.