Skip to content

Commit

Permalink
Fix Python dtype conversion for int64 on Windows. (iree-org#12880)
Browse files Browse the repository at this point in the history
Fixes iree-org#11080. The int64 and uint64
test cases here were failing on Windows as the element type mapping was
routing via the code `l`, which is a "C long int" - not an explicitly 64
bit type. This changes the mapping to always use the explicit "type
strings" (any string in `numpy.sctypeDict.keys()`, [shown in this
gist](https://gist.github.com/ScottTodd/ec1f7906e9c644eb47f74280d6c26229)).

Relates to iree-org#12872
  • Loading branch information
ScottTodd authored and NatashaKnk committed Jul 6, 2023
1 parent c94a22b commit 8bab0d3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 25 deletions.
2 changes: 0 additions & 2 deletions build_tools/cmake/ctest_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ if [[ "$OSTYPE" =~ ^msys ]]; then
"iree/tests/e2e/tensor_ops/check_vmvx_ukernel_local-task_unpack.mlir"
# TODO(#11070): Fix argument/result signature mismatch
"iree/tests/e2e/tosa_ops/check_vmvx_local-sync_microkernels_fully_connected.mlir"
# TODO(#11080): Fix arrays not matching in test_variant_list_buffers
"iree/runtime/bindings/python/vm_types_test"
)
elif [[ "$OSTYPE" =~ ^darwin ]]; then
excluded_tests+=(
Expand Down
37 changes: 21 additions & 16 deletions runtime/bindings/python/hal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,54 +411,59 @@ HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri,
namespace {

py::object MapElementTypeToDType(iree_hal_element_type_t element_type) {
// See: https://docs.python.org/3/c-api/arg.html#numbers
// TODO: Handle dtypes that do not map to a code (i.e. fp16).
const char* dtype_code;
// See:
// * https://numpy.org/doc/stable/reference/arrays.dtypes.html
// * https://docs.python.org/3/c-api/arg.html#numbers
//
// Single letter codes can be ambiguous across platforms, so prefer explicit
// bit depth values, ("Type strings: Any string in numpy.sctypeDict.keys()").
// See https://github.com/pybind/pybind11/issues/1908
const char* dtype_string;
switch (element_type) {
case IREE_HAL_ELEMENT_TYPE_BOOL_8:
dtype_code = "?";
dtype_string = "?";
break;
case IREE_HAL_ELEMENT_TYPE_INT_8:
case IREE_HAL_ELEMENT_TYPE_SINT_8:
dtype_code = "b";
dtype_string = "int8";
break;
case IREE_HAL_ELEMENT_TYPE_UINT_8:
dtype_code = "B";
dtype_string = "uint8";
break;
case IREE_HAL_ELEMENT_TYPE_INT_16:
case IREE_HAL_ELEMENT_TYPE_SINT_16:
dtype_code = "h";
dtype_string = "int16";
break;
case IREE_HAL_ELEMENT_TYPE_UINT_16:
dtype_code = "H";
dtype_string = "uint16";
break;
case IREE_HAL_ELEMENT_TYPE_INT_32:
case IREE_HAL_ELEMENT_TYPE_SINT_32:
dtype_code = "i";
dtype_string = "int32";
break;
case IREE_HAL_ELEMENT_TYPE_UINT_32:
dtype_code = "I";
dtype_string = "uint32";
break;
case IREE_HAL_ELEMENT_TYPE_INT_64:
case IREE_HAL_ELEMENT_TYPE_SINT_64:
dtype_code = "l";
dtype_string = "int64";
break;
case IREE_HAL_ELEMENT_TYPE_UINT_64:
dtype_code = "L";
dtype_string = "uint64";
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
dtype_code = "e";
dtype_string = "float16";
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
dtype_code = "f";
dtype_string = "float32";
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
dtype_code = "d";
dtype_string = "float64";
break;
default:
throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping");
}
return py::dtype(dtype_code);
return py::dtype(dtype_string);
}

} // namespace
Expand Down
19 changes: 12 additions & 7 deletions runtime/bindings/python/tests/vm_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,18 @@ def test_variant_list_i64(self):
def test_variant_list_buffers(self):
device = rt.get_device("local-sync")
ET = rt.HalElementType
for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16),
(np.int32, ET.SINT_32), (np.int64, ET.SINT_64),
(np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16),
(np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64),
(np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)):
# TODO: Unimplemented: (np.float16, ET.FLOAT_16)
for dt, et in (
(np.int8, ET.SINT_8), #
(np.int16, ET.SINT_16), #
(np.int32, ET.SINT_32), #
(np.int64, ET.SINT_64), #
(np.uint8, ET.UINT_8), #
(np.uint16, ET.UINT_16), #
(np.uint32, ET.UINT_32), #
(np.uint64, ET.UINT_64), #
(np.float16, ET.FLOAT_16), #
(np.float32, ET.FLOAT_32), #
(np.float64, ET.FLOAT_64)):
lst = rt.VmVariantList(5)
ary1 = np.asarray([1, 2, 3, 4], dtype=dt)
bv1 = device.allocator.allocate_buffer_copy(
Expand All @@ -65,7 +71,6 @@ def test_variant_list_buffers(self):
lst.push_ref(bv1)
ary2 = rt.DeviceArray(device,
lst.get_as_object(0, rt.HalBufferView),
override_dtype=dt,
implicit_host_transfer=True)
np.testing.assert_array_equal(ary1, ary2)
with self.assertRaises(IndexError):
Expand Down

0 comments on commit 8bab0d3

Please sign in to comment.