diff --git a/numba_dpex/core/utils/kernel_templates/reduction_template.py b/numba_dpex/core/utils/kernel_templates/reduction_template.py index 3aefb202ef..882570bce1 100644 --- a/numba_dpex/core/utils/kernel_templates/reduction_template.py +++ b/numba_dpex/core/utils/kernel_templates/reduction_template.py @@ -99,7 +99,9 @@ def _generate_kernel_stub_as_string(self): for redvar in self._redvars: legal_redvar = self._redvars_dict[redvar] gufunc_txt += " " - gufunc_txt += legal_redvar + " = 0\n" + gufunc_txt += legal_redvar + " = " + gufunc_txt += f"{self._parfor_reddict[redvar].init_val} \n" + gufunc_txt += " " gufunc_txt += self._sentinel_name + " = 0\n" @@ -265,8 +267,15 @@ def _generate_kernel_stub_as_string(self): ) for i, redvar in enumerate(self._redvars): - gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \ - {self._partial_sum_var_name[i]}[j]\n" + redop = self._parfor_reddict[redvar].redop + if redop == operator.iadd: + gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \ + {self._partial_sum_var_name[i]}[j]\n" + elif redop == operator.imul: + gufunc_txt += f" {self._final_sum_var_name[i]}[0] *= \ + {self._partial_sum_var_name[i]}[j]\n" + else: + raise NotImplementedError gufunc_txt += ( f" for j in range ({self._global_size_mod_var_name[0]}) :\n" @@ -275,7 +284,8 @@ def _generate_kernel_stub_as_string(self): for redvar in self._redvars: legal_redvar = self._redvars_dict[redvar] gufunc_txt += " " - gufunc_txt += legal_redvar + " = 0\n" + gufunc_txt += legal_redvar + " = " + gufunc_txt += f"{self._parfor_reddict[redvar].init_val}\n" gufunc_txt += ( " " diff --git a/numba_dpex/tests/dpjit_tests/test_dpjit_reduction.py b/numba_dpex/tests/dpjit_tests/test_dpjit_reduction.py index 64da9351b7..06fb8f2671 100644 --- a/numba_dpex/tests/dpjit_tests/test_dpjit_reduction.py +++ b/numba_dpex/tests/dpjit_tests/test_dpjit_reduction.py @@ -9,11 +9,11 @@ import numba_dpex as dpex -N = 100 +N = 10 @dpex.dpjit -def vecadd_prange(a, b): +def vecadd_prange1(a, b): s = 0 t = 0 for i in nb.prange(a.shape[0]): @@ -24,13 +24,21 @@ def vecadd_prange(a, b): @dpex.dpjit -def vecmul_prange(a, b): +def vecadd_prange2(a, b): t = 0 for i in nb.prange(a.shape[0]): t += a[i] * b[i] return t +@dpex.dpjit +def vecmul_prange(a, b): + t = 1 + for i in nb.prange(a.shape[0]): + t *= a[i] + b[i] + return t + + @dpex.dpjit def vecadd_prange_float(a, b): s = numpy.float32(0) @@ -57,22 +65,34 @@ def input_arrays(request): return a, b -def test_dpjit_array_arg_types(input_arrays): +def test_dpjit_array_arg_types_add1(input_arrays): """Tests passing float and int type dpnp arrays to a dpjit prange function. Args: input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel. """ - s = 200 - + s = 20 a, b = input_arrays - - c = vecadd_prange(a, b) + c = vecadd_prange1(a, b) assert s == c +def test_dpjit_array_arg_types_add2(input_arrays): + """Tests passing float and int type dpnp arrays to a dpjit + prange function. + + Args: + input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel. + """ + t = 45 + a, b = input_arrays + d = vecadd_prange2(a, b) + + assert t == d + + def test_dpjit_array_arg_types_mul(input_arrays): """Tests passing float and int type dpnp arrays to a dpjit prange function. @@ -80,7 +100,7 @@ def test_dpjit_array_arg_types_mul(input_arrays): Args: input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel. """ - s = 4950 + s = 3628800 a, b = input_arrays @@ -97,8 +117,8 @@ def test_dpjit_array_arg_float32_types(input_arrays): input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel. """ s = 9900 - a = dpnp.arange(N, dtype=dpnp.float32) - b = dpnp.arange(N, dtype=dpnp.float32) + a = dpnp.arange(100, dtype=dpnp.float32) + b = dpnp.arange(100, dtype=dpnp.float32) c = vecadd_prange_float(a, b)