From 4e88d30a51f454862da6f6ea135901b1f57fbded Mon Sep 17 00:00:00 2001 From: Elias Joseph Date: Thu, 30 Mar 2023 23:03:53 +0000 Subject: [PATCH] added support for complex numbers in python bindings --- runtime/bindings/python/hal.cc | 6 ++++++ runtime/bindings/python/tests/vm_types_test.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index f970a709931c..bf43d93668e3 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -460,6 +460,12 @@ py::object MapElementTypeToDType(iree_hal_element_type_t element_type) { case IREE_HAL_ELEMENT_TYPE_FLOAT_64: dtype_string = "float64"; break; + case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: + dtype_string = "complex64"; + break; + case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: + dtype_string = "complex128"; + break; default: throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping"); } diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py index 671002649aaa..3e0a951c6ff3 100644 --- a/runtime/bindings/python/tests/vm_types_test.py +++ b/runtime/bindings/python/tests/vm_types_test.py @@ -60,7 +60,9 @@ def test_variant_list_buffers(self): (np.uint64, ET.UINT_64), # (np.float16, ET.FLOAT_16), # (np.float32, ET.FLOAT_32), # - (np.float64, ET.FLOAT_64)): + (np.float64, ET.FLOAT_64), # + (np.complex64, ET.COMPLEX_64), # + (np.complex128, ET.COMPLEX_128)): lst = rt.VmVariantList(5) ary1 = np.asarray([1, 2, 3, 4], dtype=dt) bv1 = device.allocator.allocate_buffer_copy(