Skip to content

Commit

Permalink
added support for complex numbers in python bindings (iree-org#12872)
Browse files Browse the repository at this point in the history
added support for complex numbers in python bindings

Co-authored-by: Elias Joseph <elias@nod-labs.com>
  • Loading branch information
2 people authored and NatashaKnk committed Jul 6, 2023
1 parent 5bd4dcf commit 04d8555
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
6 changes: 6 additions & 0 deletions runtime/bindings/python/hal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,12 @@ py::object MapElementTypeToDType(iree_hal_element_type_t element_type) {
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
dtype_string = "float64";
break;
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64:
dtype_string = "complex64";
break;
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128:
dtype_string = "complex128";
break;
default:
throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping");
}
Expand Down
4 changes: 3 additions & 1 deletion runtime/bindings/python/tests/vm_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def test_variant_list_buffers(self):
(np.uint64, ET.UINT_64), #
(np.float16, ET.FLOAT_16), #
(np.float32, ET.FLOAT_32), #
(np.float64, ET.FLOAT_64)):
(np.float64, ET.FLOAT_64), #
(np.complex64, ET.COMPLEX_64), #
(np.complex128, ET.COMPLEX_128)):
lst = rt.VmVariantList(5)
ary1 = np.asarray([1, 2, 3, 4], dtype=dt)
bv1 = device.allocator.allocate_buffer_copy(
Expand Down

0 comments on commit 04d8555

Please sign in to comment.