diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index dda2003ba0375a9..b5f31aa5dec54f3 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -15,6 +15,7 @@ #include "PybindUtils.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/raw_ostream.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" @@ -72,6 +73,27 @@ or 255), then a splat will be created. type or if the buffer does not meet expectations. )"; +static const char kDenseElementsAttrGetFromListDocstring[] = + R"(Gets a DenseElementsAttr from a Python list of attributes. + +Note that it can be expensive to construct attributes individually. +For a large number of elements, consider using a Python buffer or array instead. + +Args: + attrs: A list of attributes. + type: The desired shape and type of the resulting DenseElementsAttr. + If not provided, the element type is determined based on the type + of the 0th attribute and the shape is `[len(attrs)]`. + context: Explicit context, if not from context manager. + +Returns: + DenseElementsAttr on success. + +Raises: + ValueError: If the type of the attributes does not match the type + specified by `shaped_type`. +)"; + static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = R"(Gets a DenseResourceElementsAttr from a Python buffer or array. @@ -647,6 +669,57 @@ class PyDenseElementsAttribute static constexpr const char *pyClassName = "DenseElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static PyDenseElementsAttribute + getFromList(py::list attributes, std::optional explicitType, + DefaultingPyMlirContext contextWrapper) { + + const size_t numAttributes = py::len(attributes); + if (numAttributes == 0) + throw py::value_error("Attributes list must be non-empty."); + + MlirType shapedType; + if (explicitType) { + if ((!mlirTypeIsAShaped(*explicitType) || + !mlirShapedTypeHasStaticShape(*explicitType))) { + + std::string message; + llvm::raw_string_ostream os(message); + os << "Expected a static ShapedType for the shaped_type parameter: " + << py::repr(py::cast(*explicitType)); + throw py::value_error(os.str()); + } + shapedType = *explicitType; + } else { + SmallVector shape{static_cast(numAttributes)}; + shapedType = mlirRankedTensorTypeGet( + shape.size(), shape.data(), + mlirAttributeGetType(pyTryCast(attributes[0])), + mlirAttributeGetNull()); + } + + SmallVector mlirAttributes; + mlirAttributes.reserve(numAttributes); + for (const py::handle &attribute : attributes) { + MlirAttribute mlirAttribute = pyTryCast(attribute); + MlirType attrType = mlirAttributeGetType(mlirAttribute); + mlirAttributes.push_back(mlirAttribute); + + if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { + std::string message; + llvm::raw_string_ostream os(message); + os << "All attributes must be of the same type and match " + << "the type parameter: expected=" << py::repr(py::cast(shapedType)) + << ", but got=" << py::repr(py::cast(attrType)); + throw py::value_error(os.str()); + } + } + + MlirAttribute elements = mlirDenseElementsAttrGet( + shapedType, mlirAttributes.size(), mlirAttributes.data()); + + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); + } + static PyDenseElementsAttribute getFromBuffer(py::buffer array, bool signless, std::optional explicitType, @@ -883,6 +956,10 @@ class PyDenseElementsAttribute py::arg("type") = py::none(), py::arg("shape") = py::none(), py::arg("context") = py::none(), kDenseElementsAttrGetDocstring) + .def_static("get", PyDenseElementsAttribute::getFromList, + py::arg("attrs"), py::arg("type") = py::none(), + py::arg("context") = py::none(), + kDenseElementsAttrGetFromListDocstring) .def_static("get_splat", PyDenseElementsAttribute::getSplat, py::arg("shaped_type"), py::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py index 9251588a4c48a6e..2bc403aace83487 100644 --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided(): print(np.array(attr)) +################################################################################ +# Tests of the list of attributes .get() factory method +################################################################################ + + +# CHECK-LABEL: TEST: testGetDenseElementsFromList +@run +def testGetDenseElementsFromList(): + with Context(), Location.unknown(): + attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)] + attr = DenseElementsAttr.get(attrs) + + # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64> + print(attr) + + +# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType +@run +def testGetDenseElementsFromListWithExplicitType(): + with Context(), Location.unknown(): + attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)] + shaped_type = ShapedType(Type.parse("tensor<2xf64>")) + attr = DenseElementsAttr.get(attrs, shaped_type) + + # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64> + print(attr) + + +# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList +@run +def testGetDenseElementsFromListEmptyList(): + with Context(), Location.unknown(): + attrs = [] + + try: + attr = DenseElementsAttr.get(attrs) + except ValueError as e: + # CHECK: Attributes list must be non-empty + print(e) + + +# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType +@run +def testGetDenseElementsFromListNonAttributeType(): + with Context(), Location.unknown(): + attrs = [1.0] + + try: + attr = DenseElementsAttr.get(attrs) + except RuntimeError as e: + # CHECK: Invalid attribute when attempting to create an ArrayAttribute + print(e) + + +# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType +@run +def testGetDenseElementsFromListMismatchedType(): + with Context(), Location.unknown(): + attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)] + shaped_type = ShapedType(Type.parse("tensor<2xf32>")) + + try: + attr = DenseElementsAttr.get(attrs, shaped_type) + except ValueError as e: + # CHECK: All attributes must be of the same type and match the type parameter + print(e) + + +# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes +@run +def testGetDenseElementsFromListMixedTypes(): + with Context(), Location.unknown(): + attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)] + + try: + attr = DenseElementsAttr.get(attrs) + except ValueError as e: + # CHECK: All attributes must be of the same type and match the type parameter + print(e) + + ################################################################################ # Splats. ################################################################################ @@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat(): ### float and double arrays. + # CHECK-LABEL: TEST: testGetDenseElementsF16 @run def testGetDenseElementsF16():