Skip to content

Commit

Permalink
Add lower function
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Apr 17, 2023
1 parent 4218ae1 commit 44b81c0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
4 changes: 2 additions & 2 deletions numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:


numba_version = parse_sem_version(numba.__version__)
if numba_version < (0, 56, 4):
if numba_version < (0, 57, 0):
logging.warning(
"numba_dpex needs numba 0.56.4, using "
"numba_dpex needs numba 0.57.0, using "
f"numba={numba_version} may cause unexpected behavior"
)

Expand Down
20 changes: 15 additions & 5 deletions numba_dpex/dpnp_iface/dpnpimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@
# SPDX-License-Identifier: Apache-2.0

import dpnp
from numba.core.imputils import Registry
from numba.np import npyimpl

from numba_dpex.core.typing.dpnpdecl import _unsupported
from numba_dpex.dpnp_iface import dpnp_ufunc_db


def _register_dpnp_ufuncs():
registry = Registry("npyimpl")
lower = registry.lower

kernels = {}
# NOTE: Assuming ufunc implementation for the CPUContext.
for ufunc in dpnp_ufunc_db.get_ufuncs():
kernels[ufunc] = npyimpl.register_ufunc_kernel(
ufunc, npyimpl._ufunc_db_function(ufunc)
ufunc,
npyimpl._ufunc_db_function(ufunc),
lower,
)

for _op_map in (
Expand All @@ -27,9 +33,13 @@ def _register_dpnp_ufuncs():
ufunc = getattr(dpnp, ufunc_name)
kernel = kernels[ufunc]
if ufunc.nin == 1:
npyimpl.register_unary_operator_kernel(operator, ufunc, kernel)
npyimpl.register_unary_operator_kernel(
operator, ufunc, kernel, lower
)
elif ufunc.nin == 2:
npyimpl.register_binary_operator_kernel(operator, ufunc, kernel)
npyimpl.register_binary_operator_kernel(
operator, ufunc, kernel, lower
)
else:
raise RuntimeError(
"There shouldn't be any non-unary or binary operators"
Expand All @@ -43,11 +53,11 @@ def _register_dpnp_ufuncs():
kernel = kernels[ufunc]
if ufunc.nin == 1:
npyimpl.register_unary_operator_kernel(
operator, ufunc, kernel, inplace=True
operator, ufunc, kernel, lower, inplace=True
)
elif ufunc.nin == 2:
npyimpl.register_binary_operator_kernel(
operator, ufunc, kernel, inplace=True
operator, ufunc, kernel, lower, inplace=True
)
else:
raise RuntimeError(
Expand Down

0 comments on commit 44b81c0

Please sign in to comment.