Skip to content

Commit

Permalink
Merge pull request #1023 from IntelPython/fix/rambo-dpnp-impl
Browse files Browse the repository at this point in the history
Fixes stride calculation when unboxing a dpnp ndarray
  • Loading branch information
Diptorup Deb authored Apr 28, 2023
2 parents 094f852 + 43920fb commit 7e62861
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 35 deletions.
28 changes: 21 additions & 7 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
arystruct_t *arystruct)
{
struct PyUSMArrayObject *arrayobj = NULL;
int i = 0, ndim = 0, exp = 0;
int i = 0, j = 0, k = 0, ndim = 0, exp = 0;
npy_intp *shape = NULL, *strides = NULL;
npy_intp *p = NULL, nitems;
void *data = NULL;
Expand Down Expand Up @@ -812,10 +812,12 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
}
}
else {
for (i = 1; i < ndim; ++i, ++p) {
*p = shape[i] << exp;
for (i = ndim * 2 - 1; i >= ndim; --i, ++p) {
*p = 1;
for (j = i, k = ndim - 1; j > ndim; --j, --k)
*p *= shape[k];
*p <<= exp;
}
*p = 1;
}

return 0;
Expand Down Expand Up @@ -859,13 +861,21 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
struct PyUSMArrayObject *arrayobj = NULL;
npy_intp itemsize = 0;

DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: In try_to_return_parent.\n"));
DPEXRT_DEBUG(
drt_debug_print("DPEXRT-DEBUG: In box_from_arystruct_parent.\n"));

if (!(arrayobj = PyUSMNdArray_ARRAYOBJ(arystruct->parent)))
if (!(arrayobj = PyUSMNdArray_ARRAYOBJ(arystruct->parent))) {
DPEXRT_DEBUG(
drt_debug_print("DPEXRT-DEBUG: Arrayobj cannot be boxed from "
"parent as parent pointer is NULL.\n"));
return NULL;
}

if ((void *)UsmNDArray_GetData(arrayobj) != arystruct->data)
if ((void *)UsmNDArray_GetData(arrayobj) != arystruct->data) {
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: Arrayobj cannot be boxed "
"from parent as data pointer is NULL.\n"));
return NULL;
}

if (UsmNDArray_GetNDim(arrayobj) != ndim)
return NULL;
Expand Down Expand Up @@ -985,6 +995,10 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
// return back to Python memory that was allocated inside Numba and let
// Python manage the lifetime of the memory.
if (arystruct->meminfo) {
DPEXRT_DEBUG(
drt_debug_print("DPEXRT-DEBUG: Set the base of the boxed array "
"from arystruct's meminfo pointer at %s, line %d\n",
__FILE__, __LINE__));
// wrap into MemInfoObject
if (!(miobj = PyObject_New(MemInfoObject, &MemInfoType))) {
PyErr_Format(PyExc_ValueError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

import dpctl
import pytest

from numba_dpex import dpjit

Expand Down
27 changes: 0 additions & 27 deletions numba_dpex/tests/core/types/DpnpNdArray/test_boxing.py

This file was deleted.

47 changes: 47 additions & 0 deletions numba_dpex/tests/core/types/DpnpNdArray/test_boxing_unboxing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Tests for boxing for dpnp.ndarray
"""

import dpnp

from numba_dpex import dpjit


def test_boxing_unboxing():
"""Tests basic boxing and unboxing of a dpnp.ndarray object.
Checks if we can pass in and return a dpctl.ndarray object to and
from a dpjit decorated function.
"""

@dpjit
def func(a):
return a

a = dpnp.empty(10)
try:
b = func(a)
except:
assert False, "Failure during unbox/box of dpnp.ndarray"

assert a.shape == b.shape
assert a.device == b.device
assert a.strides == b.strides
assert a.dtype == b.dtype


def test_stride_calc_at_unboxing():
"""Tests if strides were correctly computed during unboxing."""

def _tester(a):
return a.strides

b = dpnp.empty((4, 16, 4))
strides = dpjit(_tester)(b)

# Numba computes strides as bytes
assert list(strides) == [512, 32, 8]
62 changes: 62 additions & 0 deletions numba_dpex/tests/dpjit_tests/test_slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Tests for boxing for dpnp.ndarray
"""

import dpnp
import numpy

from numba_dpex import dpjit


def test_1d_slicing():
"""Tests if dpjit properly computes strides and returns them to Python."""

def _tester(a):
return a[1:5]

a = dpnp.arange(10)
b = dpnp.asnumpy(dpjit(_tester)(a))

na = numpy.arange(10)
nb = _tester(na)

assert (b == nb).all()


def test_1d_slicing2():
"""Tests if dpjit properly computes strides and returns them to Python."""

def _tester(a):
b = a[1:4]
a[6:9] = b

a = dpnp.arange(10)
b = dpnp.asnumpy(dpjit(_tester)(a))

na = numpy.arange(10)
nb = _tester(na)

assert (b == nb).all()


def test_multidim_slicing():
"""Tests if dpjit properly slices strides and returns them to Python."""

def _tester(a, b):
b[:, :, 0] = a

a = dpnp.arange(64, dtype=numpy.int64)
a = a.reshape(4, 16)
b = dpnp.empty((4, 16, 4), dtype=numpy.int64)
dpjit(_tester)(a, b)

na = numpy.arange(64, dtype=numpy.int64)
na = na.reshape(4, 16)
nb = numpy.empty((4, 16, 4), dtype=numpy.int64)
_tester(na, nb)

assert (nb[:, :, 0] == dpnp.asnumpy(b)[:, :, 0]).all()

0 comments on commit 7e62861

Please sign in to comment.