Skip to content

Commit

Permalink
[mlir][emitc] Add a declare_func operation (llvm#80297)
Browse files Browse the repository at this point in the history
This adds the `emitc.declare_func` operation that allows to emit the
declaration of an `emitc.func` at a specific location.
  • Loading branch information
marbre authored and mgehre-amd committed Mar 11, 2024
1 parent ed20cea commit 9c6c868
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 6 deletions.
42 changes: 42 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,48 @@ def EmitC_CallOp : EmitC_Op<"call",
}];
}

def EmitC_DeclareFuncOp : EmitC_Op<"declare_func", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "An operation to declare a function";
let description = [{
The `declare_func` operation allows to insert a function declaration for an
`emitc.func` at a specific position. The operation only requires the `callee`
of the `emitc.func` to be specified as an attribute.

Example:

```mlir
emitc.declare_func @bar
emitc.func @foo(%arg0: i32) -> i32 {
%0 = emitc.call @bar(%arg0) : (i32) -> (i32)
emitc.return %0 : i32
}

emitc.func @bar(%arg0: i32) -> i32 {
emitc.return %arg0 : i32
}
```

```c++
// Code emitted for the operations above.
int32_t bar(int32_t v1);
int32_t foo(int32_t v1) {
int32_t v2 = bar(v1);
return v2;
}

int32_t bar(int32_t v1) {
return v1;
}
```
}];
let arguments = (ins FlatSymbolRefAttr:$sym_name);
let assemblyFormat = [{
$sym_name attr-dict
}];
}

def EmitC_FuncOp : EmitC_Op<"func", [
AutomaticAllocationScope,
FunctionOpInterface, IsolatedFromAbove
Expand Down
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,24 @@ FunctionType CallOp::getCalleeType() {
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
}

//===----------------------------------------------------------------------===//
// DeclareFuncOp
//===----------------------------------------------------------------------===//

LogicalResult
DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the sym_name attribute was specified.
auto fnAttr = getSymNameAttr();
if (!fnAttr)
return emitOpError("requires a 'sym_name' symbol reference attribute");
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
if (!fn)
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";

return success();
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
Expand Down
46 changes: 40 additions & 6 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
Expand Down Expand Up @@ -870,8 +871,9 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
// needs to be printed after the closing brace.
// When generating code for an emitc.for and emitc.verbatim op, printing a
// trailing semicolon is handled within the printOperation function.
bool trailingSemicolon = !isa<cf::CondBranchOp, emitc::ForOp, emitc::IfOp,
emitc::LiteralOp, emitc::VerbatimOp>(op);
bool trailingSemicolon =
!isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);

if (failed(emitter.emitOperation(
op, /*trailingSemicolon=*/trailingSemicolon)))
Expand Down Expand Up @@ -953,6 +955,37 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
DeclareFuncOp declareFuncOp) {
CppEmitter::Scope scope(emitter);
raw_indented_ostream &os = emitter.ostream();

auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
declareFuncOp, declareFuncOp.getSymNameAttr());

if (!functionOp)
return failure();

if (functionOp.getSpecifiers()) {
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
os << cast<StringAttr>(specifier).str() << " ";
}
}

if (failed(emitter.emitTypes(functionOp.getLoc(),
functionOp.getFunctionType().getResults())))
return failure();
os << " " << functionOp.getName();

os << "(";
Operation *operation = functionOp.getOperation();
if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
return failure();
os << ");";

return success();
}

CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
: os(os), declareVariablesAtTop(declareVariablesAtTop) {
valueInScopeCount.push(0);
Expand Down Expand Up @@ -1285,10 +1318,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp,
emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp,
emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
emitc::SubscriptOp, emitc::VariableOp, emitc::VerbatimOp>(
emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
emitc::SubOp, emitc::SubscriptOp, emitc::VariableOp,
emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,13 @@ func.func @return_inside_func.func(%0: i32) -> (i32) {

// expected-error@+1 {{expected non-function type}}
emitc.func @func_variadic(...)

// -----

// expected-error@+1 {{'emitc.declare_func' op 'bar' does not reference a valid function}}
emitc.declare_func @bar

// -----

// expected-error@+1 {{'emitc.declare_func' op requires attribute 'sym_name'}}
"emitc.declare_func"() : () -> ()
2 changes: 2 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) {
return
}

emitc.declare_func @func

emitc.func @func(%arg0 : i32) {
emitc.call_opaque "foo"(%arg0) : (i32) -> ()
emitc.return
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Target/Cpp/declare_func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s

// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]);
emitc.declare_func @bar
// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]) {
emitc.func @bar(%arg0: i32) -> i32 {
emitc.return %arg0 : i32
}


// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]);
emitc.declare_func @foo
// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]) {
emitc.func @foo(%arg0: i32) -> i32 attributes {specifiers = ["static","inline"]} {
emitc.return %arg0 : i32
}

0 comments on commit 9c6c868

Please sign in to comment.