Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dpnp ufuncs inside dpjit functions #931

Closed
diptorupd opened this issue Feb 24, 2023 · 2 comments
Closed

Support dpnp ufuncs inside dpjit functions #931

diptorupd opened this issue Feb 24, 2023 · 2 comments
Assignees

Comments

@diptorupd
Copy link
Collaborator

diptorupd commented Feb 24, 2023

Numba supports most NumPy ufuncs inside njit functions and is able to compile these functions to LLVM. The same functionality has to be supported for dpnp.

Steps involved:

  1. Add dpnp ufuncs to the numba ufunc_db
#  monkey patch dpnp's ufunc to have `nin  
list_of_ufuncs = ["add", "subtract"]

from numba.np.ufunc_db import _lazy_init_db
_lazy_init_db()
from numba.np.ufunc_db import _ufunc_db as ufunc_db

for ufuncop in list_of_ufuncs:
	op = getattr(dpnp, ufuncop)
	npop = getattr(np, ufuncop)
	op.nin = npop.nin
        op.nout = npop.nout
        op.nargs = npop.nargs
	ufunc_db.update({op: ufunc_db[npop]})
  1. Update the ufunc_db for dpnp functions where we want to use OCL math intrinsics
    def replace_numpy_ufunc_with_opencl_supported_functions(self):
          from numba_dpex.ocl.mathimpl import lower_ocl_impl, sig_mapper
  
          ufuncs = [
              ("fabs", dpnp.fabs),
              ("exp", dpnp.exp),
              ("log", dpnp.log),
              ("log10", dpnp.log10),
              ("expm1", dpnp.expm1),
              ("log1p", dpnp.log1p),
              ("sqrt", dpnp.sqrt),
              ("sin", dpnp.sin),
              ("cos", dpnp.cos),
              ("tan", dpnp.tan),
              ("asin", dpnp.arcsin),
              ("acos", dpnp.arccos),
              ("atan", dpnp.arctan),
              ("atan2", dpnp.arctan2),
              ("sinh", dpnp.sinh),
              ("cosh", dpnp.cosh),
              ("tanh", dpnp.tanh),
              ("asinh", dpnp.arcsinh),
              ("acosh", dpnp.arccosh),
              ("atanh", dpnp.arctanh),
              ("ldexp", dpnp.ldexp),
              ("floor", dpnp.floor),
              ("ceil", dpnp.ceil),
              ("trunc", dpnp.trunc),
              ("hypot", dpnp.hypot),
              ("exp2", dpnp.exp2),
              ("log2", dpnp.log2),
          ]
  
          for name, ufunc in ufuncs:
              for sig in self.ufunc_db[ufunc].keys():
                  if (
                      sig in sig_mapper
                      and (name, sig_mapper[sig]) in lower_ocl_impl
                  ):
                      self.ufunc_db[ufunc][sig] = lower_ocl_impl[
                          (name, sig_mapper[sig])
                      ]
  1. Register the ufuncs

    Refer
    https://github.com/numba/numba/blob/720b357320d99eceed149be5f2a7ae20ec67642c/numba/np/npyimpl.py#L560
    https://github.com/numba/numba/blob/720b357320d99eceed149be5f2a7ae20ec67642c/numba/core/typing/npydecl.py

@diptorupd
Copy link
Collaborator Author

Calling a modified register_ufunc does not do the trick 😭

import numba
import dpnp
import numba_dpex as dpex
import numpy as np

from numba.np import npyimpl
from numba.core.typing import npydecl
from numba.np import ufunc_db

#  monkey patch dpnp's ufunc to have `nin
list_of_ufuncs = ["add", "subtract"]


def fill_ufunc_db_with_dpnp_ufuncs():
    from numba.np.ufunc_db import _lazy_init_db

    _lazy_init_db()
    from numba.np.ufunc_db import _ufunc_db as ufunc_db

    for ufuncop in list_of_ufuncs:
        op = getattr(dpnp, ufuncop)
        npop = getattr(np, ufuncop)
        op.nin = npop.nin
        op.nout = npop.nout
        op.nargs = npop.nargs
        ufunc_db.update({op: ufunc_db[npop]})


def _register_dpnp_ufuncs():
    kernels = {}
    # NOTE: Assuming ufunc implementation for the CPUContext.
    for ufunc in ufunc_db.get_ufuncs():
        kernels[ufunc] = npyimpl.register_ufunc_kernel(
            ufunc, npyimpl._ufunc_db_function(ufunc)
        )

    for _op_map in (
        npydecl.NumpyRulesUnaryArrayOperator._op_map,
        npydecl.NumpyRulesArrayOperator._op_map,
    ):
        for operator, ufunc_name in _op_map.items():
            if ufunc_name in list_of_ufuncs:
                ufunc = getattr(dpnp, ufunc_name)
                kernel = kernels[ufunc]
                if ufunc.nin == 1:
                    npyimpl.register_unary_operator_kernel(
                        operator, ufunc, kernel
                    )
                elif ufunc.nin == 2:
                    npyimpl.register_binary_operator_kernel(
                        operator, ufunc, kernel
                    )
                else:
                    raise RuntimeError(
                        "There shouldn't be any non-unary or binary operators"
                    )

    # for _op_map in (npydecl.NumpyRulesInplaceArrayOperator._op_map,):
    #     for operator, ufunc_name in _op_map.items():
    #         ufunc = getattr(np, ufunc_name)
    #         kernel = kernels[ufunc]
    #         if ufunc.nin == 1:
    #             npyimpl.register_unary_operator_kernel(
    #                 operator, ufunc, kernel, inplace=True
    #             )
    #         elif ufunc.nin == 2:
    #             npyimpl.register_binary_operator_kernel(
    #                 operator, ufunc, kernel, inplace=True
    #             )
    #         else:
    #             raise RuntimeError(
    #                 "There shouldn't be any non-unary or binary operators"
    #             )


fill_ufunc_db_with_dpnp_ufuncs()
_register_dpnp_ufuncs()


@dpex.dpjit
def foo(a, b):
    return dpnp.add(a, b)


a = dpnp.ones(10)
b = dpnp.ones(10)

c = foo(a, b)

@mingjie-intel mingjie-intel linked a pull request Feb 26, 2023 that will close this issue
5 tasks
@diptorupd
Copy link
Collaborator Author

Done as part of #957

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants