From 04d85557969e12a5bf0efd7884a1f4aa5dcaf459 Mon Sep 17 00:00:00 2001 From: Eliasj42 <46754803+Eliasj42@users.noreply.github.com> Date: Tue, 4 Apr 2023 13:02:24 -0700 Subject: [PATCH] added support for complex numbers in python bindings (#12872) added support for complex numbers in python bindings Co-authored-by: Elias Joseph --- runtime/bindings/python/hal.cc | 6 ++++++ runtime/bindings/python/tests/vm_types_test.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index f970a709931c..bf43d93668e3 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -460,6 +460,12 @@ py::object MapElementTypeToDType(iree_hal_element_type_t element_type) { case IREE_HAL_ELEMENT_TYPE_FLOAT_64: dtype_string = "float64"; break; + case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: + dtype_string = "complex64"; + break; + case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: + dtype_string = "complex128"; + break; default: throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping"); } diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py index 671002649aaa..3e0a951c6ff3 100644 --- a/runtime/bindings/python/tests/vm_types_test.py +++ b/runtime/bindings/python/tests/vm_types_test.py @@ -60,7 +60,9 @@ def test_variant_list_buffers(self): (np.uint64, ET.UINT_64), # (np.float16, ET.FLOAT_16), # (np.float32, ET.FLOAT_32), # - (np.float64, ET.FLOAT_64)): + (np.float64, ET.FLOAT_64), # + (np.complex64, ET.COMPLEX_64), # + (np.complex128, ET.COMPLEX_128)): lst = rt.VmVariantList(5) ary1 = np.asarray([1, 2, 3, 4], dtype=dt) bv1 = device.allocator.allocate_buffer_copy(