-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][python] Add bindings for mlirDenseElementsAttrGet #91389
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir Author: None (pranavm-nvidia) ChangesThis change adds bindings for Full diff: https://github.com/llvm/llvm-project/pull/91389.diff 2 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index dda2003ba037..b7ad4f3a78b7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -72,6 +72,23 @@ or 255), then a splat will be created.
type or if the buffer does not meet expectations.
)";
+static const char kDenseElementsAttrGetFromListDocstring[] =
+ R"(Gets a DenseElementsAttr from a Python list of attributes.
+
+Args:
+ attrs: A list of attributes.
+ type: The desired shape and type of the resulting DenseElementsAttr.
+ If not provided, the element type is determined based on the types
+ of the attributes and the shape is `[len(attrs)]`.
+ context: Explicit context, if not from context manager.
+
+Returns:
+ DenseElementsAttr on success.
+
+Raises:
+ ValueError: If the type of the attributes does not match the type specified by `shaped_type`.
+)";
+
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
@@ -647,6 +664,55 @@ class PyDenseElementsAttribute
static constexpr const char *pyClassName = "DenseElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
+ static PyDenseElementsAttribute
+ getFromList(py::list attributes, std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper) {
+
+ if (py::len(attributes) == 0) {
+ throw py::value_error("Attributes list must be non-empty");
+ }
+
+ MlirType shapedType;
+ if (explicitType) {
+ if ((!mlirTypeIsAShaped(*explicitType) ||
+ !mlirShapedTypeHasStaticShape(*explicitType))) {
+ std::string message =
+ "Expected a static ShapedType for the shaped_type parameter: ";
+ message.append(py::repr(py::cast(*explicitType)));
+ throw py::value_error(message);
+ }
+ shapedType = *explicitType;
+ } else {
+ SmallVector<int64_t> shape{py::len(attributes)};
+ shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+ mlirAttributeGetNull());
+ }
+
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(py::len(attributes));
+ for (auto attribute : attributes) {
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+ MlirType attrType = mlirAttributeGetType(mlirAttribute);
+ mlirAttributes.push_back(mlirAttribute);
+
+ if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+ std::string message = "All attributes must be of the same type and "
+ "match the type parameter: expected=";
+ message.append(py::repr(py::cast(shapedType)));
+ message.append(", but got=");
+ message.append(py::repr(py::cast(attrType)));
+ throw py::value_error(message);
+ }
+ }
+
+ MlirAttribute elements = mlirDenseElementsAttrGet(
+ shapedType, mlirAttributes.size(), mlirAttributes.data());
+
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+ }
+
static PyDenseElementsAttribute
getFromBuffer(py::buffer array, bool signless,
std::optional<PyType> explicitType,
@@ -883,6 +949,10 @@ class PyDenseElementsAttribute
py::arg("type") = py::none(), py::arg("shape") = py::none(),
py::arg("context") = py::none(),
kDenseElementsAttrGetDocstring)
+ .def_static("get", PyDenseElementsAttribute::getFromList,
+ py::arg("attrs"), py::arg("type") = py::none(),
+ py::arg("context") = py::none(),
+ kDenseElementsAttrGetFromListDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
@@ -954,8 +1024,8 @@ class PyDenseElementsAttribute
}
}; // namespace
-/// Refinement of the PyDenseElementsAttribute for attributes containing integer
-/// (and boolean) values. Supports element access.
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
class PyDenseIntElementsAttribute
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
PyDenseElementsAttribute> {
@@ -964,8 +1034,8 @@ class PyDenseIntElementsAttribute
static constexpr const char *pyClassName = "DenseIntElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
- /// Returns the element at the given linear position. Asserts if the index is
- /// out of range.
+ /// Returns the element at the given linear position. Asserts if the index
+ /// is out of range.
py::int_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
throw py::index_error("attempt to access out of bounds element");
@@ -1267,7 +1337,8 @@ class PyStridedLayoutAttribute
return PyStridedLayoutAttribute(ctx->getRef(), attr);
},
py::arg("rank"), py::arg("context") = py::none(),
- "Gets a strided layout attribute with dynamic offset and strides of a "
+ "Gets a strided layout attribute with dynamic offset and strides of "
+ "a "
"given rank.");
c.def_property_readonly(
"offset",
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 9251588a4c48..2bc403aace83 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
print(np.array(attr))
+################################################################################
+# Tests of the list of attributes .get() factory method
+################################################################################
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromList
+@run
+def testGetDenseElementsFromList():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ attr = DenseElementsAttr.get(attrs)
+
+ # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+ print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
+@run
+def testGetDenseElementsFromListWithExplicitType():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
+ attr = DenseElementsAttr.get(attrs, shaped_type)
+
+ # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+ print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
+@run
+def testGetDenseElementsFromListEmptyList():
+ with Context(), Location.unknown():
+ attrs = []
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except ValueError as e:
+ # CHECK: Attributes list must be non-empty
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
+@run
+def testGetDenseElementsFromListNonAttributeType():
+ with Context(), Location.unknown():
+ attrs = [1.0]
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except RuntimeError as e:
+ # CHECK: Invalid attribute when attempting to create an ArrayAttribute
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
+@run
+def testGetDenseElementsFromListMismatchedType():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ shaped_type = ShapedType(Type.parse("tensor<2xf32>"))
+
+ try:
+ attr = DenseElementsAttr.get(attrs, shaped_type)
+ except ValueError as e:
+ # CHECK: All attributes must be of the same type and match the type parameter
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
+@run
+def testGetDenseElementsFromListMixedTypes():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except ValueError as e:
+ # CHECK: All attributes must be of the same type and match the type parameter
+ print(e)
+
+
################################################################################
# Splats.
################################################################################
@@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat():
### float and double arrays.
+
# CHECK-LABEL: TEST: testGetDenseElementsF16
@run
def testGetDenseElementsF16():
|
Just curious, is the idea that you would do something like?
? |
The build failure seems legit:
|
Yes, exactly, and then you could have |
8293dde
to
63a8b1a
Compare
Strangely, I did not see it locally. I guess I missed whichever CMake variable enables Not sure if the error in the Windows build is caused by my change:
|
This is a fluke |
Why someone would want this? Creating attributes is rather expensive. Should we rather accept python |
That would work too, but I think the C API Out of curiosity, what was the intended purpose of |
ftynse is talking about the many element gets (the |
Right, but doesn't this API still require us to construct the constituent elements individually? To pass the MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(
MlirType shapedType, intptr_t numElements, MlirAttribute const *elements); |
Yea sorry you're right of course (and now I see that you were exactly alluding/implying that). Sorry - it looks like it's indeed just an artifact of how the Python binding works: SmallVector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(py::len(attributes));
for (auto attribute : attributes) {
MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
MlirType attrType = mlirAttributeGetType(mlirAttribute);
mlirAttributes.push_back(mlirAttribute);
if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
std::string message = "All attributes must be of the same type and "
"match the type parameter: expected=";
message.append(py::repr(py::cast(shapedType)));
message.append(", but got=");
message.append(py::repr(py::cast(attrType)));
throw py::value_error(message);
}
} The C++ API just takes an static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
const char *data = reinterpret_cast<const char *>(values.data());
return getRawIntOrFloat(
type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
std::numeric_limits<T>::is_integer, std::numeric_limits<T>::is_signed);
} The more "fiducial" corresponding C API is MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
size_t rawBufferSize,
const void *rawBuffer) {
auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType));
ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
rawBufferSize);
... which is how |
That makes sense. So then assuming we accept a Python list of ints/floats, would I just do something like this in the bindings (pseudocode mostly)? auto rawBuffer = ...; // Allocate enough memory to hold sizeInBytes(targetType) * numElements
for (auto number : numbers) {
switch (targetType) {
case bfloat16:
rawBuffer[index] = float2bfloat(number);
... // Similar for all other types
}
}
return mlirDenseElementsAttrRawBufferGet(..., rawBuffer); Or is there already infrastructure to convert Python numbers? (sorry if this is a dumb question; I'm still quite new to MLIR) Also, is there any value in allowing existing attributes to be packed into a |
Looks about right but the only issue here is the memory/ownership: what to do with the alloc after the call returns?
Don't think so (unless I missed someting over the last few months).
Yea I don't know 🤷 doesn't hurt to have more functionality but maybe C APIs should be concise/minimal? I'll let @ftynse make the call since he has more experience. |
+1 to the perils of individual elemental attributes for anything real. I'm not exactly sure what the goal is. When I'm doing this kind of stuff, I'm always working on things that produce packed in-memory buffers of the thing being created... Because anything else is impractical for real use. You may need to do some creative bit manipulation, but you could do what you want without deps in the python standard library just using the array and struct modules, both of which let you swizzle packed data. Of course, since python knows nothing of bf16 or the f8 or weird types, you're on your own unless if you have a library that lets you do anything there. Same as in any language. If you're just wanting to use the fact that mlir/llvm already knows how to do these conversions, you could go the other way: teach the FloatAttr type how to give you back an untyped binary buffer of the value (ie. Implement the buffer protocol). Then you could cast that to a bytes in Python and have something, I guess. Won't be very good, and will be ruinously expensive if trying to use it as a general numeric conversion library (since it is still uniqued elemental attributes). Another alternative might be to provide a convenience binding for numeric conversion based on APFloat directly. For reasons I won't go into here, it is legal to use that C++ type directly in the python bindings sans C wrapper. |
Yeah, let me elaborate on the use case a bit more. Basically, I need to construct small
Yes, this is exactly it. As you noted, it's not too hard for types that
Would this be any more beneficial than what I'm doing in this PR? I'm imagining something like this: buffer = FloatAttr.get(BF16Type.get(), 1.0).tobytes() + FloatAttr.get(BF16Type.get(), 2.0).tobytes()
attr = DenseElementsAttr.get_from_buffer(buffer, ...) And I guess I'm still curious about what the intended use case for the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, you're getting push back because the original authors deeply regret DenseElementsAttr for most cases, having switched most cases that matter to DenseArrayAttr (for "normal" inline values) or DenseResourceAttr for bulk storage. In general, the attribute elements approach has cost many dearly in terms of compile time and overhead. So it's a little confusing thinking about enabling that further.
With that said, I don't think what you are proposing hurts anything. I'd probably draw the line if it required C API changes as not being worth it.
That makes sense. It sounds like |
It exists and is still used. We can certainly add ergonomics like this to it if helpful. I just haven't had time to review the patch in detail with an eye to landing. But yeah, most mature implementations have been limiting their use of that attribute generally for a long time, due to its design flaws. |
@stellaraccident if we're in agreement it can be merged, I can review with a fine-toothed comb (while you're out...). |
Thanks/appreciated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a solid PR/impl to me 👍
e267483
to
0703e70
Compare
I guess you'll need me to merge this for you (I need to approve the rerun) so let me know when you're ready. |
To replicate the API existing in C++:
It is possible to construct these attributes from an object implementing the Python buffer protocol, e.g.,
It's okay to add if there is a use case. I would rather not add code nobody uses, that only increases the maintenance load.
I won't object to there being |
11c716f
to
bf82552
Compare
@makslevental I think it's ready to go unless there are any other comments |
One more round of tests for good luck and then I'll merge. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
This change adds bindings for `mlirDenseElementsAttrGet` which accepts a list of MLIR attributes and constructs a DenseElementsAttr. This allows for creating `DenseElementsAttr`s of types not natively supported by Python (e.g. BF16) without requiring other dependencies (e.g. `numpy` + `ml-dtypes`).
bf82552
to
510942a
Compare
@pranavm-nvidia if Windows still isn't done by morning I'll merge anyway |
@pranavm-nvidia Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested Please check whether problems have been caused by your change specifically, as How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
This change adds bindings for
mlirDenseElementsAttrGet
which accepts a list of MLIR attributes and constructs a DenseElementsAttr. This allows for creatingDenseElementsAttr
s of types not natively supported by Python (e.g. BF16) without requiring other dependencies (e.g.numpy
+ml-dtypes
).