Skip to content

Commit

Permalink
[mlir][LLVM] Add support for constant struct with multiple fields (ll…
Browse files Browse the repository at this point in the history
…vm#102752)

Currently `mlir.llvm.constant` of structure types restricts that the
structure type effectively represents a complex type -- it must have
exactly two fields of the same type and the field type must be either an
integer type or a float type.

This PR relaxes this restriction and it allows the structure type to
have an arbitrary number of fields.
  • Loading branch information
Lancern authored and dmpolukhin committed Sep 2, 2024
1 parent 284536f commit 50b262a
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 55 deletions.
37 changes: 24 additions & 13 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1620,19 +1620,30 @@ def LLVM_ConstantOp
let description = [{
Unlike LLVM IR, MLIR does not have first-class constant values. Therefore,
all constants must be created as SSA values before being used in other
operations. `llvm.mlir.constant` creates such values for scalars and
vectors. It has a mandatory `value` attribute, which may be an integer,
floating point attribute; dense or sparse attribute containing integers or
floats. The type of the attribute is one of the corresponding MLIR builtin
types. It may be omitted for `i64` and `f64` types that are implied.

The operation produces a new SSA value of the specified LLVM IR dialect
type. Certain builtin types such as integer, float and vector types are
also allowed. The result type _must_ correspond to the attribute type
converted to LLVM IR. In particular, the number of elements of a container
type must match the number of elements in the attribute. If the type is or
contains a scalable vector type, the attribute must be a splat elements
attribute.
operations. `llvm.mlir.constant` creates such values for scalars, vectors,
strings, and structs. It has a mandatory `value` attribute whose type
depends on the type of the constant value. The type of the constant value
must correspond to the attribute type converted to LLVM IR type.

When creating constant scalars, the `value` attribute must be either an
integer attribute or a floating point attribute. The type of the attribute
may be omitted for `i64` and `f64` types that are implied.

When creating constant vectors, the `value` attribute must be either an
array attribute, a dense attribute, or a sparse attribute that contains
integers or floats. The number of elements in the result vector must match
the number of elements in the attribute.

When creating constant strings, the `value` attribute must be a string
attribute. The type of the constant must be an LLVM array of `i8`s, and the
length of the array must match the length of the attribute.

When creating constant structs, the `value` attribute must be an array
attribute that contains integers or floats. The type of the constant must be
an LLVM struct type. The number of fields in the struct must match the
number of elements in the attribute, and the type of each LLVM struct field
must correspond to the type of the corresponding attribute element converted
to LLVM IR.

Examples:

Expand Down
47 changes: 27 additions & 20 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -2710,32 +2711,38 @@ LogicalResult LLVM::ConstantOp::verify() {
}
return success();
}
if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
if (structType.getBody().size() != 2 ||
structType.getBody()[0] != structType.getBody()[1]) {
return emitError() << "expected struct type with two elements of the "
"same type, the type of a complex constant";
if (auto structType = dyn_cast<LLVMStructType>(getType())) {
auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
if (!arrayAttr) {
return emitOpError() << "expected array attribute for a struct constant";
}

auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
if (!arrayAttr || arrayAttr.size() != 2) {
return emitOpError() << "expected array attribute with two elements, "
"representing a complex constant";
ArrayRef<Type> elementTypes = structType.getBody();
if (arrayAttr.size() != elementTypes.size()) {
return emitOpError() << "expected array attribute of size "
<< elementTypes.size();
}
auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
if (!re || !im || re.getType() != im.getType()) {
return emitOpError()
<< "expected array attribute with two elements of the same type";
for (auto elementTy : elementTypes) {
if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
return emitOpError() << "expected struct element types to be floating "
"point type or integer type";
}
}

Type elementType = structType.getBody()[0];
if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
elementType)) {
return emitError()
<< "expected struct element types to be floating point type or "
"integer type";
for (size_t i = 0; i < elementTypes.size(); ++i) {
Attribute element = arrayAttr[i];
if (!isa<IntegerAttr, FloatAttr>(element)) {
return emitOpError()
<< "expected struct element attribute types to be floating "
"point type or integer type";
}
auto elementType = cast<TypedAttr>(element).getType();
if (elementType != elementTypes[i]) {
return emitOpError()
<< "struct element at index " << i << " is of wrong type";
}
}

return success();
}
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
Expand Down
25 changes: 13 additions & 12 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,20 +557,21 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
return llvm::UndefValue::get(llvmType);
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
if (!arrayAttr || arrayAttr.size() != 2) {
emitError(loc, "expected struct type to be a complex number");
if (!arrayAttr) {
emitError(loc, "expected an array attribute for a struct constant");
return nullptr;
}
llvm::Type *elementType = structType->getElementType(0);
llvm::Constant *real =
getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
if (!real)
return nullptr;
llvm::Constant *imag =
getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
if (!imag)
return nullptr;
return llvm::ConstantStruct::get(structType, {real, imag});
SmallVector<llvm::Constant *> structElements;
structElements.reserve(structType->getNumElements());
for (auto [elemType, elemAttr] :
zip_equal(structType->elements(), arrayAttr)) {
llvm::Constant *element =
getLLVMConstant(elemType, elemAttr, loc, moduleTranslation);
if (!element)
return nullptr;
structElements.push_back(element);
}
return llvm::ConstantStruct::get(structType, structElements);
}
// For integer types, we allow a mismatch in sizes as the index type in
// MLIR might have a different size than the index type in the LLVM module.
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -367,39 +367,39 @@ func.func @constant_wrong_type_string() {
// -----

llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
// expected-error @+1 {{expected array attribute of size 2}}
%0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute with two elements of the same type}}
// expected-error @+1 {{struct element at index 1 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
// expected-error @+1 {{expected array attribute}}
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
// expected-error @+1 {{expected array attribute of size 1}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
llvm.return %0 : !llvm.struct<(f64)>
}

// -----

llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
// expected-error @+1 {{struct element at index 1 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
llvm.return %0 : !llvm.struct<(f64, f32)>
}
Expand Down
26 changes: 21 additions & 5 deletions mlir/test/Target/LLVMIR/llvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,40 @@ llvm.func @vector_with_non_vector_type() -> f32 {

// -----

llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @below{{expected struct type to be a complex number}}
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
}

// -----

llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
// expected-error @below{{expected struct type to be a complex number}}
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
// expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
}

// -----

llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
// expected-error @below{{expected struct element types to be floating point type or integer type}}
%0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
}

// -----

llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
// expected-error @below{{expected struct element attribute types to be floating point type or integer type}}
%0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
// expected-error @below{{FloatAttr does not match expected type of the constant}}
// expected-error @below{{struct element at index 0 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Target/LLVMIR/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
}

llvm.func @structconstant() -> !llvm.struct<(i32, f32)> {
%1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)>
// CHECK: ret { i32, float } { i32 1, float 2.000000e+00 }
llvm.return %1 : !llvm.struct<(i32, f32)>
}

// CHECK-LABEL: @indexconstantsplat
llvm.func @indexconstantsplat() -> vector<3xi32> {
%1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>
Expand Down

0 comments on commit 50b262a

Please sign in to comment.