diff --git a/numba_dpex/core/parfors/kernel_templates/reduction_template.py b/numba_dpex/core/parfors/kernel_templates/reduction_template.py index 96b8913106..34b637412f 100644 --- a/numba_dpex/core/parfors/kernel_templates/reduction_template.py +++ b/numba_dpex/core/parfors/kernel_templates/reduction_template.py @@ -278,10 +278,13 @@ def _generate_kernel_stub_as_string(self): ) for redvar in self._redvars: + rtyp = str(self._typemap[redvar]) legal_redvar = self._redvars_dict[redvar] gufunc_txt += " " gufunc_txt += legal_redvar + " = " - gufunc_txt += f"{self._parfor_reddict[redvar].init_val}\n" + gufunc_txt += ( + f"dpnp.{rtyp}({self._parfor_reddict[redvar].init_val})\n" + ) gufunc_txt += ( " " @@ -290,32 +293,17 @@ def _generate_kernel_stub_as_string(self): + f"{self._global_size_var_name[0]} + j\n" ) - for redvar in self._redvars: - rtyp = str(self._typemap[redvar]) - redvar = self._redvars_dict[redvar] - gufunc_txt += ( - " " - + f"local_sums_{redvar} = " - + f"dpex.local.array(1, dpnp.{rtyp})\n" - ) - gufunc_txt += " " + self._sentinel_name + " = 0\n" - for i, redvar in enumerate(self._redvars): - legal_redvar = self._redvars_dict[redvar] - gufunc_txt += ( - " " + f"local_sums_{legal_redvar}[0] = {legal_redvar}\n" - ) - for i, redvar in enumerate(self._redvars): legal_redvar = self._redvars_dict[redvar] redop = self._parfor_reddict[redvar].redop if redop == operator.iadd: gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \ - local_sums_{legal_redvar}[0]\n" + {legal_redvar}\n" elif redop == operator.imul: gufunc_txt += f" {self._final_sum_var_name[i]}[0] *= \ - local_sums_{legal_redvar}[0]\n" + {legal_redvar}\n" else: raise NotImplementedError