Skip to content

Commit

Permalink
Init clean up and monkeypatch fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Mar 4, 2023
1 parent a2b1979 commit 2c1d8e8
Show file tree
Hide file tree
Showing 10 changed files with 455 additions and 311 deletions.
125 changes: 111 additions & 14 deletions numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,130 @@
"""
The numba-dpex extension module adds data-parallel offload support to Numba.
"""
import glob
import logging
import os
import platform as plt

import numba_dpex.core.dpjit_dispatcher
import numba_dpex.core.offload_dispatcher
import dpctl
import llvmlite.binding as ll
import numba
from numba.core import ir_utils
from numba.np import arrayobj
from numba.np.ufunc import array_exprs
from numba.np.ufunc.decorators import Vectorize

from numba_dpex._patches import _empty_nd_impl, _is_ufunc, _mk_alloc
from numba_dpex.vectorizers import Vectorize as DpexVectorize

# Monkey patches
array_exprs._is_ufunc = _is_ufunc
ir_utils.mk_alloc = _mk_alloc
arrayobj._empty_nd_impl = _empty_nd_impl


def load_dpctl_sycl_interface():
"""Permanently loads the ``DPCTLSyclInterface`` library provided by dpctl.
The ``DPCTLSyclInterface`` library provides C wrappers over SYCL functions
that are directly invoked from the LLVM modules generated by numba_dpex.
We load the library once at the time of initialization using llvmlite's
load_library_permanently function.
Raises:
ImportError: If the ``DPCTLSyclInterface`` library could not be loaded.
"""

platform = plt.system()
if platform == "Windows":
paths = glob.glob(
os.path.join(
os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.dll"
)
)
else:
paths = glob.glob(
os.path.join(
os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.so.0"
)
)

if len(paths) == 1:
ll.load_library_permanently(paths[0])
else:
raise ImportError

Vectorize.target_registry.ondemand["dpex"] = lambda: DpexVectorize


numba_version = tuple(map(int, numba.__version__.split(".")[:3]))
if numba_version < (0, 56, 4):
logging.warning(
"numba_dpex needs numba 0.56.4, using "
f"numba={numba_version} may cause unexpected behavior"
)


dpctl_version = tuple(map(int, dpctl.__version__.split(".")[:2]))
if dpctl_version < (0, 14):
logging.warning(
"numba_dpex needs dpctl 0.14 or greater, using "
f"dpctl={dpctl_version} may cause unexpected behavior"
)


import numba_dpex.core.dpjit_dispatcher # noqa E402
import numba_dpex.core.offload_dispatcher # noqa E402

# Initialize the _dpexrt_python extension
import numba_dpex.core.runtime
import numba_dpex.core.targets.dpjit_target
import numba_dpex.core.runtime # noqa E402
import numba_dpex.core.targets.dpjit_target # noqa E402

# Re-export types itself
import numba_dpex.core.types as types
from numba_dpex.core.kernel_interface.indexers import NdRange, Range
import numba_dpex.core.types as types # noqa E402
from numba_dpex import config # noqa E402
from numba_dpex.core.kernel_interface.indexers import ( # noqa E402
NdRange,
Range,
)

# Re-export all type names
from numba_dpex.core.types import *
from numba_dpex.retarget import offload_to_sycl_device

from . import config
from numba_dpex.core.types import * # noqa E402
from numba_dpex.retarget import offload_to_sycl_device # noqa E402

if config.HAS_NON_HOST_DEVICE:
from .device_init import *
# Re export
from .core.targets import dpjit_target, kernel_target
from .decorators import dpjit, func, kernel

# We are importing dpnp stub module to make Numba recognize the
# module when we rename Numpy functions.
from .dpnp_iface.stubs import dpnp
from .ocl.stubs import (
GLOBAL_MEM_FENCE,
LOCAL_MEM_FENCE,
atomic,
barrier,
get_global_id,
get_global_size,
get_group_id,
get_local_id,
get_local_size,
get_num_groups,
get_work_dim,
local,
mem_fence,
private,
sub_group_barrier,
)

DEFAULT_LOCAL_SIZE = []
load_dpctl_sycl_interface()
del load_dpctl_sycl_interface
else:
raise ImportError("No non-host SYCL device found to execute kernels.")


from ._version import get_versions
from numba_dpex._version import get_versions # noqa E402

__version__ = get_versions()["version"]
del get_versions

__all__ = ["offload_to_sycl_device"] + types.__all__ + ["Range", "NdRange"]
__all__ = types.__all__ + ["offload_to_sycl_device"] + ["Range", "NdRange"]
Loading

0 comments on commit 2c1d8e8

Please sign in to comment.