Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init clean up and monkeypatch fix #954

Merged
merged 1 commit into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 111 additions & 14 deletions numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,130 @@
"""
The numba-dpex extension module adds data-parallel offload support to Numba.
"""
import glob
import logging
import os
import platform as plt

import numba_dpex.core.dpjit_dispatcher
import numba_dpex.core.offload_dispatcher
import dpctl
import llvmlite.binding as ll
import numba
from numba.core import ir_utils
from numba.np import arrayobj
from numba.np.ufunc import array_exprs
from numba.np.ufunc.decorators import Vectorize

from numba_dpex._patches import _empty_nd_impl, _is_ufunc, _mk_alloc
from numba_dpex.vectorizers import Vectorize as DpexVectorize

# Monkey patches
array_exprs._is_ufunc = _is_ufunc
ir_utils.mk_alloc = _mk_alloc
arrayobj._empty_nd_impl = _empty_nd_impl


def load_dpctl_sycl_interface():
"""Permanently loads the ``DPCTLSyclInterface`` library provided by dpctl.
The ``DPCTLSyclInterface`` library provides C wrappers over SYCL functions
that are directly invoked from the LLVM modules generated by numba_dpex.
We load the library once at the time of initialization using llvmlite's
load_library_permanently function.
Raises:
ImportError: If the ``DPCTLSyclInterface`` library could not be loaded.
"""

platform = plt.system()
if platform == "Windows":
paths = glob.glob(
os.path.join(
os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.dll"
)
)
else:
paths = glob.glob(
os.path.join(
os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.so.0"
)
)

if len(paths) == 1:
ll.load_library_permanently(paths[0])
else:
raise ImportError

Vectorize.target_registry.ondemand["dpex"] = lambda: DpexVectorize


numba_version = tuple(map(int, numba.__version__.split(".")[:3]))
if numba_version < (0, 56, 4):
logging.warning(
"numba_dpex needs numba 0.56.4, using "
f"numba={numba_version} may cause unexpected behavior"
)


dpctl_version = tuple(map(int, dpctl.__version__.split(".")[:2]))
if dpctl_version < (0, 14):
logging.warning(
"numba_dpex needs dpctl 0.14 or greater, using "
f"dpctl={dpctl_version} may cause unexpected behavior"
)


import numba_dpex.core.dpjit_dispatcher # noqa E402
import numba_dpex.core.offload_dispatcher # noqa E402

# Initialize the _dpexrt_python extension
import numba_dpex.core.runtime
import numba_dpex.core.targets.dpjit_target
import numba_dpex.core.runtime # noqa E402
import numba_dpex.core.targets.dpjit_target # noqa E402

# Re-export types itself
import numba_dpex.core.types as types
from numba_dpex.core.kernel_interface.indexers import NdRange, Range
import numba_dpex.core.types as types # noqa E402
from numba_dpex import config # noqa E402
from numba_dpex.core.kernel_interface.indexers import ( # noqa E402
NdRange,
Range,
)

# Re-export all type names
from numba_dpex.core.types import *
from numba_dpex.retarget import offload_to_sycl_device

from . import config
from numba_dpex.core.types import * # noqa E402
from numba_dpex.retarget import offload_to_sycl_device # noqa E402

if config.HAS_NON_HOST_DEVICE:
from .device_init import *
# Re export
from .core.targets import dpjit_target, kernel_target
from .decorators import dpjit, func, kernel

# We are importing dpnp stub module to make Numba recognize the
# module when we rename Numpy functions.
from .dpnp_iface.stubs import dpnp
from .ocl.stubs import (
GLOBAL_MEM_FENCE,
LOCAL_MEM_FENCE,
atomic,
barrier,
get_global_id,
get_global_size,
get_group_id,
get_local_id,
get_local_size,
get_num_groups,
get_work_dim,
local,
mem_fence,
private,
sub_group_barrier,
)

DEFAULT_LOCAL_SIZE = []
load_dpctl_sycl_interface()
del load_dpctl_sycl_interface
else:
raise ImportError("No non-host SYCL device found to execute kernels.")


from ._version import get_versions
from numba_dpex._version import get_versions # noqa E402

__version__ = get_versions()["version"]
del get_versions

__all__ = ["offload_to_sycl_device"] + types.__all__ + ["Range", "NdRange"]
__all__ = types.__all__ + ["offload_to_sycl_device"] + ["Range", "NdRange"]
Loading