Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][python] Add bindings for mlirDenseElementsAttrGet #91389

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "PybindUtils.h"

#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"

#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
Expand Down Expand Up @@ -72,6 +73,27 @@ or 255), then a splat will be created.
type or if the buffer does not meet expectations.
)";

static const char kDenseElementsAttrGetFromListDocstring[] =
R"(Gets a DenseElementsAttr from a Python list of attributes.
pranavm-nvidia marked this conversation as resolved.
Show resolved Hide resolved

Note that it can be expensive to construct attributes individually.
For a large number of elements, consider using a Python buffer or array instead.

Args:
attrs: A list of attributes.
type: The desired shape and type of the resulting DenseElementsAttr.
If not provided, the element type is determined based on the type
of the 0th attribute and the shape is `[len(attrs)]`.
context: Explicit context, if not from context manager.

Returns:
DenseElementsAttr on success.

Raises:
ValueError: If the type of the attributes does not match the type
specified by `shaped_type`.
)";

static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.

Expand Down Expand Up @@ -647,6 +669,57 @@ class PyDenseElementsAttribute
static constexpr const char *pyClassName = "DenseElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;

static PyDenseElementsAttribute
getFromList(py::list attributes, std::optional<PyType> explicitType,
DefaultingPyMlirContext contextWrapper) {

const size_t numAttributes = py::len(attributes);
if (numAttributes == 0)
throw py::value_error("Attributes list must be non-empty.");

MlirType shapedType;
if (explicitType) {
if ((!mlirTypeIsAShaped(*explicitType) ||
!mlirShapedTypeHasStaticShape(*explicitType))) {

std::string message;
llvm::raw_string_ostream os(message);
os << "Expected a static ShapedType for the shaped_type parameter: "
<< py::repr(py::cast(*explicitType));
throw py::value_error(os.str());
}
shapedType = *explicitType;
} else {
SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
shapedType = mlirRankedTensorTypeGet(
shape.size(), shape.data(),
mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
mlirAttributeGetNull());
}

SmallVector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(numAttributes);
for (const py::handle &attribute : attributes) {
MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
MlirType attrType = mlirAttributeGetType(mlirAttribute);
mlirAttributes.push_back(mlirAttribute);

if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
std::string message;
llvm::raw_string_ostream os(message);
os << "All attributes must be of the same type and match "
<< "the type parameter: expected=" << py::repr(py::cast(shapedType))
<< ", but got=" << py::repr(py::cast(attrType));
throw py::value_error(os.str());
}
}

MlirAttribute elements = mlirDenseElementsAttrGet(
shapedType, mlirAttributes.size(), mlirAttributes.data());

return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
}

static PyDenseElementsAttribute
getFromBuffer(py::buffer array, bool signless,
std::optional<PyType> explicitType,
Expand Down Expand Up @@ -883,6 +956,10 @@ class PyDenseElementsAttribute
py::arg("type") = py::none(), py::arg("shape") = py::none(),
py::arg("context") = py::none(),
kDenseElementsAttrGetDocstring)
.def_static("get", PyDenseElementsAttribute::getFromList,
py::arg("attrs"), py::arg("type") = py::none(),
py::arg("context") = py::none(),
kDenseElementsAttrGetFromListDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
Expand Down
82 changes: 82 additions & 0 deletions mlir/test/python/ir/array_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
print(np.array(attr))


################################################################################
# Tests of the list of attributes .get() factory method
################################################################################


# CHECK-LABEL: TEST: testGetDenseElementsFromList
@run
def testGetDenseElementsFromList():
with Context(), Location.unknown():
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
attr = DenseElementsAttr.get(attrs)

# CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
print(attr)


# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
@run
def testGetDenseElementsFromListWithExplicitType():
with Context(), Location.unknown():
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
attr = DenseElementsAttr.get(attrs, shaped_type)

# CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
print(attr)


# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
@run
def testGetDenseElementsFromListEmptyList():
with Context(), Location.unknown():
attrs = []

try:
attr = DenseElementsAttr.get(attrs)
except ValueError as e:
# CHECK: Attributes list must be non-empty
print(e)


# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
@run
def testGetDenseElementsFromListNonAttributeType():
with Context(), Location.unknown():
attrs = [1.0]

try:
attr = DenseElementsAttr.get(attrs)
except RuntimeError as e:
# CHECK: Invalid attribute when attempting to create an ArrayAttribute
print(e)


# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
@run
def testGetDenseElementsFromListMismatchedType():
with Context(), Location.unknown():
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
shaped_type = ShapedType(Type.parse("tensor<2xf32>"))

try:
attr = DenseElementsAttr.get(attrs, shaped_type)
except ValueError as e:
# CHECK: All attributes must be of the same type and match the type parameter
print(e)


# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
@run
def testGetDenseElementsFromListMixedTypes():
with Context(), Location.unknown():
attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]

try:
attr = DenseElementsAttr.get(attrs)
except ValueError as e:
# CHECK: All attributes must be of the same type and match the type parameter
print(e)


################################################################################
# Splats.
################################################################################
Expand Down Expand Up @@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat():

### float and double arrays.


# CHECK-LABEL: TEST: testGetDenseElementsF16
@run
def testGetDenseElementsF16():
Expand Down
Loading