Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bug in reduction mul operation for dpjit. #1048

Merged
merged 1 commit into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions numba_dpex/core/utils/kernel_templates/reduction_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def _generate_kernel_stub_as_string(self):
for redvar in self._redvars:
legal_redvar = self._redvars_dict[redvar]
gufunc_txt += " "
gufunc_txt += legal_redvar + " = 0\n"
gufunc_txt += legal_redvar + " = "
gufunc_txt += f"{self._parfor_reddict[redvar].init_val} \n"

gufunc_txt += " "
gufunc_txt += self._sentinel_name + " = 0\n"

Expand Down Expand Up @@ -265,8 +267,15 @@ def _generate_kernel_stub_as_string(self):
)

for i, redvar in enumerate(self._redvars):
gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \
{self._partial_sum_var_name[i]}[j]\n"
redop = self._parfor_reddict[redvar].redop
if redop == operator.iadd:
gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \
{self._partial_sum_var_name[i]}[j]\n"
elif redop == operator.imul:
gufunc_txt += f" {self._final_sum_var_name[i]}[0] *= \
{self._partial_sum_var_name[i]}[j]\n"
else:
raise NotImplementedError

gufunc_txt += (
f" for j in range ({self._global_size_mod_var_name[0]}) :\n"
Expand All @@ -275,7 +284,8 @@ def _generate_kernel_stub_as_string(self):
for redvar in self._redvars:
legal_redvar = self._redvars_dict[redvar]
gufunc_txt += " "
gufunc_txt += legal_redvar + " = 0\n"
gufunc_txt += legal_redvar + " = "
gufunc_txt += f"{self._parfor_reddict[redvar].init_val}\n"

gufunc_txt += (
" "
Expand Down
42 changes: 31 additions & 11 deletions numba_dpex/tests/dpjit_tests/test_dpjit_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

import numba_dpex as dpex

N = 100
N = 10


@dpex.dpjit
def vecadd_prange(a, b):
def vecadd_prange1(a, b):
s = 0
t = 0
for i in nb.prange(a.shape[0]):
Expand All @@ -24,13 +24,21 @@ def vecadd_prange(a, b):


@dpex.dpjit
def vecmul_prange(a, b):
def vecadd_prange2(a, b):
t = 0
for i in nb.prange(a.shape[0]):
t += a[i] * b[i]
return t


@dpex.dpjit
def vecmul_prange(a, b):
t = 1
for i in nb.prange(a.shape[0]):
t *= a[i] + b[i]
return t


@dpex.dpjit
def vecadd_prange_float(a, b):
s = numpy.float32(0)
Expand All @@ -57,30 +65,42 @@ def input_arrays(request):
return a, b


def test_dpjit_array_arg_types(input_arrays):
def test_dpjit_array_arg_types_add1(input_arrays):
"""Tests passing float and int type dpnp arrays to a dpjit
prange function.

Args:
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
s = 200

s = 20
a, b = input_arrays

c = vecadd_prange(a, b)
c = vecadd_prange1(a, b)

assert s == c


def test_dpjit_array_arg_types_add2(input_arrays):
"""Tests passing float and int type dpnp arrays to a dpjit
prange function.

Args:
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
t = 45
a, b = input_arrays
d = vecadd_prange2(a, b)

assert t == d


def test_dpjit_array_arg_types_mul(input_arrays):
"""Tests passing float and int type dpnp arrays to a dpjit
prange function.

Args:
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
s = 4950
s = 3628800

a, b = input_arrays

Expand All @@ -97,8 +117,8 @@ def test_dpjit_array_arg_float32_types(input_arrays):
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
s = 9900
a = dpnp.arange(N, dtype=dpnp.float32)
b = dpnp.arange(N, dtype=dpnp.float32)
a = dpnp.arange(100, dtype=dpnp.float32)
b = dpnp.arange(100, dtype=dpnp.float32)

c = vecadd_prange_float(a, b)

Expand Down