diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index f970a709931cf..bf43d93668e30 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -460,6 +460,12 @@ py::object MapElementTypeToDType(iree_hal_element_type_t element_type) { case IREE_HAL_ELEMENT_TYPE_FLOAT_64: dtype_string = "float64"; break; + case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: + dtype_string = "complex64"; + break; + case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: + dtype_string = "complex128"; + break; default: throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping"); } diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py index 671002649aaac..3e0a951c6ff3a 100644 --- a/runtime/bindings/python/tests/vm_types_test.py +++ b/runtime/bindings/python/tests/vm_types_test.py @@ -60,7 +60,9 @@ def test_variant_list_buffers(self): (np.uint64, ET.UINT_64), # (np.float16, ET.FLOAT_16), # (np.float32, ET.FLOAT_32), # - (np.float64, ET.FLOAT_64)): + (np.float64, ET.FLOAT_64), # + (np.complex64, ET.COMPLEX_64), # + (np.complex128, ET.COMPLEX_128)): lst = rt.VmVariantList(5) ary1 = np.asarray([1, 2, 3, 4], dtype=dt) bv1 = device.allocator.allocate_buffer_copy(