Skip to content

Commit

Permalink
[mlir][emitc] Add ArrayType (llvm#83386)
Browse files Browse the repository at this point in the history
This models a one or multi-dimensional C/C++ array.

The type implements the `ShapedTypeInterface` and prints similar to
memref/tensor:
```
  %arg0: !emitc.array<1xf32>,
  %arg1: !emitc.array<10x20x30xi32>,
  %arg2: !emitc.array<30x!emitc.ptr<i32>>,
  %arg3: !emitc.array<30x!emitc.opaque<"int">>
```

It can be translated to a C array type when used as function parameter
or as `emitc.variable` type.
  • Loading branch information
mgehre-amd authored Mar 11, 2024
1 parent d99bb01 commit 818af71
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 8 deletions.
52 changes: 50 additions & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,64 @@

include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"

//===----------------------------------------------------------------------===//
// EmitC type definitions
//===----------------------------------------------------------------------===//

class EmitC_Type<string name, string typeMnemonic>
: TypeDef<EmitC_Dialect, name> {
class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<EmitC_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
let summary = "EmitC array type";

let description = [{
An array data type.

Example:

```mlir
// Array emitted as `int32_t[10]`
!emitc.array<10xi32>
// Array emitted as `float[10][20]`
!emitc.ptr<10x20xf32>
```
}];

let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType
);

let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape,
"Type":$elementType
), [{
return $_get(elementType.getContext(), shape, elementType);
}]>
];
let extraClassDeclaration = [{
/// Returns if this type is ranked (always true).
bool hasRank() const { return true; }

/// Clone this array type with the given shape and element type. If the
/// provided shape is `std::nullopt`, the current shape of the type is used.
ArrayType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;

static bool isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
llvm::isa<PointerType, OpaqueType>(type);
}
}];
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}

def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
let summary = "EmitC opaque type";

Expand Down
73 changes: 73 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ LogicalResult emitc::AssignOp::verify() {
return emitOpError() << "requires value's type (" << value.getType()
<< ") to match variable's type (" << variable.getType()
<< ")";
if (isa<ArrayType>(variable.getType()))
return emitOpError() << "cannot assign to array type";
return success();
}

Expand Down Expand Up @@ -192,6 +194,11 @@ LogicalResult emitc::CallOpaqueOp::verify() {
}
}

if (llvm::any_of(getResultTypes(),
[](Type type) { return isa<ArrayType>(type); })) {
return emitOpError() << "cannot return array type";
}

return success();
}

Expand Down Expand Up @@ -456,6 +463,9 @@ LogicalResult FuncOp::verify() {
return emitOpError("requires zero or exactly one result, but has ")
<< getNumResults();

if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
return emitOpError("cannot return array type");

return success();
}

Expand Down Expand Up @@ -763,6 +773,69 @@ LogicalResult emitc::YieldOp::verify() {
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"

//===----------------------------------------------------------------------===//
// ArrayType
//===----------------------------------------------------------------------===//

Type emitc::ArrayType::parse(AsmParser &parser) {
if (parser.parseLess())
return Type();

SmallVector<int64_t, 4> dimensions;
if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
/*withTrailingX=*/true))
return Type();
// Parse the element type.
auto typeLoc = parser.getCurrentLocation();
Type elementType;
if (parser.parseType(elementType))
return Type();

// Check that array is formed from allowed types.
if (!isValidElementType(elementType))
return parser.emitError(typeLoc, "invalid array element type"), Type();
if (parser.parseGreater())
return Type();
return parser.getChecked<ArrayType>(dimensions, elementType);
}

void emitc::ArrayType::print(AsmPrinter &printer) const {
printer << "<";
for (int64_t dim : getShape()) {
printer << dim << 'x';
}
printer.printType(getElementType());
printer << ">";
}

LogicalResult emitc::ArrayType::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
::llvm::ArrayRef<int64_t> shape, Type elementType) {
if (shape.empty())
return emitError() << "shape must not be empty";

for (int64_t dim : shape) {
if (dim <= 0)
return emitError() << "dimensions must have positive size";
}

if (!elementType)
return emitError() << "element type must not be none";

if (!isValidElementType(elementType))
return emitError() << "invalid array element type";

return success();
}

emitc::ArrayType
emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
if (!shape)
return emitc::ArrayType::get(getShape(), elementType);
return emitc::ArrayType::get(*shape, elementType);
}

//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
Expand Down
54 changes: 48 additions & 6 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ struct CppEmitter {
LogicalResult emitVariableDeclaration(OpResult result,
bool trailingSemicolon);

/// Emits a declaration of a variable with the given type and name.
LogicalResult emitVariableDeclaration(Location loc, Type type,
StringRef name);

/// Emits the variable declaration and assignment prefix for 'op'.
/// - emits separate variable followed by std::tie for multi-valued operation;
/// - emits single type followed by variable for single result;
Expand Down Expand Up @@ -870,10 +874,8 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,

return (interleaveCommaWithError(
arguments, os, [&](BlockArgument arg) -> LogicalResult {
if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
return failure();
os << " " << emitter.getOrCreateName(arg);
return success();
return emitter.emitVariableDeclaration(
functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
}));
}

Expand Down Expand Up @@ -917,6 +919,9 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
if (emitter.hasValueInScope(arg))
return functionOp->emitOpError(" block argument #")
<< arg.getArgNumber() << " is out of scope";
if (isa<ArrayType>(arg.getType()))
return functionOp->emitOpError("cannot emit block argument #")
<< arg.getArgNumber() << " with array type";
if (failed(
emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
return failure();
Expand Down Expand Up @@ -960,6 +965,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
"with multiple blocks needs variables declared at top");
}

if (llvm::any_of(functionOp.getResultTypes(),
[](Type type) { return isa<ArrayType>(type); })) {
return functionOp.emitOpError() << "cannot emit array type as result type";
}

CppEmitter::Scope scope(emitter);
raw_indented_ostream &os = emitter.ostream();
if (failed(emitter.emitTypes(functionOp.getLoc(),
Expand Down Expand Up @@ -1306,9 +1316,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
}
if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
result.getType(),
getOrCreateName(result))))
return failure();
os << " " << getOrCreateName(result);
if (trailingSemicolon)
os << ";\n";
return success();
Expand Down Expand Up @@ -1403,6 +1414,23 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
return success();
}

LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
StringRef name) {
if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, arrType.getElementType())))
return failure();
os << " " << name;
for (auto dim : arrType.getShape()) {
os << "[" << dim << "]";
}
return success();
}
if (failed(emitType(loc, type)))
return failure();
os << " " << name;
return success();
}

LogicalResult CppEmitter::emitType(Location loc, Type type) {
if (auto iType = dyn_cast<IntegerType>(type)) {
switch (iType.getWidth()) {
Expand Down Expand Up @@ -1438,6 +1466,8 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
if (!tType.hasStaticShape())
return emitError(loc, "cannot emit tensor type with non static shape");
os << "Tensor<";
if (isa<ArrayType>(tType.getElementType()))
return emitError(loc, "cannot emit tensor of array type ") << type;
if (failed(emitType(loc, tType.getElementType())))
return failure();
auto shape = tType.getShape();
Expand All @@ -1454,7 +1484,16 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
os << oType.getValue();
return success();
}
if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, aType.getElementType())))
return failure();
for (auto dim : aType.getShape())
os << "[" << dim << "]";
return success();
}
if (auto pType = dyn_cast<emitc::PointerType>(type)) {
if (isa<ArrayType>(pType.getPointee()))
return emitError(loc, "cannot emit pointer to array type ") << type;
if (failed(emitType(loc, pType.getPointee())))
return failure();
os << "*";
Expand All @@ -1476,6 +1515,9 @@ LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
}

LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
if (llvm::any_of(types, [](Type type) { return isa<ArrayType>(type); })) {
return emitError(loc, "cannot emit tuple of array type");
}
os << "std::tuple<";
if (failed(interleaveCommaWithError(
types, os, [&](Type type) { return emitType(loc, type); })))
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ func.func @dense_template_argument(%arg : i32) {

// -----

func.func @array_result() {
// expected-error @+1 {{'emitc.call_opaque' op cannot return array type}}
emitc.call_opaque "array_result"() : () -> !emitc.array<4xi32>
return
}

// -----

func.func @empty_operator(%arg : i32) {
// expected-error @+1 {{'emitc.apply' op applicable operator must not be empty}}
%2 = emitc.apply ""(%arg) : (i32) -> !emitc.ptr<i32>
Expand Down Expand Up @@ -129,6 +137,14 @@ func.func @cast_tensor(%arg : tensor<f32>) {

// -----

func.func @cast_array(%arg : !emitc.array<4xf32>) {
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}}
%1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32>
return
}

// -----

func.func @add_two_pointers(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
// expected-error @+1 {{'emitc.add' op requires that at most one operand is a pointer}}
%1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> !emitc.ptr<f32>
Expand Down Expand Up @@ -235,6 +251,15 @@ func.func @test_assign_type_mismatch(%arg1: f32) {

// -----

func.func @test_assign_to_array(%arg1: !emitc.array<4xi32>) {
%v = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4xi32>
// expected-error @+1 {{'emitc.assign' op cannot assign to array type}}
emitc.assign %arg1 : !emitc.array<4xi32> to %v : !emitc.array<4xi32>
return
}

// -----

func.func @test_expression_no_yield() -> i32 {
// expected-error @+1 {{'emitc.expression' op must yield a value at termination}}
%r = emitc.expression : i32 {
Expand Down Expand Up @@ -313,6 +338,13 @@ emitc.func @return_type_mismatch() -> i32 {

// -----

// expected-error@+1 {{'emitc.func' op cannot return array type}}
emitc.func @return_type_array(%arg : !emitc.array<4xi32>) -> !emitc.array<4xi32> {
emitc.return %arg : !emitc.array<4xi32>
}

// -----

func.func @return_inside_func.func(%0: i32) -> (i32) {
// expected-error@+1 {{'emitc.return' op expects parent op 'emitc.func'}}
emitc.return %0 : i32
Expand Down
Loading

0 comments on commit 818af71

Please sign in to comment.