-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support dpnp ufuncs inside dpjit functions #931
Comments
Calling a modified import numba
import dpnp
import numba_dpex as dpex
import numpy as np
from numba.np import npyimpl
from numba.core.typing import npydecl
from numba.np import ufunc_db
# monkey patch dpnp's ufunc to have `nin
list_of_ufuncs = ["add", "subtract"]
def fill_ufunc_db_with_dpnp_ufuncs():
from numba.np.ufunc_db import _lazy_init_db
_lazy_init_db()
from numba.np.ufunc_db import _ufunc_db as ufunc_db
for ufuncop in list_of_ufuncs:
op = getattr(dpnp, ufuncop)
npop = getattr(np, ufuncop)
op.nin = npop.nin
op.nout = npop.nout
op.nargs = npop.nargs
ufunc_db.update({op: ufunc_db[npop]})
def _register_dpnp_ufuncs():
kernels = {}
# NOTE: Assuming ufunc implementation for the CPUContext.
for ufunc in ufunc_db.get_ufuncs():
kernels[ufunc] = npyimpl.register_ufunc_kernel(
ufunc, npyimpl._ufunc_db_function(ufunc)
)
for _op_map in (
npydecl.NumpyRulesUnaryArrayOperator._op_map,
npydecl.NumpyRulesArrayOperator._op_map,
):
for operator, ufunc_name in _op_map.items():
if ufunc_name in list_of_ufuncs:
ufunc = getattr(dpnp, ufunc_name)
kernel = kernels[ufunc]
if ufunc.nin == 1:
npyimpl.register_unary_operator_kernel(
operator, ufunc, kernel
)
elif ufunc.nin == 2:
npyimpl.register_binary_operator_kernel(
operator, ufunc, kernel
)
else:
raise RuntimeError(
"There shouldn't be any non-unary or binary operators"
)
# for _op_map in (npydecl.NumpyRulesInplaceArrayOperator._op_map,):
# for operator, ufunc_name in _op_map.items():
# ufunc = getattr(np, ufunc_name)
# kernel = kernels[ufunc]
# if ufunc.nin == 1:
# npyimpl.register_unary_operator_kernel(
# operator, ufunc, kernel, inplace=True
# )
# elif ufunc.nin == 2:
# npyimpl.register_binary_operator_kernel(
# operator, ufunc, kernel, inplace=True
# )
# else:
# raise RuntimeError(
# "There shouldn't be any non-unary or binary operators"
# )
fill_ufunc_db_with_dpnp_ufuncs()
_register_dpnp_ufuncs()
@dpex.dpjit
def foo(a, b):
return dpnp.add(a, b)
a = dpnp.ones(10)
b = dpnp.ones(10)
c = foo(a, b) |
5 tasks
Done as part of #957 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Numba supports most NumPy ufuncs inside njit functions and is able to compile these functions to LLVM. The same functionality has to be supported for dpnp.
Steps involved:
Register the ufuncs
The text was updated successfully, but these errors were encountered: