From 8362127faeecdbb527f2a9692677503ecb0c1b04 Mon Sep 17 00:00:00 2001 From: Gil Rapaport Date: Wed, 20 Dec 2023 15:04:46 +0200 Subject: [PATCH 01/21] [mlir][emitc] Add op modelling C expressions (#71631) Add an emitc.expression operation that models C expressions, and provide transforms to form and fold expressions. The translator emits the body of emitc.expression ops as a single C expression. This expression is emitted by default as the RHS of an EmitC SSA value, but if possible, expressions with a single use that is not another expression are instead inlined. Specific expression's inlining can be fine tuned by lowering passes and transforms. --- .../include/mlir/Dialect/EmitC/CMakeLists.txt | 1 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 96 ++++++- .../Dialect/EmitC/Transforms/CMakeLists.txt | 5 + .../mlir/Dialect/EmitC/Transforms/Passes.h | 35 +++ .../mlir/Dialect/EmitC/Transforms/Passes.td | 24 ++ .../Dialect/EmitC/Transforms/Transforms.h | 34 +++ mlir/include/mlir/InitAllPasses.h | 2 + mlir/lib/Dialect/EmitC/CMakeLists.txt | 1 + mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 61 +++++ .../Dialect/EmitC/Transforms/CMakeLists.txt | 16 ++ .../EmitC/Transforms/FormExpressions.cpp | 60 +++++ .../Dialect/EmitC/Transforms/Transforms.cpp | 114 ++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 247 ++++++++++++++++-- mlir/test/Dialect/EmitC/invalid_ops.mlir | 59 ++++- mlir/test/Dialect/EmitC/ops.mlir | 17 ++ mlir/test/Dialect/EmitC/transforms.mlir | 109 ++++++++ mlir/test/Target/Cpp/expressions.mlir | 212 +++++++++++++++ mlir/test/Target/Cpp/for.mlir | 22 +- 18 files changed, 1078 insertions(+), 37 deletions(-) create mode 100644 mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h create mode 100644 mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td create mode 100644 mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h create mode 100644 mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt create mode 100644 mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp create mode 100644 mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp create mode 100644 mlir/test/Dialect/EmitC/transforms.mlir create mode 100644 mlir/test/Target/Cpp/expressions.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt index f33061b2d87cf..9f57627c321fb 100644 --- a/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 644a6ed2566e5..4ece9471a67c2 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -19,6 +19,7 @@ include "mlir/Dialect/EmitC/IR/EmitCTypes.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/RegionKindInterface.td" //===----------------------------------------------------------------------===// // EmitC op definitions @@ -247,6 +248,83 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> { let results = (outs FloatIntegerIndexOrOpaqueType); } +def EmitC_ExpressionOp : EmitC_Op<"expression", + [HasOnlyGraphRegion, SingleBlockImplicitTerminator<"emitc::YieldOp">, + NoRegionArguments]> { + let summary = "Expression operation"; + let description = [{ + The `expression` operation returns a single SSA value which is yielded by + its single-basic-block region. The operation doesn't take any arguments. + + As the operation is to be emitted as a C expression, the operations within + its body must form a single Def-Use tree of emitc ops whose result is + yielded by a terminating `yield`. + + Example: + + ```mlir + %r = emitc.expression : () -> i32 { + %0 = emitc.add %a, %b : (i32, i32) -> i32 + %1 = emitc.call "foo"(%0) : () -> i32 + %2 = emitc.add %c, %d : (i32, i32) -> i32 + %3 = emitc.mul %1, %2 : (i32, i32) -> i32 + yield %3 + } + ``` + + May be emitted as + + ```c++ + int32_t v7 = foo(v1 + v2) * (v3 + v4); + ``` + + The operations allowed within expression body are emitc.add, emitc.apply, + emitc.call, emitc.cast, emitc.cmp, emitc.div, emitc.mul, emitc.rem and + emitc.sub. + + When specified, the optional `do_not_inline` indicates that the expression is + to be emitted as seen above, i.e. as the rhs of an EmitC SSA value + definition. Otherwise, the expression may be emitted inline, i.e. directly + at its use. + }]; + + let arguments = (ins UnitAttr:$do_not_inline); + let results = (outs AnyType:$result); + let regions = (region SizedRegion<1>:$region); + + let hasVerifier = 1; + let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region"; + + let extraClassDeclaration = [{ + static bool isCExpression(Operation &op) { + return isa(op); + } + bool hasSideEffects() { + auto predicate = [](Operation &op) { + assert(isCExpression(op) && "Expected a C expression"); + // Conservatively assume calls to read and write memory. + if (isa(op)) + return true; + // De-referencing reads modifiable memory, address-taking has no + // side-effect. + auto applyOp = dyn_cast(op); + if (applyOp) + return applyOp.getApplicableOperator() == "*"; + // Any operation using variables is assumed to have a side effect of + // reading memory mutable by emitc::assign ops. + return llvm::any_of(op.getOperands(), [](Value operand) { + Operation *def = operand.getDefiningOp(); + return def && isa(def); + }); + }; + return llvm::any_of(getRegion().front().without_terminator(), predicate); + }; + Operation *getRootOp(); + }]; +} + def EmitC_ForOp : EmitC_Op<"for", [AllTypesMatch<["lowerBound", "upperBound", "step"]>, SingleBlockImplicitTerminator<"emitc::YieldOp">, @@ -494,18 +572,24 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> { } def EmitC_YieldOp : EmitC_Op<"yield", - [Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> { + [Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> { let summary = "block termination operation"; let description = [{ - "yield" terminates blocks within EmitC control-flow operations. Since - control-flow constructs in C do not return values, this operation doesn't - take any arguments. + "yield" terminates its parent EmitC op's region, optionally yielding + an SSA value. The semantics of how the values are yielded is defined by the + parent operation. + If "yield" has an operand, the operand must match the parent operation's + result. If the parent operation defines no values, then the "emitc.yield" + may be left out in the custom syntax and the builders will insert one + implicitly. Otherwise, it has to be present in the syntax to indicate which + value is yielded. }]; - let arguments = (ins); + let arguments = (ins Optional:$result); let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let assemblyFormat = [{ attr-dict }]; + let hasVerifier = 1; + let assemblyFormat = [{ attr-dict ($result^ `:` type($result))? }]; } def EmitC_IfOp : EmitC_Op<"if", diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..0b507d75fa07a --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name EmitC) +add_public_tablegen_target(MLIREmitCTransformsIncGen) + +add_mlir_doc(Passes EmitCPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h new file mode 100644 index 0000000000000..5cd27149d366e --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h @@ -0,0 +1,35 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace emitc { + +//===----------------------------------------------------------------------===// +// Passes +//===----------------------------------------------------------------------===// + +/// Creates an instance of the C-style expressions forming pass. +std::unique_ptr createFormExpressionsPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" + +} // namespace emitc +} // namespace mlir + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td new file mode 100644 index 0000000000000..fd083abc95715 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td @@ -0,0 +1,24 @@ +//===-- Passes.td - pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES +#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def FormExpressions : Pass<"form-expressions"> { + let summary = "Form C-style expressions from C-operator ops"; + let description = [{ + The pass wraps emitc ops modelling C operators in emitc.expression ops and + then folds single-use expressions into their users where possible. + }]; + let constructor = "mlir::emitc::createFormExpressionsPass()"; + let dependentDialects = ["emitc::EmitCDialect"]; +} + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h new file mode 100644 index 0000000000000..2574acd7d48e0 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h @@ -0,0 +1,34 @@ +//===- Transforms.h - EmitC transformations as patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace emitc { + +//===----------------------------------------------------------------------===// +// Expression transforms +//===----------------------------------------------------------------------===// + +ExpressionOp createExpression(Operation *op, OpBuilder &builder); + +//===----------------------------------------------------------------------===// +// Populate functions +//===----------------------------------------------------------------------===// + +/// Populates `patterns` with expression-related patterns. +void populateExpressionPatterns(RewritePatternSet &patterns); + +} // namespace emitc +} // namespace mlir + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index f22980036ffcf..5207559f36250 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/Async/Passes.h" #include "mlir/Dialect/Bufferization/Pipelines/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" @@ -86,6 +87,7 @@ inline void registerAllPasses() { vector::registerVectorPasses(); arm_sme::registerArmSMEPasses(); arm_sve::registerArmSVEPasses(); + emitc::registerEmitCPasses(); // Dialect pipelines bufferization::registerBufferizationPipelines(); diff --git a/mlir/lib/Dialect/EmitC/CMakeLists.txt b/mlir/lib/Dialect/EmitC/CMakeLists.txt index f33061b2d87cf..9f57627c321fb 100644 --- a/mlir/lib/Dialect/EmitC/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 2d578d47aa4a8..c5d07b1d39994 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -190,6 +190,50 @@ LogicalResult emitc::ConstantOp::verify() { OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } +//===----------------------------------------------------------------------===// +// ExpressionOp +//===----------------------------------------------------------------------===// + +Operation *ExpressionOp::getRootOp() { + auto yieldOp = cast(getBody()->getTerminator()); + Value yieldedValue = yieldOp.getResult(); + Operation *rootOp = yieldedValue.getDefiningOp(); + assert(rootOp && "Yielded value not defined within expression"); + return rootOp; +} + +LogicalResult ExpressionOp::verify() { + Type resultType = getResult().getType(); + Region ®ion = getRegion(); + + Block &body = region.front(); + + if (!body.mightHaveTerminator()) + return emitOpError("must yield a value at termination"); + + auto yield = cast(body.getTerminator()); + Value yieldResult = yield.getResult(); + + if (!yieldResult) + return emitOpError("must yield a value at termination"); + + Type yieldType = yieldResult.getType(); + + if (resultType != yieldType) + return emitOpError("requires yielded type to match return type"); + + for (Operation &op : region.front().without_terminator()) { + if (!isCExpression(op)) + return emitOpError("contains an unsupported operation"); + if (op.getNumResults() != 1) + return emitOpError("requires exactly one result for each operation"); + if (!op.getResult(0).hasOneUse()) + return emitOpError("requires exactly one use for each operation"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// @@ -545,6 +589,23 @@ LogicalResult emitc::SubscriptOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +LogicalResult emitc::YieldOp::verify() { + Value result = getResult(); + Operation *containingOp = getOperation()->getParentOp(); + + if (result && containingOp->getNumResults() != 1) + return emitOpError() << "yields a value not returned by parent"; + + if (!result && containingOp->getNumResults() != 0) + return emitOpError() << "does not yield a value to be returned by parent"; + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..bfcc14523f137 --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIREmitCTransforms + Transforms.cpp + FormExpressions.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms + + DEPENDS + MLIREmitCTransformsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIREmitCDialect + MLIRTransforms +) diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp new file mode 100644 index 0000000000000..21212155ffb22 --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp @@ -0,0 +1,60 @@ +//===- FormExpressions.cpp - Form C-style expressions --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that forms EmitC operations modeling C operators +// into C-style expressions using the emitc.expression op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace emitc { +#define GEN_PASS_DEF_FORMEXPRESSIONS +#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" +} // namespace emitc +} // namespace mlir + +using namespace mlir; +using namespace emitc; + +namespace { +struct FormExpressionsPass + : public emitc::impl::FormExpressionsBase { + void runOnOperation() override { + Operation *rootOp = getOperation(); + MLIRContext *context = rootOp->getContext(); + + // Wrap each C operator op with an expression op. + OpBuilder builder(context); + auto matchFun = [&](Operation *op) { + if (emitc::ExpressionOp::isCExpression(*op)) + createExpression(op, builder); + }; + rootOp->walk(matchFun); + + // Fold expressions where possible. + RewritePatternSet patterns(context); + populateExpressionPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns)))) + return signalPassFailure(); + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; +} // namespace + +std::unique_ptr mlir::emitc::createFormExpressionsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp new file mode 100644 index 0000000000000..593d774cac73b --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp @@ -0,0 +1,114 @@ +//===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/Transforms/Transforms.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace emitc { + +ExpressionOp createExpression(Operation *op, OpBuilder &builder) { + assert(ExpressionOp::isCExpression(*op) && "Expected a C expression"); + + // Create an expression yielding the value returned by op. + assert(op->getNumResults() == 1 && "Expected exactly one result"); + Value result = op->getResult(0); + Type resultType = result.getType(); + Location loc = op->getLoc(); + + builder.setInsertionPointAfter(op); + auto expressionOp = builder.create(loc, resultType); + + // Replace all op's uses with the new expression's result. + result.replaceAllUsesWith(expressionOp.getResult()); + + // Create an op to yield op's value. + Region ®ion = expressionOp.getRegion(); + Block &block = region.emplaceBlock(); + builder.setInsertionPointToEnd(&block); + auto yieldOp = builder.create(loc, result); + + // Move op into the new expression. + op->moveBefore(yieldOp); + + return expressionOp; +} + +} // namespace emitc +} // namespace mlir + +using namespace mlir; +using namespace mlir::emitc; + +namespace { + +struct FoldExpressionOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ExpressionOp expressionOp, + PatternRewriter &rewriter) const override { + bool anythingFolded = false; + for (Operation &op : llvm::make_early_inc_range( + expressionOp.getBody()->without_terminator())) { + // Don't fold expressions whose result value has its address taken. + auto applyOp = dyn_cast(op); + if (applyOp && applyOp.getApplicableOperator() == "&") + continue; + + for (Value operand : op.getOperands()) { + auto usedExpression = + dyn_cast_if_present(operand.getDefiningOp()); + + if (!usedExpression) + continue; + + // Don't fold expressions with multiple users: assume any + // re-materialization was done separately. + if (!usedExpression.getResult().hasOneUse()) + continue; + + // Don't fold expressions with side effects. + if (usedExpression.hasSideEffects()) + continue; + + // Fold the used expression into this expression by cloning all + // instructions in the used expression just before the operation using + // its value. + rewriter.setInsertionPoint(&op); + IRMapping mapper; + for (Operation &opToClone : + usedExpression.getBody()->without_terminator()) { + Operation *clone = rewriter.clone(opToClone, mapper); + mapper.map(&opToClone, clone); + } + + Operation *expressionRoot = usedExpression.getRootOp(); + Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot); + assert(clonedExpressionRootOp && + "Expected cloned expression root to be in mapper"); + assert(clonedExpressionRootOp->getNumResults() == 1 && + "Expected cloned root to have a single result"); + + Value clonedExpressionResult = clonedExpressionRootOp->getResult(0); + + usedExpression.getResult().replaceAllUsesWith(clonedExpressionResult); + rewriter.eraseOp(usedExpression); + anythingFolded = true; + } + } + return anythingFolded ? success() : failure(); + } +}; + +} // namespace + +void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 1edf679390d7d..02a54f792567f 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/Support/IndentedOstream.h" +#include "mlir/Support/LLVM.h" #include "mlir/Target/Cpp/CppEmitter.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" @@ -65,6 +66,35 @@ inline LogicalResult interleaveCommaWithError(const Container &c, return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); } +/// Return the precedence of a operator as an integer, higher values +/// imply higher precedence. +static int getOperatorPrecedence(Operation *operation) { + return llvm::TypeSwitch(operation) + .Case([&](auto op) { return 11; }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { + switch (op.getPredicate()) { + case emitc::CmpPredicate::eq: + case emitc::CmpPredicate::ne: + return 8; + case emitc::CmpPredicate::lt: + case emitc::CmpPredicate::le: + case emitc::CmpPredicate::gt: + case emitc::CmpPredicate::ge: + return 9; + case emitc::CmpPredicate::three_way: + return 10; + } + }) + .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 11; }) + .Case([&](auto op) { return 14; }); + llvm_unreachable("Unsupported operator"); +} + namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { @@ -119,6 +149,12 @@ struct CppEmitter { /// Emits the operands of the operation. All operands are emitted in order. LogicalResult emitOperands(Operation &op); + /// Emits value as an operands of an operation + LogicalResult emitOperand(Value value); + + /// Emit an expression as a C expression. + LogicalResult emitExpression(ExpressionOp expressionOp); + /// Return the existing or a new name for a Value. StringRef getOrCreateName(Value val); @@ -163,6 +199,21 @@ struct CppEmitter { /// be declared at the beginning of a function. bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; + /// Get expression currently being emitted. + ExpressionOp getEmittedExpression() { return emittedExpression; } + + /// Determine whether given value is part of the expression potentially being + /// emitted. + bool isPartOfCurrentExpression(Value value) { + if (!emittedExpression) + return false; + Operation *def = value.getDefiningOp(); + if (!def) + return false; + auto operandExpression = dyn_cast(def->getParentOp()); + return operandExpression == emittedExpression; + }; + private: using ValueMapper = llvm::ScopedHashTable; using BlockMapper = llvm::ScopedHashTable; @@ -185,9 +236,50 @@ struct CppEmitter { /// names of values in a scope. std::stack valueInScopeCount; std::stack labelInScopeCount; + + /// State of the current expression being emitted. + ExpressionOp emittedExpression; + SmallVector emittedExpressionPrecedence; + + void pushExpressionPrecedence(int precedence) { + emittedExpressionPrecedence.push_back(precedence); + } + void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); } + static int lowestPrecedence() { return 0; } + int getExpressionPrecedence() { + if (emittedExpressionPrecedence.empty()) + return lowestPrecedence(); + return emittedExpressionPrecedence.back(); + } }; } // namespace +/// Determine whether expression \p expressionOp should be emitted inline, i.e. +/// as part of its user. This function recommends inlining of any expressions +/// that can be inlined unless it is used by another expression, under the +/// assumption that any expression fusion/re-materialization was taken care of +/// by transformations run by the backend. +static bool shouldBeInlined(ExpressionOp expressionOp) { + // Do not inline if expression is marked as such. + if (expressionOp.getDoNotInline()) + return false; + + // Do not inline expressions with side effects to prevent side-effect + // reordering. + if (expressionOp.hasSideEffects()) + return false; + + // Do not inline expressions with multiple uses. + Value result = expressionOp.getResult(); + if (!result.hasOneUse()) + return false; + + // Do not inline expressions used by other expressions, as any desired + // expression folding was taken care of by transformations. + Operation *user = *result.getUsers().begin(); + return !user->getParentOfType(); +} + static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value) { OpResult result = operation->getResult(0); @@ -259,9 +351,7 @@ static LogicalResult printOperation(CppEmitter &emitter, if (failed(emitter.emitVariableAssignment(result))) return failure(); - emitter.ostream() << emitter.getOrCreateName(assignOp.getValue()); - - return success(); + return emitter.emitOperand(assignOp.getValue()); } static LogicalResult printOperation(CppEmitter &emitter, @@ -278,9 +368,14 @@ static LogicalResult printBinaryOperation(CppEmitter &emitter, if (failed(emitter.emitAssignPrefix(*operation))) return failure(); - os << emitter.getOrCreateName(operation->getOperand(0)); - os << " " << binaryOperator; - os << " " << emitter.getOrCreateName(operation->getOperand(1)); + + if (failed(emitter.emitOperand(operation->getOperand(0)))) + return failure(); + + os << " " << binaryOperator << " "; + + if (failed(emitter.emitOperand(operation->getOperand(1)))) + return failure(); return success(); } @@ -498,9 +593,20 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) return failure(); os << ") "; - os << emitter.getOrCreateName(castOp.getOperand()); + return emitter.emitOperand(castOp.getOperand()); +} - return success(); +static LogicalResult printOperation(CppEmitter &emitter, + emitc::ExpressionOp expressionOp) { + if (shouldBeInlined(expressionOp)) + return success(); + + Operation &op = *expressionOp.getOperation(); + + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + + return emitter.emitExpression(expressionOp); } static LogicalResult printOperation(CppEmitter &emitter, @@ -520,6 +626,17 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { raw_indented_ostream &os = emitter.ostream(); + // Utility function to determine whether a value is an expression that will be + // inlined, and as such should be wrapped in parentheses in order to guarantee + // its precedence and associativity. + auto requiresParentheses = [&](Value value) { + auto expressionOp = + dyn_cast_if_present(value.getDefiningOp()); + if (!expressionOp) + return false; + return shouldBeInlined(expressionOp); + }; + os << "for ("; if (failed( emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) @@ -527,15 +644,24 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { os << " "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " = "; - os << emitter.getOrCreateName(forOp.getLowerBound()); + if (failed(emitter.emitOperand(forOp.getLowerBound()))) + return failure(); os << "; "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " < "; - os << emitter.getOrCreateName(forOp.getUpperBound()); + Value upperBound = forOp.getUpperBound(); + bool upperBoundRequiresParentheses = requiresParentheses(upperBound); + if (upperBoundRequiresParentheses) + os << "("; + if (failed(emitter.emitOperand(upperBound))) + return failure(); + if (upperBoundRequiresParentheses) + os << ")"; os << "; "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " += "; - os << emitter.getOrCreateName(forOp.getStep()); + if (failed(emitter.emitOperand(forOp.getStep()))) + return failure(); os << ") {\n"; os.indent(); @@ -570,7 +696,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) { }; os << "if ("; - if (failed(emitter.emitOperands(*ifOp.getOperation()))) + if (failed(emitter.emitOperand(ifOp.getCondition()))) return failure(); os << ") {\n"; os.indent(); @@ -598,8 +724,10 @@ static LogicalResult printOperation(CppEmitter &emitter, case 0: return success(); case 1: - os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); - return success(emitter.hasValueInScope(returnOp.getOperand(0))); + os << " "; + if (failed(emitter.emitOperand(returnOp.getOperand(0)))) + return failure(); + return success(); default: os << " std::make_tuple("; if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) @@ -650,7 +778,10 @@ static LogicalResult printOperation(CppEmitter &emitter, // regions. WalkResult result = functionOp.walk([&](Operation *op) -> WalkResult { - if (isa(op)) + if (isa(op) || + isa(op->getParentOp()) || + (isa(op) && + shouldBeInlined(cast(op)))) return WalkResult::skip(); for (OpResult result : op->getResults()) { if (failed(emitter.emitVariableDeclaration( @@ -868,15 +999,70 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { return emitError(loc, "cannot emit attribute: ") << attr; } +LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { + assert(emittedExpressionPrecedence.empty() && + "Expected precedence stack to be empty"); + Operation *rootOp = expressionOp.getRootOp(); + + emittedExpression = expressionOp; + pushExpressionPrecedence(getOperatorPrecedence(rootOp)); + + if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false))) + return failure(); + + popExpressionPrecedence(); + assert(emittedExpressionPrecedence.empty() && + "Expected precedence stack to be empty"); + emittedExpression = nullptr; + + return success(); +} + +LogicalResult CppEmitter::emitOperand(Value value) { + if (isPartOfCurrentExpression(value)) { + Operation *def = value.getDefiningOp(); + assert(def && "Expected operand to be defined by an operation"); + int precedence = getOperatorPrecedence(def); + bool encloseInParenthesis = precedence < getExpressionPrecedence(); + if (encloseInParenthesis) { + os << "("; + pushExpressionPrecedence(lowestPrecedence()); + } else + pushExpressionPrecedence(precedence); + + if (failed(emitOperation(*def, /*trailingSemicolon=*/false))) + return failure(); + + if (encloseInParenthesis) + os << ")"; + + popExpressionPrecedence(); + return success(); + } + + auto expressionOp = dyn_cast_if_present(value.getDefiningOp()); + if (expressionOp && shouldBeInlined(expressionOp)) + return emitExpression(expressionOp); + + auto literalOp = dyn_cast_if_present(value.getDefiningOp()); + if (!literalOp && !hasValueInScope(value)) + return failure(); + os << getOrCreateName(value); + return success(); +} + LogicalResult CppEmitter::emitOperands(Operation &op) { - auto emitOperandName = [&](Value result) -> LogicalResult { - auto literalDef = dyn_cast_if_present(result.getDefiningOp()); - if (!literalDef && !hasValueInScope(result)) - return op.emitOpError() << "operand value not in scope"; - os << getOrCreateName(result); + return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) { + // If an expression is being emitted, push lowest precedence as these + // operands are either wrapped by parenthesis. + if (getEmittedExpression()) + pushExpressionPrecedence(lowestPrecedence()); + if (failed(emitOperand(operand))) + return failure(); + if (getEmittedExpression()) + popExpressionPrecedence(); return success(); - }; - return interleaveCommaWithError(op.getOperands(), os, emitOperandName); + }); } LogicalResult @@ -932,6 +1118,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, } LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { + // If op is being emitted as part of an expression, bail out. + if (getEmittedExpression()) + return success(); + switch (op.getNumResults()) { case 0: break; @@ -982,9 +1172,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { // EmitC ops. .Case( + emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp, + emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp, + emitc::RemOp, emitc::SubOp, emitc::SubscriptOp, + emitc::VariableOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( @@ -1003,7 +1194,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (isa(op)) return success(); + if (getEmittedExpression() || + (isa(op) && + shouldBeInlined(cast(op)))) + return success(); + os << (trailingSemicolon ? ";\n" : "\n"); + return success(); } diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index fd79bbd8a1d30..bf217be32ef0c 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -203,7 +203,7 @@ func.func @sub_pointer_pointer(%arg0: !emitc.ptr, %arg1: !emitc.ptr) { // ----- func.func @test_misplaced_yield() { - // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.if, emitc.for'}} + // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.expression, emitc.if, emitc.for'}} emitc.yield return } @@ -232,3 +232,60 @@ func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: %0 = emitc.subscript %arg0[%arg2] : <4x8xf32> return } + +// ----- + +func.func @test_expression_no_yield() -> i32 { + // expected-error @+1 {{'emitc.expression' op must yield a value at termination}} + %r = emitc.expression : i32 { + %c7 = "emitc.constant"(){value = 7 : i32} : () -> i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_illegal_op(%arg0 : i1) -> i32 { + // expected-error @+1 {{'emitc.expression' op contains an unsupported operation}} + %r = emitc.expression : i32 { + %x = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32 + emitc.yield %x : i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 { + // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}} + %r = emitc.expression : i32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + emitc.yield %a : i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 { + // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}} + %r = emitc.expression : i32 { + %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.add %a, %arg0 : (i32, i32) -> i32 + %c = emitc.mul %arg1, %a : (i32, i32) -> i32 + emitc.yield %a : i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_multiple_results(%arg0: i32) -> i32 { + // expected-error @+1 {{'emitc.expression' op requires exactly one result for each operation}} + %r = emitc.expression : i32 { + %a:2 = emitc.call_opaque "bar" (%arg0) : (i32) -> (i32, i32) + emitc.yield %a : i32 + } + return %r : i32 +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index d280f12b78516..5af90f8749beb 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -128,6 +128,23 @@ func.func @test_assign(%arg1: f32) { return } +func.func @test_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> i32 { + %c7 = "emitc.constant"() {value = 7 : i32} : () -> i32 + %q = emitc.expression : i32 { + %a = emitc.rem %arg1, %c7 : (i32, i32) -> i32 + emitc.yield %a : i32 + } + %r = emitc.expression noinline : i32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%a, %arg2, %q) : (i32, i32, i32) -> (i32) + %c = emitc.mul %arg3, %arg4 : (f32, f32) -> f32 + %d = emitc.cast %c : f32 to i32 + %e = emitc.sub %b, %d : (i32, i32) -> i32 + emitc.yield %e : i32 + } + return %r : i32 +} + func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { emitc.for %i0 = %arg0 to %arg1 step %arg2 { %0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32 diff --git a/mlir/test/Dialect/EmitC/transforms.mlir b/mlir/test/Dialect/EmitC/transforms.mlir new file mode 100644 index 0000000000000..ad167fa455a1a --- /dev/null +++ b/mlir/test/Dialect/EmitC/transforms.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-opt %s --form-expressions --verify-diagnostics --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @single_expression( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 { +// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32 +// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_6:.*]] = emitc.mul %[[VAL_0]], %[[VAL_4]] : (i32, i32) -> i32 +// CHECK: %[[VAL_7:.*]] = emitc.sub %[[VAL_6]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_7]], %[[VAL_3]] : (i32, i32) -> i1 +// CHECK: emitc.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: return %[[VAL_5]] : i1 +// CHECK: } + +func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i1 { + %c42 = "emitc.constant"(){value = 42 : i32} : () -> i32 + %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32 + %b = emitc.sub %a, %arg2 : (i32, i32) -> i32 + %c = emitc.cmp lt, %b, %arg3 :(i32, i32) -> i1 + return %c : i1 +} + +// CHECK-LABEL: func.func @multiple_expressions( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) { +// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_5:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: %[[VAL_6:.*]] = emitc.sub %[[VAL_5]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_6]] : i32 +// CHECK: } +// CHECK: %[[VAL_7:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_8:.*]] = emitc.add %[[VAL_1]], %[[VAL_3]] : (i32, i32) -> i32 +// CHECK: %[[VAL_9:.*]] = emitc.div %[[VAL_8]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_9]] : i32 +// CHECK: } +// CHECK: return %[[VAL_4]], %[[VAL_7]] : i32, i32 +// CHECK: } + +func.func @multiple_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> (i32, i32) { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.sub %a, %arg2 : (i32, i32) -> i32 + %c = emitc.add %arg1, %arg3 : (i32, i32) -> i32 + %d = emitc.div %c, %arg2 : (i32, i32) -> i32 + return %b, %d : i32, i32 +} + +// CHECK-LABEL: func.func @expression_with_call( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 { +// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_5:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: %[[VAL_6:.*]] = emitc.call_opaque "foo"(%[[VAL_5]], %[[VAL_2]]) : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_6]] : i32 +// CHECK: } +// CHECK: %[[VAL_7:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_4]], %[[VAL_1]] : (i32, i32) -> i1 +// CHECK: emitc.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: return %[[VAL_7]] : i1 +// CHECK: } + +func.func @expression_with_call(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i1 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "foo" (%a, %arg2) : (i32, i32) -> (i32) + %c = emitc.cmp lt, %b, %arg1 :(i32, i32) -> i1 + return %c : i1 +} + +// CHECK-LABEL: func.func @expression_with_dereference( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr) -> i1 { +// CHECK: %[[VAL_3:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_4:.*]] = emitc.apply "*"(%[[VAL_2]]) : (!emitc.ptr) -> i32 +// CHECK: emitc.yield %[[VAL_4]] : i32 +// CHECK: } +// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_6:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: %[[VAL_7:.*]] = emitc.cmp lt, %[[VAL_6]], %[[VAL_3]] : (i32, i32) -> i1 +// CHECK: emitc.yield %[[VAL_7]] : i1 +// CHECK: } +// CHECK: return %[[VAL_5]] : i1 +// CHECK: } + +func.func @expression_with_dereference(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.apply "*"(%arg2) : (!emitc.ptr) -> (i32) + %c = emitc.cmp lt, %a, %b :(i32, i32) -> i1 + return %c : i1 +} + +// CHECK-LABEL: func.func @expression_with_address_taken( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr) -> i1 { +// CHECK: %[[VAL_3:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_4:.*]] = emitc.rem %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_4]] : i32 +// CHECK: } +// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_6:.*]] = emitc.apply "&"(%[[VAL_3]]) : (i32) -> !emitc.ptr +// CHECK: %[[VAL_7:.*]] = emitc.add %[[VAL_6]], %[[VAL_1]] : (!emitc.ptr, i32) -> !emitc.ptr +// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_7]], %[[VAL_2]] : (!emitc.ptr, !emitc.ptr) -> i1 +// CHECK: emitc.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: return %[[VAL_5]] : i1 +// CHECK: } + +func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { + %a = emitc.rem %arg0, %arg1 : (i32, i32) -> (i32) + %b = emitc.apply "&"(%a) : (i32) -> !emitc.ptr + %c = emitc.add %b, %arg1 : (!emitc.ptr, i32) -> !emitc.ptr + %d = emitc.cmp lt, %c, %arg2 :(!emitc.ptr, !emitc.ptr) -> i1 + return %d : i1 +} diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir new file mode 100644 index 0000000000000..9ec9dcc3c6a84 --- /dev/null +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -0,0 +1,212 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP + +// CPP-DEFAULT: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if ([[VAL_5]]) { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: return [[VAL_6]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: bool [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if ([[VAL_5]]) { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: return [[VAL_6]]; +// CPP-DECLTOP-NEXT: } + +func.func @single_use(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 { + %p0 = emitc.literal "M_PI" : i32 + %e = emitc.expression : i1 { + %a = emitc.mul %arg0, %p0 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32) + %c = emitc.sub %b, %arg3 : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + emitc.yield %d : i1 + } + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32 + emitc.if %e { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } else { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } + return %v : i32 +} + +// CPP-DEFAULT: int32_t do_not_inline(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = ([[VAL_1]] + [[VAL_2]]) * [[VAL_3]]; +// CPP-DEFAULT-NEXT: return [[VAL_4]]; +// CPP-DEFAULT-NEXT:} + +// CPP-DECLTOP: int32_t do_not_inline(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = ([[VAL_1]] + [[VAL_2]]) * [[VAL_3]]; +// CPP-DECLTOP-NEXT: return [[VAL_4]]; +// CPP-DECLTOP-NEXT:} + +func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 { + %e = emitc.expression noinline : i32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.mul %a, %arg2 : (i32, i32) -> i32 + emitc.yield %b : i32 + } + return %e : i32 +} + +// CPP-DEFAULT: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); +// CPP-DECLTOP-NEXT: } + +func.func @paranthesis_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 { + %e = emitc.expression : f32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.mul %a, %arg2 : (i32, i32) -> i32 + %d = emitc.cast %b : i32 to f32 + emitc.yield %d : f32 + } + return %e : f32 +} + +// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if ([[VAL_5]]) { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: bool [[VAL_7:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_5]]; +// CPP-DEFAULT-NEXT: return [[VAL_6]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: bool [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: bool [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if ([[VAL_5]]) { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_5]]; +// CPP-DECLTOP-NEXT: return [[VAL_6]]; +// CPP-DECLTOP-NEXT: } + +func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 { + %e = emitc.expression : i1 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32) + %c = emitc.sub %b, %arg3 : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + emitc.yield %d : i1 + } + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32 + emitc.if %e { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } else { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } + %q = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i1 + emitc.assign %e : i1 to %q : i1 + return %v : i32 +} + +// CPP-DEFAULT: int32_t different_expressions(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { +// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: return [[VAL_7]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t different_expressions(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_3]] % [[VAL_4]]; +// CPP-DECLTOP-NEXT: [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { +// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: return [[VAL_7]]; +// CPP-DECLTOP-NEXT: } + +func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 { + %e1 = emitc.expression : i32 { + %a = emitc.rem %arg2, %arg3 : (i32, i32) -> i32 + emitc.yield %a : i32 + } + %e2 = emitc.expression : i32 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%e1, %a) : (i32, i32) -> (i32) + emitc.yield %b : i32 + } + %e3 = emitc.expression : i1 { + %c = emitc.sub %e2, %arg3 : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + emitc.yield %d : i1 + } + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32 + emitc.if %e3 { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } else { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } + return %v : i32 +} + +// CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]]; +// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]] % [[VAL_2]]; +// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DECLTOP-NEXT: } + +func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { + %a = emitc.expression : i32 { + %b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + emitc.yield %b : i32 + } + %c = emitc.expression : i1 { + %d = emitc.apply "&"(%a) : (i32) -> !emitc.ptr + %e = emitc.sub %d, %arg1 : (!emitc.ptr, i32) -> !emitc.ptr + %f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr, !emitc.ptr) -> i1 + emitc.yield %f : i1 + } + return %c : i1 +} diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir index 90504b1347bb4..b9bd3d98465a2 100644 --- a/mlir/test/Target/Cpp/for.mlir +++ b/mlir/test/Target/Cpp/for.mlir @@ -2,20 +2,32 @@ // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { - emitc.for %i0 = %arg0 to %arg1 step %arg2 { + %lb = emitc.expression : index { + %a = emitc.add %arg0, %arg1 : (index, index) -> index + emitc.yield %a : index + } + %ub = emitc.expression : index { + %a = emitc.mul %arg1, %arg2 : (index, index) -> index + emitc.yield %a : index + } + %step = emitc.expression : index { + %a = emitc.div %arg0, %arg2 : (index, index) -> index + emitc.yield %a : index + } + emitc.for %i0 = %lb to %ub step %step { %0 = emitc.call_opaque "f"() : () -> i32 } return } -// CPP-DEFAULT: void test_for(size_t [[START:[^ ]*]], size_t [[STOP:[^ ]*]], size_t [[STEP:[^ ]*]]) { -// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) { +// CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { +// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f(); // CPP-DEFAULT-NEXT: } // CPP-DEFAULT-NEXT: return; -// CPP-DECLTOP: void test_for(size_t [[START:[^ ]*]], size_t [[STOP:[^ ]*]], size_t [[STEP:[^ ]*]]) { +// CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { // CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]]; -// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) { +// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { // CPP-DECLTOP-NEXT: [[V4]] = f(); // CPP-DECLTOP-NEXT: } // CPP-DECLTOP-NEXT: return; From cfd54ec04e900ca5c81e5984d0da370bec5f119b Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 21 Dec 2023 16:00:18 +0900 Subject: [PATCH 02/21] [mlir][EmitC] Fix invalid rewriter API usage (#76124) When operations are modified in-place, the rewriter must be notified. This commit fixes `mlir/test/Dialect/EmitC/transforms.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` enabled. --- mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp index 593d774cac73b..88b691b50f325 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp @@ -96,10 +96,7 @@ struct FoldExpressionOp : public OpRewritePattern { assert(clonedExpressionRootOp->getNumResults() == 1 && "Expected cloned root to have a single result"); - Value clonedExpressionResult = clonedExpressionRootOp->getResult(0); - - usedExpression.getResult().replaceAllUsesWith(clonedExpressionResult); - rewriter.eraseOp(usedExpression); + rewriter.replaceOp(usedExpression, clonedExpressionRootOp); anythingFolded = true; } } From 864aa1d35ab8a71fbb9e125969efe3283f67b5fd Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Tue, 2 Jan 2024 16:53:36 +0100 Subject: [PATCH 03/21] [mlir][EmitC] Disallow string attributes as initial values (#75310) --- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 58 ++++++++++++++---------- mlir/test/Dialect/EmitC/invalid_ops.mlir | 22 ++++++--- mlir/test/Target/Cpp/const.mlir | 6 +-- 3 files changed, 52 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index c5d07b1d39994..bdba2de4b073d 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -50,6 +50,32 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) { builder.create(loc); } +/// Check that the type of the initial value is compatible with the operations +/// result type. +static LogicalResult verifyInitializationAttribute(Operation *op, + Attribute value) { + assert(op->getNumResults() == 1 && "operation must have 1 result"); + + if (llvm::isa(value)) + return success(); + + if (llvm::isa(value)) + return op->emitOpError() + << "string attributes are not supported, use #emitc.opaque instead"; + + Type resultType = op->getResult(0).getType(); + Type attrType = cast(value).getType(); + + if (resultType != attrType) + return op->emitOpError() + << "requires attribute to either be an #emitc.opaque attribute or " + "it's type (" + << attrType << ") to match the op's result type (" << resultType + << ")"; + + return success(); +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -170,21 +196,14 @@ LogicalResult emitc::CallOpaqueOp::verify() { // ConstantOp //===----------------------------------------------------------------------===// -/// The constant op requires that the attribute's type matches the return type. LogicalResult emitc::ConstantOp::verify() { - if (llvm::isa(getValueAttr())) - return success(); - - // Value must not be empty - StringAttr strAttr = llvm::dyn_cast(getValueAttr()); - if (strAttr && strAttr.empty()) - return emitOpError() << "value must not be empty"; - - auto value = cast(getValueAttr()); - Type type = getType(); - if (!llvm::isa(value.getType()) && type != value.getType()) - return emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; + Attribute value = getValueAttr(); + if (failed(verifyInitializationAttribute(getOperation(), value))) + return failure(); + if (auto opaqueValue = llvm::dyn_cast(value)) { + if (opaqueValue.getValue().empty()) + return emitOpError() << "value must not be empty"; + } return success(); } @@ -562,17 +581,8 @@ LogicalResult SubOp::verify() { // VariableOp //===----------------------------------------------------------------------===// -/// The variable op requires that the attribute's type matches the return type. LogicalResult emitc::VariableOp::verify() { - if (llvm::isa(getValueAttr())) - return success(); - - auto value = cast(getValueAttr()); - Type type = getType(); - if (!llvm::isa(value.getType()) && type != value.getType()) - return emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; - return success(); + return verifyInitializationAttribute(getOperation(), getValueAttr()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index bf217be32ef0c..1357b81fd6f58 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -1,7 +1,15 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +func.func @const_attribute_str() { + // expected-error @+1 {{'emitc.constant' op string attributes are not supported, use #emitc.opaque instead}} + %c0 = "emitc.constant"(){value = "NULL"} : () -> !emitc.ptr + return +} + +// ----- + func.func @const_attribute_return_type_1() { - // expected-error @+1 {{'emitc.constant' op requires attribute's type ('i64') to match op's return type ('i32')}} + // expected-error @+1 {{'emitc.constant' op requires attribute to either be an #emitc.opaque attribute or it's type ('i64') to match the op's result type ('i32')}} %c0 = "emitc.constant"(){value = 42: i64} : () -> i32 return } @@ -9,8 +17,8 @@ func.func @const_attribute_return_type_1() { // ----- func.func @const_attribute_return_type_2() { - // expected-error @+1 {{'emitc.constant' op requires attribute's type ('!emitc.opaque<"char">') to match op's return type ('!emitc.opaque<"mychar">')}} - %c0 = "emitc.constant"(){value = "CHAR_MIN" : !emitc.opaque<"char">} : () -> !emitc.opaque<"mychar"> + // expected-error @+1 {{'emitc.constant' op attribute 'value' failed to satisfy constraint: An opaque attribute or TypedAttr instance}} + %c0 = "emitc.constant"(){value = unit} : () -> i32 return } @@ -18,7 +26,7 @@ func.func @const_attribute_return_type_2() { func.func @empty_constant() { // expected-error @+1 {{'emitc.constant' op value must not be empty}} - %c0 = "emitc.constant"(){value = ""} : () -> i32 + %c0 = "emitc.constant"(){value = #emitc.opaque<"">} : () -> i32 return } @@ -98,7 +106,7 @@ func.func @illegal_operand() { // ----- func.func @var_attribute_return_type_1() { - // expected-error @+1 {{'emitc.variable' op requires attribute's type ('i64') to match op's return type ('i32')}} + // expected-error @+1 {{'emitc.variable' op requires attribute to either be an #emitc.opaque attribute or it's type ('i64') to match the op's result type ('i32')}} %c0 = "emitc.variable"(){value = 42: i64} : () -> i32 return } @@ -106,8 +114,8 @@ func.func @var_attribute_return_type_1() { // ----- func.func @var_attribute_return_type_2() { - // expected-error @+1 {{'emitc.variable' op requires attribute's type ('!emitc.ptr') to match op's return type ('!emitc.ptr')}} - %c0 = "emitc.variable"(){value = "nullptr" : !emitc.ptr} : () -> !emitc.ptr + // expected-error @+1 {{'emitc.variable' op attribute 'value' failed to satisfy constraint: An opaque attribute or TypedAttr instance}} + %c0 = "emitc.variable"(){value = unit} : () -> i32 return } diff --git a/mlir/test/Target/Cpp/const.mlir b/mlir/test/Target/Cpp/const.mlir index e6c94732e9f6b..28a547909a0ac 100644 --- a/mlir/test/Target/Cpp/const.mlir +++ b/mlir/test/Target/Cpp/const.mlir @@ -2,7 +2,7 @@ // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP func.func @emitc_constant() { - %c0 = "emitc.constant"(){value = #emitc.opaque<"">} : () -> i32 + %c0 = "emitc.constant"(){value = #emitc.opaque<"INT_MAX">} : () -> i32 %c1 = "emitc.constant"(){value = 42 : i32} : () -> i32 %c2 = "emitc.constant"(){value = -1 : i32} : () -> i32 %c3 = "emitc.constant"(){value = -1 : si8} : () -> si8 @@ -11,7 +11,7 @@ func.func @emitc_constant() { return } // CPP-DEFAULT: void emitc_constant() { -// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]]; +// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = INT_MAX; // CPP-DEFAULT-NEXT: int32_t [[V1:[^ ]*]] = 42; // CPP-DEFAULT-NEXT: int32_t [[V2:[^ ]*]] = -1; // CPP-DEFAULT-NEXT: int8_t [[V3:[^ ]*]] = -1; @@ -25,7 +25,7 @@ func.func @emitc_constant() { // CPP-DECLTOP-NEXT: int8_t [[V3:[^ ]*]]; // CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]]; // CPP-DECLTOP-NEXT: char [[V5:[^ ]*]]; -// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[V0]] = INT_MAX; // CPP-DECLTOP-NEXT: [[V1]] = 42; // CPP-DECLTOP-NEXT: [[V2]] = -1; // CPP-DECLTOP-NEXT: [[V3]] = -1; From 241a2baeb58c761616185d3e63f08f8e14c3c1cc Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Thu, 4 Jan 2024 15:43:33 +0100 Subject: [PATCH 04/21] [mlir][EmitC] Use declarative assembly format for opaque types and attributes (#76066) The parser and printer of string attributes were changed to handle escape sequences. Therefore, we no longer require a custom parser and printer. Verification is moved from the parser to the verifier accordingly. --- .../mlir/Dialect/EmitC/IR/EmitCAttributes.td | 3 +- .../mlir/Dialect/EmitC/IR/EmitCTypes.td | 3 +- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 49 +++---------------- 3 files changed, 11 insertions(+), 44 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td index ae843e49c6c5b..ea5e9efd5fa0b 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td @@ -57,8 +57,7 @@ def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> { }]; let parameters = (ins StringRefParameter<"the opaque value">:$value); - - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "`<` $value `>`"; } def EmitC_OpaqueOrTypedAttr : AnyAttrOf<[EmitC_OpaqueAttr, TypedAttrInterface]>; diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index 8dfda3be99d5f..5ab729df67882 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -90,7 +90,8 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> { }]; let parameters = (ins StringRefParameter<"the opaque value">:$value); - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "`<` $value `>`"; + let genVerifyDecl = 1; } def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> { diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index bdba2de4b073d..d9a495ae94ea0 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -636,27 +636,6 @@ LogicalResult emitc::YieldOp::verify() { #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc" -Attribute emitc::OpaqueAttr::parse(AsmParser &parser, Type type) { - if (parser.parseLess()) - return Attribute(); - std::string value; - SMLoc loc = parser.getCurrentLocation(); - if (parser.parseOptionalString(&value)) { - parser.emitError(loc) << "expected string"; - return Attribute(); - } - if (parser.parseGreater()) - return Attribute(); - - return get(parser.getContext(), value); -} - -void emitc::OpaqueAttr::print(AsmPrinter &printer) const { - printer << "<\""; - llvm::printEscapedString(getValue(), printer.getStream()); - printer << "\">"; -} - //===----------------------------------------------------------------------===// // EmitC Types //===----------------------------------------------------------------------===// @@ -731,27 +710,15 @@ emitc::ArrayType::cloneWith(std::optional> shape, // OpaqueType //===----------------------------------------------------------------------===// -Type emitc::OpaqueType::parse(AsmParser &parser) { - if (parser.parseLess()) - return Type(); - std::string value; - SMLoc loc = parser.getCurrentLocation(); - if (parser.parseOptionalString(&value) || value.empty()) { - parser.emitError(loc) << "expected non empty string in !emitc.opaque type"; - return Type(); +LogicalResult mlir::emitc::OpaqueType::verify( + llvm::function_ref emitError, + llvm::StringRef value) { + if (value.empty()) { + return emitError() << "expected non empty string in !emitc.opaque type"; } if (value.back() == '*') { - parser.emitError(loc) << "pointer not allowed as outer type with " - "!emitc.opaque, use !emitc.ptr instead"; - return Type(); + return emitError() << "pointer not allowed as outer type with " + "!emitc.opaque, use !emitc.ptr instead"; } - if (parser.parseGreater()) - return Type(); - return get(parser.getContext(), value); -} - -void emitc::OpaqueType::print(AsmPrinter &printer) const { - printer << "<\""; - llvm::printEscapedString(getValue(), printer.getStream()); - printer << "\">"; + return success(); } From 9e3863c4a9633da6d6d102016a1d947f1d1119a3 Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Wed, 31 Jan 2024 11:56:16 +0100 Subject: [PATCH 05/21] [mlir][EmitC] Add `verbatim` op (#79584) The `verbatim` operation produces no results and the value is emitted as is followed by a line break ('\n' character) during translation. Note: Use with caution. This operation can have arbitrary effects on the semantics of the emitted code. Use semantically more meaningful operations whenever possible. Additionally this op is *NOT* intended to be used to inject large snippets of code. This operation can be used in situations where a more suitable operation is not yet implemented in the dialect or where preprocessor directives interfere with the structure of the code. Co-authored-by: Marius Brehler --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 37 +++++++++++++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 20 +++++++---- mlir/test/Dialect/EmitC/ops.mlir | 11 ++++++ mlir/test/Target/Cpp/verbatim.mlir | 21 ++++++++++++ 4 files changed, 83 insertions(+), 6 deletions(-) create mode 100644 mlir/test/Target/Cpp/verbatim.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 4ece9471a67c2..8dc157b5517c3 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -544,6 +544,43 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> { let hasVerifier = 1; } +def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { + let summary = "Verbatim operation"; + let description = [{ + The `verbatim` operation produces no results and the value is emitted as is + followed by a line break ('\n' character) during translation. + + Note: Use with caution. This operation can have arbitrary effects on the + semantics of the emitted code. Use semantically more meaningful operations + whenever possible. Additionally this op is *NOT* intended to be used to + inject large snippets of code. + + This operation can be used in situations where a more suitable operation is + not yet implemented in the dialect or where preprocessor directives + interfere with the structure of the code. One example of this is to declare + the linkage of external symbols to make the generated code usable in both C + and C++ contexts: + + ```c++ + #ifdef __cplusplus + extern "C" { + #endif + + ... + + #ifdef __cplusplus + } + #endif + ``` + }]; + + let arguments = (ins + StrAttr:$value, + UnitAttr:$trailing_semicolon + ); + let assemblyFormat = "$value attr-dict"; +} + def EmitC_AssignOp : EmitC_Op<"assign", []> { let summary = "Assign operation"; let description = [{ diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 02a54f792567f..973c19b8c9b5f 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -442,6 +442,15 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) { return printBinaryOperation(emitter, operation, binaryOperator); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::VerbatimOp verbatimOp) { + raw_ostream &os = emitter.ostream(); + + os << verbatimOp.getValue(); + + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, cf::BranchOp branchOp) { raw_ostream &os = emitter.ostream(); @@ -825,11 +834,10 @@ static LogicalResult printOperation(CppEmitter &emitter, for (Operation &op : block.getOperations()) { // When generating code for an emitc.if or cf.cond_br op no semicolon // needs to be printed after the closing brace. - // When generating code for an emitc.for op, printing a trailing semicolon - // is handled within the printOperation function. - bool trailingSemicolon = - !isa( - op); + // When generating code for an emitc.for and emitc.verbatim op, printing a + // trailing semicolon is handled within the printOperation function. + bool trailingSemicolon = !isa(op); if (failed(emitter.emitOperation( op, /*trailingSemicolon=*/trailingSemicolon))) @@ -1175,7 +1183,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp, emitc::SubscriptOp, - emitc::VariableOp>( + emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 5af90f8749beb..645501fb058c9 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -174,3 +174,14 @@ func.func @test_subscript(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5x emitc.assign %0 : f32 to %1 : f32 return } + +emitc.verbatim "#ifdef __cplusplus" +emitc.verbatim "extern \"C\" {" +emitc.verbatim "#endif // __cplusplus" + +emitc.verbatim "#ifdef __cplusplus" +emitc.verbatim "} // extern \"C\"" +emitc.verbatim "#endif // __cplusplus" + +emitc.verbatim "typedef int32_t i32;" +emitc.verbatim "typedef float f32;" diff --git a/mlir/test/Target/Cpp/verbatim.mlir b/mlir/test/Target/Cpp/verbatim.mlir new file mode 100644 index 0000000000000..10465dd781a81 --- /dev/null +++ b/mlir/test/Target/Cpp/verbatim.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s + + +emitc.verbatim "#ifdef __cplusplus" +// CHECK: #ifdef __cplusplus +emitc.verbatim "extern \"C\" {" +// CHECK-NEXT: extern "C" { +emitc.verbatim "#endif // __cplusplus" +// CHECK-NEXT: #endif // __cplusplus +emitc.verbatim "#ifdef __cplusplus" +// CHECK-NEXT: #ifdef __cplusplus +emitc.verbatim "} // extern \"C\"" +// CHECK-NEXT: } // extern "C" +emitc.verbatim "#endif // __cplusplus" +// CHECK-NEXT: #endif // __cplusplus + +emitc.verbatim "typedef int32_t i32;" +// CHECK-NEXT: typedef int32_t i32; +emitc.verbatim "typedef float f32;" +// CHECK-NEXT: typedef float f32; From 763ebf902f110ebb04e3a9f6fb508b57773bb767 Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Wed, 31 Jan 2024 14:42:40 +0100 Subject: [PATCH 06/21] [mlir][EmitC] Remove unused attribute from verbatim op (#80142) The uses of the attribute were removed in code review of #79584, but it's definition was inadvertently kept. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 8dc157b5517c3..8c5384bdd48cc 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -574,10 +574,7 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { ``` }]; - let arguments = (ins - StrAttr:$value, - UnitAttr:$trailing_semicolon - ); + let arguments = (ins StrAttr:$value); let assemblyFormat = "$value attr-dict"; } From 0bcc7827a9e3e27238392216bb2fa032ab78c543 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 1 Feb 2024 10:04:36 +0100 Subject: [PATCH 07/21] [mlir][EmitC] Add func, call and return operations and conversions (#79612) This adds a `func`, `call` and `return` operation to the EmitC dialect, closely related to the corresponding operations of the Func dialect. In contrast to the operations of the Func dialect, the EmitC operations do not support multiple results. The `emitc.func` op features a `specifiers` argument that for example allows, with corresponding support in the emitter, to emit `inline static` functions. Furthermore, this adds patterns and a pass to convert the Func dialect to EmitC. A `func.func` op that is `private` is converted to `emitc.func` with a `"static"` specifier. --- .../mlir/Conversion/FuncToEmitC/FuncToEmitC.h | 18 ++ .../Conversion/FuncToEmitC/FuncToEmitCPass.h | 21 +++ mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 9 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 1 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 170 ++++++++++++++++++ mlir/lib/Conversion/CMakeLists.txt | 1 + .../lib/Conversion/FuncToEmitC/CMakeLists.txt | 16 ++ .../Conversion/FuncToEmitC/FuncToEmitC.cpp | 116 ++++++++++++ .../FuncToEmitC/FuncToEmitCPass.cpp | 47 +++++ mlir/lib/Dialect/EmitC/IR/CMakeLists.txt | 2 + mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 119 ++++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 147 +++++++++++---- .../Conversion/FuncToEmitC/func-to-emitc.mlir | 55 ++++++ mlir/test/Dialect/EmitC/invalid_ops.mlir | 37 ++++ mlir/test/Dialect/EmitC/ops.mlir | 15 ++ mlir/test/Target/Cpp/func.mlir | 39 ++++ .../llvm-project-overlay/mlir/BUILD.bazel | 31 ++++ 18 files changed, 815 insertions(+), 30 deletions(-) create mode 100644 mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h create mode 100644 mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h create mode 100644 mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt create mode 100644 mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp create mode 100644 mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp create mode 100644 mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir create mode 100644 mlir/test/Target/Cpp/func.mlir diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h new file mode 100644 index 0000000000000..5c7f87e470306 --- /dev/null +++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h @@ -0,0 +1,18 @@ +//===- FuncToEmitC.h - Func to EmitC Patterns -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H +#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H + +namespace mlir { +class RewritePatternSet; + +void populateFuncToEmitCPatterns(RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h new file mode 100644 index 0000000000000..65936703ee13e --- /dev/null +++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h @@ -0,0 +1,21 @@ +//===- FuncToEmitCPass.h - Func to EmitC Pass -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H +#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTFUNCTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index c91cc39829215..f334ec7a592f8 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -29,6 +29,7 @@ #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" +#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bd0be40c04af3..2abd4b4b94f9d 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -356,6 +356,15 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> { ]; } +//===----------------------------------------------------------------------===// +// FuncToEmitC +//===----------------------------------------------------------------------===// + +def ConvertFuncToEmitC : Pass<"convert-func-to-emitc", "ModuleOp"> { + let summary = "Convert Func dialect to EmitC dialect"; + let dependentDialects = ["emitc::EmitCDialect"]; +} + //===----------------------------------------------------------------------===// // FuncToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index 4dff26e23c428..3d38744527d59 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -20,6 +20,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc" diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 8c5384bdd48cc..4e316ab41b078 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -16,8 +16,10 @@ include "mlir/Dialect/EmitC/IR/EmitCAttributes.td" include "mlir/Dialect/EmitC/IR/EmitCTypes.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/RegionKindInterface.td" @@ -386,6 +388,174 @@ def EmitC_ForOp : EmitC_Op<"for", let hasRegionVerifier = 1; } +def EmitC_CallOp : EmitC_Op<"call", + [CallOpInterface, + DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `emitc.call` operation represents a direct call to an `emitc.func` + that is within the same symbol scope as the call. The operands and result type + of the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType(); + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def EmitC_FuncOp : EmitC_Op<"func", [ + AutomaticAllocationScope, + FunctionOpInterface, IsolatedFromAbove +]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). While the MLIR textual form provides a nice + inline syntax for function arguments, they are internally represented as + “block arguments” to the first block in the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // A function with no results: + emitc.func @foo(%arg0 : i32) { + emitc.call_opaque "bar" (%arg0) : (i32) -> () + emitc.return + } + + // A function with its argument as single result: + emitc.func @foo(%arg0 : i32) -> i32 { + emitc.return %arg0 : i32 + } + + // A function with specifiers attribute: + emitc.func @example_specifiers_fn_attr() -> i32 + attributes {specifiers = ["static","inline"]} { + %0 = emitc.call_opaque "foo" (): () -> i32 + emitc.return %0 : i32 + } + + ``` + }]; + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$specifiers, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">, + ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `emitc.return` operation represents a return operation within a function. + The operation takes zero or exactly one operand and produces no results. + The operand number and type must match the signature of the function + that contains the operation. + + Example: + + ```mlir + emitc.func @foo() : (i32) { + ... + emitc.return %0 : i32 + } + ``` + }]; + let arguments = (ins Optional:$operand); + + let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?"; + let hasVerifier = 1; +} + def EmitC_IncludeOp : EmitC_Op<"include", [HasParent<"ModuleOp">]> { let summary = "Include operation"; diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 7e9369b14b401..b70c26effe2b6 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -18,6 +18,7 @@ add_subdirectory(ControlFlowToLLVM) add_subdirectory(ControlFlowToSCF) add_subdirectory(ControlFlowToSPIRV) add_subdirectory(ConvertToLLVM) +add_subdirectory(FuncToEmitC) add_subdirectory(FuncToLLVM) add_subdirectory(FuncToSPIRV) add_subdirectory(GPUCommon) diff --git a/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt b/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt new file mode 100644 index 0000000000000..97752205bbcb4 --- /dev/null +++ b/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRFuncToEmitC + FuncToEmitC.cpp + FuncToEmitCPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/FuncToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIREmitCDialect + MLIRFuncDialect + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp new file mode 100644 index 0000000000000..ac3d8297953f3 --- /dev/null +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -0,0 +1,116 @@ +//===- FuncToEmitC.cpp - Func to EmitC Patterns -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert the Func dialect to the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +namespace { +class CallOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Multiple results func was not converted to `emitc.func`. + if (callOp.getNumResults() > 1) + return rewriter.notifyMatchFailure( + callOp, "only functions with zero or one result can be converted"); + + rewriter.replaceOpWithNewOp( + callOp, + callOp.getNumResults() ? callOp.getResult(0).getType() : nullptr, + adaptor.getOperands(), callOp->getAttrs()); + + return success(); + } +}; + +class FuncOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (funcOp.getFunctionType().getNumResults() > 1) + return rewriter.notifyMatchFailure( + funcOp, "only functions with zero or one result can be converted"); + + if (funcOp.isDeclaration()) + return rewriter.notifyMatchFailure(funcOp, + "declarations cannot be converted"); + + // Create the converted `emitc.func` op. + emitc::FuncOp newFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType()); + + // Copy over all attributes other than the function name and type. + for (const auto &namedAttr : funcOp->getAttrs()) { + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && + namedAttr.getName() != SymbolTable::getSymbolAttrName()) + newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + + // Add `static` to specifiers if `func.func` is private. + if (funcOp.isPrivate()) { + ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"}); + newFuncOp.setSpecifiersAttr(specifiers); + } + + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + rewriter.eraseOp(funcOp); + + return success(); + } +}; + +class ReturnOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (returnOp.getNumOperands() > 1) + return rewriter.notifyMatchFailure( + returnOp, "only zero or one operand is supported"); + + rewriter.replaceOpWithNewOp( + returnOp, + returnOp.getNumOperands() ? adaptor.getOperands()[0] : nullptr); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + + patterns.add(ctx); +} diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp new file mode 100644 index 0000000000000..26d32e29bef8c --- /dev/null +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp @@ -0,0 +1,47 @@ +//===- FuncToEmitC.cpp - Func to EmitC Pass ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert the Func dialect to the EmitC dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h" + +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTFUNCTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertFuncToEmitC + : public impl::ConvertFuncToEmitCBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertFuncToEmitC::runOnOperation() { + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalOp(); + + RewritePatternSet patterns(&getContext()); + populateFuncToEmitCPatterns(patterns); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt index 4665c41a62e80..4cc54201d2745 100644 --- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt @@ -9,8 +9,10 @@ add_mlir_dialect_library(MLIREmitCDialect MLIREmitCAttributesIncGen LINK_LIBS PUBLIC + MLIRCallInterfaces MLIRCastInterfaces MLIRControlFlowInterfaces + MLIRFunctionInterfaces MLIRIR MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index d9a495ae94ea0..4310237cdc412 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -8,7 +8,10 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -348,6 +351,122 @@ LogicalResult ForOp::verifyRegions() { return success(); } +//===----------------------------------------------------------------------===// +// CallOp +//===----------------------------------------------------------------------===// + +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this)->getAttrOfType("callee"); + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +FunctionType CallOp::getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +LogicalResult FuncOp::verify() { + if (getNumResults() > 1) + return emitOpError("requires zero or exactly one result, but has ") + << getNumResults(); + + if (isExternal()) + return emitOpError("does not support empty function bodies"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + if (getNumOperands() != function.getNumResults()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << function.getNumResults(); + + if (function.getNumResults() == 1) + if (getOperand().getType() != function.getResultTypes()[0]) + return emitError() << "type of the return operand (" + << getOperand().getType() + << ") doesn't match function result type (" + << function.getResultTypes()[0] << ")" + << " in function @" << function.getName(); + return success(); +} + //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 973c19b8c9b5f..4474dd811226a 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -517,18 +517,33 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } -static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { - if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) +static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, + StringRef callee) { + if (failed(emitter.emitAssignPrefix(*callOp))) return failure(); raw_ostream &os = emitter.ostream(); - os << callOp.getCallee() << "("; - if (failed(emitter.emitOperands(*callOp.getOperation()))) + os << callee << "("; + if (failed(emitter.emitOperands(*callOp))) return failure(); os << ")"; return success(); } +static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { + Operation *operation = callOp.getOperation(); + StringRef callee = callOp.getCallee(); + + return printCallOperation(emitter, operation, callee); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { + Operation *operation = callOp.getOperation(); + StringRef callee = callOp.getCallee(); + + return printCallOperation(emitter, operation, callee); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOpaqueOp callOpaqueOp) { raw_ostream &os = emitter.ostream(); @@ -746,6 +761,19 @@ static LogicalResult printOperation(CppEmitter &emitter, } } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::ReturnOp returnOp) { + raw_ostream &os = emitter.ostream(); + os << "return"; + if (returnOp.getNumOperands() == 0) + return success(); + + os << " "; + if (failed(emitter.emitOperand(returnOp.getOperand()))) + return failure(); + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { CppEmitter::Scope scope(emitter); @@ -756,37 +784,33 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { return success(); } -static LogicalResult printOperation(CppEmitter &emitter, - func::FuncOp functionOp) { - // We need to declare variables at top if the function has multiple blocks. - if (!emitter.shouldDeclareVariablesAtTop() && - functionOp.getBlocks().size() > 1) { - return functionOp.emitOpError( - "with multiple blocks needs variables declared at top"); - } - - CppEmitter::Scope scope(emitter); +static LogicalResult printFunctionArgs(CppEmitter &emitter, + Operation *functionOp, + Region::BlockArgListType arguments) { raw_indented_ostream &os = emitter.ostream(); - if (failed(emitter.emitTypes(functionOp.getLoc(), - functionOp.getFunctionType().getResults()))) - return failure(); - os << " " << functionOp.getName(); - os << "("; - if (failed(interleaveCommaWithError(functionOp.getArguments(), os, + if (failed(interleaveCommaWithError(arguments, os, [&](BlockArgument arg) -> LogicalResult { return emitter.emitVariableDeclaration( - functionOp.getLoc(), arg.getType(), + functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg)); }))) return failure(); - os << ") {\n"; + + return success(); +} + +static LogicalResult printFunctionBody(CppEmitter &emitter, + Operation *functionOp, + Region::BlockListType &blocks) { + raw_indented_ostream &os = emitter.ostream(); os.indent(); + if (emitter.shouldDeclareVariablesAtTop()) { // Declare all variables that hold op results including those from nested // regions. WalkResult result = - functionOp.walk([&](Operation *op) -> WalkResult { + functionOp->walk([&](Operation *op) -> WalkResult { if (isa(op) || isa(op->getParentOp()) || (isa(op) && @@ -805,7 +829,6 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); } - Region::BlockListType &blocks = functionOp.getBlocks(); // Create label names for basic blocks. for (Block &block : blocks) { emitter.getOrCreateName(block); @@ -815,7 +838,7 @@ static LogicalResult printOperation(CppEmitter &emitter, for (Block &block : llvm::drop_begin(blocks)) { for (BlockArgument &arg : block.getArguments()) { if (emitter.hasValueInScope(arg)) - return functionOp.emitOpError(" block argument #") + return functionOp->emitOpError(" block argument #") << arg.getArgNumber() << " is out of scope"; if (failed( emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { @@ -844,7 +867,71 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); } } - os.unindent() << "}\n"; + + os.unindent(); + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + func::FuncOp functionOp) { + // We need to declare variables at top if the function has multiple blocks. + if (!emitter.shouldDeclareVariablesAtTop() && + functionOp.getBlocks().size() > 1) { + return functionOp.emitOpError( + "with multiple blocks needs variables declared at top"); + } + + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + Operation *operation = functionOp.getOperation(); + if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) + return failure(); + os << ") {\n"; + if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) + return failure(); + os << "}\n"; + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::FuncOp functionOp) { + // We need to declare variables at top if the function has multiple blocks. + if (!emitter.shouldDeclareVariablesAtTop() && + functionOp.getBlocks().size() > 1) { + return functionOp.emitOpError( + "with multiple blocks needs variables declared at top"); + } + + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + if (functionOp.getSpecifiers()) { + for (Attribute specifier : functionOp.getSpecifiersAttr()) { + os << cast(specifier).str() << " "; + } + } + + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + Operation *operation = functionOp.getOperation(); + if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) + return failure(); + os << ") {\n"; + if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) + return failure(); + os << "}\n"; + return success(); } @@ -1178,12 +1265,12 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { .Case( [&](auto op) { return printOperation(*this, op); }) // EmitC ops. - .Case( + emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp, + emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, + emitc::SubscriptOp, emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir new file mode 100644 index 0000000000000..a1c8af2587aa0 --- /dev/null +++ b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt -split-input-file -convert-func-to-emitc %s | FileCheck %s + +// CHECK-LABEL: emitc.func @foo() +// CHECK-NEXT: emitc.return +func.func @foo() { + return +} + +// ----- + +// CHECK-LABEL: emitc.func private @foo() attributes {specifiers = ["static"]} +// CHECK-NEXT: emitc.return +func.func private @foo() { + return +} + +// ----- + +// CHECK-LABEL: emitc.func @foo(%arg0: i32) +func.func @foo(%arg0: i32) { + emitc.call_opaque "bar"(%arg0) : (i32) -> () + return +} + +// ----- + +// CHECK-LABEL: emitc.func @foo(%arg0: i32) -> i32 +// CHECK-NEXT: emitc.return %arg0 : i32 +func.func @foo(%arg0: i32) -> i32 { + return %arg0 : i32 +} + +// ----- + +// CHECK-LABEL: emitc.func @foo(%arg0: i32, %arg1: i32) -> i32 +func.func @foo(%arg0: i32, %arg1: i32) -> i32 { + %0 = "emitc.add" (%arg0, %arg1) : (i32, i32) -> i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: emitc.func private @return_i32(%arg0: i32) -> i32 attributes {specifiers = ["static"]} +// CHECK-NEXT: emitc.return %arg0 : i32 +func.func private @return_i32(%arg0: i32) -> i32 { + return %arg0 : i32 +} + +// CHECK-LABEL: emitc.func @call(%arg0: i32) -> i32 +// CHECK-NEXT: %0 = emitc.call @return_i32(%arg0) : (i32) -> i32 +// CHECK-NEXT: emitc.return %0 : i32 +func.func @call(%arg0: i32) -> i32 { + %0 = call @return_i32(%arg0) : (i32) -> (i32) + return %0 : i32 +} diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 1357b81fd6f58..b1e5139fa0b05 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -297,3 +297,40 @@ func.func @test_expression_multiple_results(%arg0: i32) -> i32 { } return %r : i32 } + +// ----- + +// expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}} +emitc.func @multiple_results(%0: i32) -> (i32, i32) { + emitc.return %0 : i32 +} + +// ----- + +emitc.func @resulterror() -> i32 { +^bb42: + emitc.return // expected-error {{'emitc.return' op has 0 operands, but enclosing function (@resulterror) returns 1}} +} + +// ----- + +emitc.func @return_type_mismatch() -> i32 { + %0 = emitc.call_opaque "foo()"(): () -> f32 + emitc.return %0 : f32 // expected-error {{type of the return operand ('f32') doesn't match function result type ('i32') in function @return_type_mismatch}} +} + +// ----- + +func.func @return_inside_func.func(%0: i32) -> (i32) { + // expected-error@+1 {{'emitc.return' op expects parent op 'emitc.func'}} + emitc.return %0 : i32 +} +// ----- + +// expected-error@+1 {{expected non-function type}} +emitc.func @func_variadic(...) + +// ----- + +// expected-error@+1 {{'emitc.func' op does not support empty function bodies}} +emitc.func private @empty() diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 645501fb058c9..426cd66e12031 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -15,6 +15,21 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) { return } +emitc.func @func(%arg0 : i32) { + emitc.call_opaque "foo"(%arg0) : (i32) -> () + emitc.return +} + +emitc.func @return_i32() -> i32 attributes {specifiers = ["static","inline"]} { + %0 = emitc.call_opaque "foo"(): () -> i32 + emitc.return %0 : i32 +} + +emitc.func @call() -> i32 { + %0 = emitc.call @return_i32() : () -> (i32) + emitc.return %0 : i32 +} + func.func @cast(%arg0: i32) { %1 = emitc.cast %arg0: i32 to f32 return diff --git a/mlir/test/Target/Cpp/func.mlir b/mlir/test/Target/Cpp/func.mlir new file mode 100644 index 0000000000000..d2e14a9e5a7ae --- /dev/null +++ b/mlir/test/Target/Cpp/func.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP + + +emitc.func @emitc_func(%arg0 : i32) { + emitc.call_opaque "foo" (%arg0) : (i32) -> () + emitc.return +} +// CPP-DEFAULT: void emitc_func(int32_t [[V0:[^ ]*]]) { +// CPP-DEFAULT-NEXT: foo([[V0:[^ ]*]]); +// CPP-DEFAULT-NEXT: return; + + +emitc.func @return_i32() -> i32 attributes {specifiers = ["static","inline"]} { + %0 = emitc.call_opaque "foo" (): () -> i32 + emitc.return %0 : i32 +} +// CPP-DEFAULT: static inline int32_t return_i32() { +// CPP-DEFAULT-NEXT: [[V0:[^ ]*]] = foo(); +// CPP-DEFAULT-NEXT: return [[V0:[^ ]*]]; + +// CPP-DECLTOP: static inline int32_t return_i32() { +// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; +// CPP-DECLTOP-NEXT: [[V0:]] = foo(); +// CPP-DECLTOP-NEXT: return [[V0:[^ ]*]]; + + +emitc.func @emitc_call() -> i32 { + %0 = emitc.call @return_i32() : () -> (i32) + emitc.return %0 : i32 +} +// CPP-DEFAULT: int32_t emitc_call() { +// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = return_i32(); +// CPP-DEFAULT-NEXT: return [[V0:[^ ]*]]; + +// CPP-DECLTOP: int32_t emitc_call() { +// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; +// CPP-DECLTOP-NEXT: [[V0:[^ ]*]] = return_i32(); +// CPP-DECLTOP-NEXT: return [[V0:[^ ]*]]; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index cdb35d87992ed..1ecca1df0ea47 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1560,8 +1560,10 @@ td_library( includes = ["include"], deps = [ ":BuiltinDialectTdFiles", + ":CallInterfacesTdFiles", ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":FunctionInterfacesTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -3633,10 +3635,12 @@ cc_library( ]), includes = ["include"], deps = [ + ":CallOpInterfaces", ":CastInterfaces", ":ControlFlowInterfaces", ":EmitCAttributesIncGen", ":EmitCOpsIncGen", + ":FunctionInterfaces", ":IR", ":SideEffectInterfaces", "//llvm:Support", @@ -3853,6 +3857,7 @@ cc_library( ":ControlFlowToSPIRV", ":ConversionPassIncGen", ":ConvertToLLVM", + ":FuncToEmitC", ":FuncToLLVM", ":FuncToSPIRV", ":GPUToGPURuntimeTransforms", @@ -6751,6 +6756,32 @@ cc_library( ], ) +cc_library( + name = "FuncToEmitC", + srcs = glob([ + "lib/Conversion/FuncToEmitC*.cpp", + "lib/Conversion/FuncToEmitC/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/FuncToEmitC/*.h", + ]), + includes = [ + "include", + "lib/Conversion/FuncToEmitC", + ], + deps = [ + ":ConversionPassIncGen", + ":FuncDialect", + ":EmitCDialect", + ":IR", + ":Pass", + ":Support", + ":TransformUtils", + ":Transforms", + "//llvm:Support", + ], +) + cc_library( name = "FuncToSPIRV", srcs = glob([ From ed20cea4d9f6f1bc49d117c87fc188309a310d1d Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 5 Feb 2024 16:58:10 +0100 Subject: [PATCH 08/21] [mlir][EmitC] Add support for external functions (#80547) This adds a conversion from an externaly defined `func.func`, a `func.func` without function body, to an `emitc.func` with an `extern` specifier. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 3 +++ .../Conversion/FuncToEmitC/FuncToEmitC.cpp | 20 +++++++++++-------- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 3 --- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 18 +++++++++++++++++ .../Conversion/FuncToEmitC/func-to-emitc.mlir | 5 +++++ mlir/test/Dialect/EmitC/invalid_ops.mlir | 5 ----- mlir/test/Dialect/EmitC/ops.mlir | 2 ++ mlir/test/Target/Cpp/func.mlir | 3 +++ 8 files changed, 43 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 4e316ab41b078..7b1d590987058 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -498,6 +498,9 @@ def EmitC_FuncOp : EmitC_Op<"func", [ emitc.return %0 : i32 } + // An external function definition: + emitc.func private @extern_func(i32) + attributes {specifiers = ["extern"]} ``` }]; let arguments = (ins SymbolNameAttr:$sym_name, diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp index ac3d8297953f3..6a8ecb7b00473 100644 --- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -57,10 +57,6 @@ class FuncOpConversion final : public OpConversionPattern { return rewriter.notifyMatchFailure( funcOp, "only functions with zero or one result can be converted"); - if (funcOp.isDeclaration()) - return rewriter.notifyMatchFailure(funcOp, - "declarations cannot be converted"); - // Create the converted `emitc.func` op. emitc::FuncOp newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType()); @@ -72,14 +68,22 @@ class FuncOpConversion final : public OpConversionPattern { newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } - // Add `static` to specifiers if `func.func` is private. - if (funcOp.isPrivate()) { + // Add `extern` to specifiers if `func.func` is declaration only. + if (funcOp.isDeclaration()) { + ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"}); + newFuncOp.setSpecifiersAttr(specifiers); + } + + // Add `static` to specifiers if `func.func` is private but not a + // declaration. + if (funcOp.isPrivate() && !funcOp.isDeclaration()) { ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"}); newFuncOp.setSpecifiersAttr(specifiers); } - rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), - newFuncOp.end()); + if (!funcOp.isDeclaration()) + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); rewriter.eraseOp(funcOp); return success(); diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 4310237cdc412..5e78fe9fabd2b 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -438,9 +438,6 @@ LogicalResult FuncOp::verify() { return emitOpError("requires zero or exactly one result, but has ") << getNumResults(); - if (isExternal()) - return emitOpError("does not support empty function bodies"); - return success(); } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 4474dd811226a..8eaf3cb86c0f6 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -784,6 +784,17 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { return success(); } +static LogicalResult printFunctionArgs(CppEmitter &emitter, + Operation *functionOp, + ArrayRef arguments) { + raw_indented_ostream &os = emitter.ostream(); + + return ( + interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult { + return emitter.emitType(functionOp->getLoc(), arg); + })); +} + static LogicalResult printFunctionArgs(CppEmitter &emitter, Operation *functionOp, Region::BlockArgListType arguments) { @@ -925,6 +936,13 @@ static LogicalResult printOperation(CppEmitter &emitter, os << "("; Operation *operation = functionOp.getOperation(); + if (functionOp.isExternal()) { + if (failed(printFunctionArgs(emitter, operation, + functionOp.getArgumentTypes()))) + return failure(); + os << ");"; + return success(); + } if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) return failure(); os << ") {\n"; diff --git a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir index a1c8af2587aa0..5c96cf1ce0d34 100644 --- a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir +++ b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir @@ -53,3 +53,8 @@ func.func @call(%arg0: i32) -> i32 { %0 = call @return_i32(%arg0) : (i32) -> (i32) return %0 : i32 } + +// ----- + +// CHECK-LABEL: emitc.func private @return_i32(i32) -> i32 attributes {specifiers = ["extern"]} +func.func private @return_i32(%arg0: i32) -> i32 diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index b1e5139fa0b05..44e1897069c72 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -329,8 +329,3 @@ func.func @return_inside_func.func(%0: i32) -> (i32) { // expected-error@+1 {{expected non-function type}} emitc.func @func_variadic(...) - -// ----- - -// expected-error@+1 {{'emitc.func' op does not support empty function bodies}} -emitc.func private @empty() diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 426cd66e12031..080bddeb6a51f 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -30,6 +30,8 @@ emitc.func @call() -> i32 { emitc.return %0 : i32 } +emitc.func private @extern(i32) attributes {specifiers = ["extern"]} + func.func @cast(%arg0: i32) { %1 = emitc.cast %arg0: i32 to f32 return diff --git a/mlir/test/Target/Cpp/func.mlir b/mlir/test/Target/Cpp/func.mlir index d2e14a9e5a7ae..a639cae6f623c 100644 --- a/mlir/test/Target/Cpp/func.mlir +++ b/mlir/test/Target/Cpp/func.mlir @@ -37,3 +37,6 @@ emitc.func @emitc_call() -> i32 { // CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; // CPP-DECLTOP-NEXT: [[V0:[^ ]*]] = return_i32(); // CPP-DECLTOP-NEXT: return [[V0:[^ ]*]]; + +emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]} +// CPP-DEFAULT: extern void extern_func(int32_t); From 9c6c868dd3b2e2b34ac80edfdd14f1c674d33ad1 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 6 Feb 2024 08:49:10 +0100 Subject: [PATCH 09/21] [mlir][emitc] Add a `declare_func` operation (#80297) This adds the `emitc.declare_func` operation that allows to emit the declaration of an `emitc.func` at a specific location. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 42 +++++++++++++++++++ mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 18 ++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 46 ++++++++++++++++++--- mlir/test/Dialect/EmitC/invalid_ops.mlir | 10 +++++ mlir/test/Dialect/EmitC/ops.mlir | 2 + mlir/test/Target/Cpp/declare_func.mlir | 16 +++++++ 6 files changed, 128 insertions(+), 6 deletions(-) create mode 100644 mlir/test/Target/Cpp/declare_func.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 7b1d590987058..bad9ef4cdada3 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -460,6 +460,48 @@ def EmitC_CallOp : EmitC_Op<"call", }]; } +def EmitC_DeclareFuncOp : EmitC_Op<"declare_func", [ + DeclareOpInterfaceMethods +]> { + let summary = "An operation to declare a function"; + let description = [{ + The `declare_func` operation allows to insert a function declaration for an + `emitc.func` at a specific position. The operation only requires the `callee` + of the `emitc.func` to be specified as an attribute. + + Example: + + ```mlir + emitc.declare_func @bar + emitc.func @foo(%arg0: i32) -> i32 { + %0 = emitc.call @bar(%arg0) : (i32) -> (i32) + emitc.return %0 : i32 + } + + emitc.func @bar(%arg0: i32) -> i32 { + emitc.return %arg0 : i32 + } + ``` + + ```c++ + // Code emitted for the operations above. + int32_t bar(int32_t v1); + int32_t foo(int32_t v1) { + int32_t v2 = bar(v1); + return v2; + } + + int32_t bar(int32_t v1) { + return v1; + } + ``` + }]; + let arguments = (ins FlatSymbolRefAttr:$sym_name); + let assemblyFormat = [{ + $sym_name attr-dict + }]; +} + def EmitC_FuncOp : EmitC_Op<"func", [ AutomaticAllocationScope, FunctionOpInterface, IsolatedFromAbove diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 5e78fe9fabd2b..803f9047b5c16 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -394,6 +394,24 @@ FunctionType CallOp::getCalleeType() { return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); } +//===----------------------------------------------------------------------===// +// DeclareFuncOp +//===----------------------------------------------------------------------===// + +LogicalResult +DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the sym_name attribute was specified. + auto fnAttr = getSymNameAttr(); + if (!fnAttr) + return emitOpError("requires a 'sym_name' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + return success(); +} + //===----------------------------------------------------------------------===// // FuncOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 8eaf3cb86c0f6..84c3744b38028 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Support/IndentedOstream.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/Cpp/CppEmitter.h" @@ -870,8 +871,9 @@ static LogicalResult printFunctionBody(CppEmitter &emitter, // needs to be printed after the closing brace. // When generating code for an emitc.for and emitc.verbatim op, printing a // trailing semicolon is handled within the printOperation function. - bool trailingSemicolon = !isa(op); + bool trailingSemicolon = + !isa(op); if (failed(emitter.emitOperation( op, /*trailingSemicolon=*/trailingSemicolon))) @@ -953,6 +955,37 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } +static LogicalResult printOperation(CppEmitter &emitter, + DeclareFuncOp declareFuncOp) { + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + + auto functionOp = SymbolTable::lookupNearestSymbolFrom( + declareFuncOp, declareFuncOp.getSymNameAttr()); + + if (!functionOp) + return failure(); + + if (functionOp.getSpecifiers()) { + for (Attribute specifier : functionOp.getSpecifiersAttr()) { + os << cast(specifier).str() << " "; + } + } + + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + Operation *operation = functionOp.getOperation(); + if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) + return failure(); + os << ");"; + + return success(); +} + CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) : os(os), declareVariablesAtTop(declareVariablesAtTop) { valueInScopeCount.push(0); @@ -1285,10 +1318,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { // EmitC ops. .Case( + emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, + emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp, + emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, + emitc::SubOp, emitc::SubscriptOp, emitc::VariableOp, + emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 44e1897069c72..2c763a982569e 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -329,3 +329,13 @@ func.func @return_inside_func.func(%0: i32) -> (i32) { // expected-error@+1 {{expected non-function type}} emitc.func @func_variadic(...) + +// ----- + +// expected-error@+1 {{'emitc.declare_func' op 'bar' does not reference a valid function}} +emitc.declare_func @bar + +// ----- + +// expected-error@+1 {{'emitc.declare_func' op requires attribute 'sym_name'}} +"emitc.declare_func"() : () -> () diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 080bddeb6a51f..5c0ecf998741f 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -15,6 +15,8 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) { return } +emitc.declare_func @func + emitc.func @func(%arg0 : i32) { emitc.call_opaque "foo"(%arg0) : (i32) -> () emitc.return diff --git a/mlir/test/Target/Cpp/declare_func.mlir b/mlir/test/Target/Cpp/declare_func.mlir new file mode 100644 index 0000000000000..72c087a3388e2 --- /dev/null +++ b/mlir/test/Target/Cpp/declare_func.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]); +emitc.declare_func @bar +// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]) { +emitc.func @bar(%arg0: i32) -> i32 { + emitc.return %arg0 : i32 +} + + +// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]); +emitc.declare_func @foo +// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]) { +emitc.func @foo(%arg0: i32) -> i32 attributes {specifiers = ["static","inline"]} { + emitc.return %arg0 : i32 +} From 56efedd4d9070e7aec4c45a0d6edda2e52243bd9 Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Tue, 6 Feb 2024 12:04:13 +0100 Subject: [PATCH 10/21] [mlir][EmitC] Remove unreachable code and fix Windows build warning (#80677) --- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 84c3744b38028..d17789a877d20 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -69,12 +69,12 @@ inline LogicalResult interleaveCommaWithError(const Container &c, /// Return the precedence of a operator as an integer, higher values /// imply higher precedence. -static int getOperatorPrecedence(Operation *operation) { - return llvm::TypeSwitch(operation) +static FailureOr getOperatorPrecedence(Operation *operation) { + return llvm::TypeSwitch>(operation) .Case([&](auto op) { return 11; }) .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 13; }) - .Case([&](auto op) { + .Case([&](auto op) -> FailureOr { switch (op.getPredicate()) { case emitc::CmpPredicate::eq: case emitc::CmpPredicate::ne: @@ -87,13 +87,14 @@ static int getOperatorPrecedence(Operation *operation) { case emitc::CmpPredicate::three_way: return 10; } + return op->emitError("unsupported cmp predicate"); }) .Case([&](auto op) { return 12; }) .Case([&](auto op) { return 12; }) .Case([&](auto op) { return 12; }) .Case([&](auto op) { return 11; }) - .Case([&](auto op) { return 14; }); - llvm_unreachable("Unsupported operator"); + .Case([&](auto op) { return 14; }) + .Default([](auto op) { return op->emitError("unsupported operation"); }); } namespace { @@ -1151,7 +1152,10 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { Operation *rootOp = expressionOp.getRootOp(); emittedExpression = expressionOp; - pushExpressionPrecedence(getOperatorPrecedence(rootOp)); + FailureOr precedence = getOperatorPrecedence(rootOp); + if (failed(precedence)) + return failure(); + pushExpressionPrecedence(precedence.value()); if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false))) return failure(); @@ -1168,13 +1172,15 @@ LogicalResult CppEmitter::emitOperand(Value value) { if (isPartOfCurrentExpression(value)) { Operation *def = value.getDefiningOp(); assert(def && "Expected operand to be defined by an operation"); - int precedence = getOperatorPrecedence(def); - bool encloseInParenthesis = precedence < getExpressionPrecedence(); + FailureOr precedence = getOperatorPrecedence(def); + if (failed(precedence)) + return failure(); + bool encloseInParenthesis = precedence.value() < getExpressionPrecedence(); if (encloseInParenthesis) { os << "("; pushExpressionPrecedence(lowestPrecedence()); } else - pushExpressionPrecedence(precedence); + pushExpressionPrecedence(precedence.value()); if (failed(emitOperation(*def, /*trailingSemicolon=*/false))) return failure(); From f5ea1033b17f5f2e20ddbafcfe916e270d2f1d5b Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Thu, 8 Feb 2024 11:27:08 +0100 Subject: [PATCH 11/21] [mlir][EmitC] Add builders for call_opaque op (#80879) This allows to omit the default valued attributes and therefore write more compact code. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index bad9ef4cdada3..f06bfe44121f6 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -122,6 +122,19 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> { Variadic:$operands ); let results = (outs Variadic); + let builders = [ + OpBuilder<(ins + "::mlir::TypeRange":$resultTypes, + "::llvm::StringRef":$callee, + "::mlir::ValueRange":$operands, + CArg<"::mlir::ArrayAttr", "{}">:$args, + CArg<"::mlir::ArrayAttr", "{}">:$template_args), [{ + build($_builder, $_state, resultTypes, callee, args, template_args, + operands); + }] + > + ]; + let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; From a96e62b667e1d1961b00f53160e8cd6f193a72a9 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 20 Feb 2024 14:16:34 +0100 Subject: [PATCH 12/21] [mlir][EmitC] Remove `func.constant` from emitter (#82342) As part of the renaming the Standard dialect to Func dialect, *support* for the `func.constant` operation was added to the emitter. However, the emitter cannot emit function types. Hence the emission for a snippet like ``` %0 = func.constant @myfn : (f32) -> f32 func.func private @myfn(%arg0: f32) -> f32 { return %arg0 : f32 } ``` failes with `func.mlir:1:6: error: cannot emit type '(f32) -> f32'`. This removes `func.constant` from the emitter. --- mlir/docs/Dialects/emitc.md | 1 - mlir/lib/Target/Cpp/TranslateToCpp.cpp | 10 +--------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md index 809a04660336b..b227a8c4599a8 100644 --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -28,7 +28,6 @@ translating the following operations: * `cf.cond_br` * 'func' Dialect * `func.call` - * `func.constant` * `func.func` * `func.return` * 'arith' Dialect diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index d17789a877d20..9b9ce817c1639 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -338,14 +338,6 @@ static LogicalResult printOperation(CppEmitter &emitter, return printConstantOp(emitter, operation, value); } -static LogicalResult printOperation(CppEmitter &emitter, - func::ConstantOp constantOp) { - Operation *operation = constantOp.getOperation(); - Attribute value = constantOp.getValueAttr(); - - return printConstantOp(emitter, operation, value); -} - static LogicalResult printOperation(CppEmitter &emitter, emitc::AssignOp assignOp) { OpResult result = assignOp.getVar().getDefiningOp()->getResult(0); @@ -1331,7 +1323,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. - .Case( + .Case( [&](auto op) { return printOperation(*this, op); }) // Arithmetic ops. .Case( From 65d0767729f47153b6cca0a5e0125b4b36bd5e57 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 24 Feb 2024 09:10:07 +0100 Subject: [PATCH 13/21] [mlir] Use `OpBuilder::createBlock` in op builders and patterns (#82770) When creating a new block in (conversion) rewrite patterns, `OpBuilder::createBlock` must be used. Otherwise, no `notifyBlockInserted` notification is sent to the listener. Note: The dialect conversion relies on listener notifications to keep track of IR modifications. Creating blocks without the builder API can lead to memory leaks during rollback. --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 2 +- .../Dialect/SPIRV/IR/SPIRVControlFlowOps.td | 4 ++-- .../mlir/Interfaces/FunctionInterfaces.td | 4 ++-- .../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 2 +- .../ControlFlowToSCF/ControlFlowToSCF.cpp | 4 +--- mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 4 ++-- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 4 +--- mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 11 ++++----- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 24 +++++++++++-------- mlir/lib/Dialect/Async/IR/Async.cpp | 17 +++++-------- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 10 ++++---- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 7 +++--- .../Linalg/Transforms/DropUnitDims.cpp | 6 ++--- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 6 ++--- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 9 ++++--- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 3 ++- mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp | 18 +++++++------- mlir/lib/Dialect/Shape/IR/Shape.cpp | 16 ++++++------- .../Transforms/SparseTensorRewriting.cpp | 4 +--- .../SPIRV/Deserialization/Deserializer.cpp | 4 ++-- 20 files changed, 71 insertions(+), 88 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 9e65898154bd6..789450903afe7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1439,7 +1439,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ let extraClassDeclaration = [{ // Add an entry block to an empty function, and set up the block arguments // to match the signature of the function. - Block *addEntryBlock(); + Block *addEntryBlock(OpBuilder &builder); bool isVarArg() { return getFunctionType().isVarArg(); } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index 36ad6755cab25..991e753d1b359 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -285,7 +285,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> { // Adds an empty entry block and loop merge block containing one // spirv.mlir.merge op. - void addEntryAndMergeBlock(); + void addEntryAndMergeBlock(OpBuilder &builder); }]; let hasOpcode = 0; @@ -427,7 +427,7 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> { Block *getMergeBlock(); /// Adds a selection merge block containing one spirv.mlir.merge op. - void addMergeBlock(); + void addMergeBlock(OpBuilder &builder); /// Creates a spirv.mlir.selection op for `if () then { }` /// with `builder`. `builder`'s insertion point will remain at after the diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td index 98e002565cf19..be71063272d80 100644 --- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td +++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td @@ -131,6 +131,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ static void buildWithEntryBlock( OpBuilder &builder, OperationState &state, StringRef name, Type type, ArrayRef attrs, TypeRange inputTypes) { + OpBuilder::InsertionGuard g(builder); state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name), @@ -139,8 +140,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ // Add the function body. Region *bodyRegion = state.addRegion(); - Block *body = new Block(); - bodyRegion->push_back(body); + Block *body = builder.createBlock(bodyRegion); for (Type input : inputTypes) body->addArgument(input, state.location); } diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 0ab53ce7e3327..7760373913761 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -259,7 +259,7 @@ static void addResumeFunction(ModuleOp module) { kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); resumeOp.setPrivate(); - auto *block = resumeOp.addEntryBlock(); + auto *block = resumeOp.addEntryBlock(moduleBuilder); auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); blockBuilder.create(resumeOp.getArgument(0)); diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp index 363e5f9b8cefe..d3ee89743da9d 100644 --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -98,12 +98,10 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( loc, builder.create(loc, builder.getI1Type(), condition), loopVariablesNextIter); - auto *afterBlock = new Block; - whileOp.getAfter().push_back(afterBlock); + Block *afterBlock = builder.createBlock(&whileOp.getAfter()); afterBlock->addArguments( loopVariablesInit.getTypes(), SmallVector(loopVariablesInit.size(), loc)); - builder.setInsertionPointToEnd(afterBlock); builder.create(loc, afterBlock->getArguments()); return whileOp.getOperation(); diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index bd50c67fb8795..53b44aa3241bb 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -135,7 +135,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp); OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); + rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter)); SmallVector args; size_t argOffset = resultStructType ? 1 : 0; @@ -203,7 +203,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, // The wrapper that we synthetize here should only be visible in this module. newFuncOp.setLinkage(LLVM::Linkage::Private); - builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); + builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder)); // Get a ValueRange containing arguments. FunctionType type = cast(funcOp.getFunctionType()); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 2bfca303b5fd4..2dc42f0a85e66 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -520,9 +520,7 @@ struct GlobalMemrefOpLowering global, arrayTy, global.getConstant(), linkage, global.getSymName(), initialValue, alignment, *addressSpace); if (!global.isExternal() && global.isUninitialized()) { - Block *blk = new Block(); - newGlobal.getInitializerRegion().push_back(blk); - rewriter.setInsertionPointToStart(blk); + rewriter.createBlock(&newGlobal.getInitializerRegion()); Value undef[] = { rewriter.create(global.getLoc(), arrayTy)}; rewriter.create(global.getLoc(), undef); diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index febfe97f6c0a9..d90cf931385fc 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -138,14 +138,13 @@ struct ForOpConversion final : SCFToSPIRVPattern { // from header to merge. auto loc = forOp.getLoc(); auto loopOp = rewriter.create(loc, spirv::LoopControl::None); - loopOp.addEntryAndMergeBlock(); + loopOp.addEntryAndMergeBlock(rewriter); OpBuilder::InsertionGuard guard(rewriter); // Create the block for the header. - auto *header = new Block(); - // Insert the header. - loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), - header); + Block *header = rewriter.createBlock(&loopOp.getBody(), + getBlockIt(loopOp.getBody(), 1)); + rewriter.setInsertionPointAfter(loopOp); // Create the new induction variable to use. Value adapLowerBound = adaptor.getLowerBound(); @@ -342,7 +341,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern { ConversionPatternRewriter &rewriter) const override { auto loc = whileOp.getLoc(); auto loopOp = rewriter.create(loc, spirv::LoopControl::None); - loopOp.addEntryAndMergeBlock(); + loopOp.addEntryAndMergeBlock(rewriter); Region &beforeRegion = whileOp.getBefore(); Region &afterRegion = whileOp.getAfter(); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index d5be2e906989f..0ccb5a9f658da 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1809,6 +1809,8 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result, "upper bound operand count does not match the affine map"); assert(step > 0 && "step has to be a positive integer constant"); + OpBuilder::InsertionGuard guard(builder); + // Set variadic segment sizes. result.addAttribute( getOperandSegmentSizeAttr(), @@ -1837,12 +1839,11 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result, // Create a region and a block for the body. The argument of the region is // the loop induction variable. Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); + Block *bodyBlock = builder.createBlock(bodyRegion); Value inductionVar = - bodyBlock.addArgument(builder.getIndexType(), result.location); + bodyBlock->addArgument(builder.getIndexType(), result.location); for (Value val : iterArgs) - bodyBlock.addArgument(val.getType(), val.getLoc()); + bodyBlock->addArgument(val.getType(), val.getLoc()); // Create the default terminator if the builder is not provided and if the // iteration arguments are not provided. Otherwise, leave this to the caller @@ -1851,9 +1852,9 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result, ensureTerminator(*bodyRegion, builder, result.location); } else if (bodyBuilder) { OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); + builder.setInsertionPointToStart(bodyBlock); bodyBuilder(builder, result.location, inductionVar, - bodyBlock.getArguments().drop_front()); + bodyBlock->getArguments().drop_front()); } } @@ -2890,18 +2891,20 @@ void AffineIfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, IntegerSet set, ValueRange args, bool withElseRegion) { assert(resultTypes.empty() || withElseRegion); + OpBuilder::InsertionGuard guard(builder); + result.addTypes(resultTypes); result.addOperands(args); result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set)); Region *thenRegion = result.addRegion(); - thenRegion->push_back(new Block()); + builder.createBlock(thenRegion); if (resultTypes.empty()) AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); Region *elseRegion = result.addRegion(); if (withElseRegion) { - elseRegion->push_back(new Block()); + builder.createBlock(elseRegion); if (resultTypes.empty()) AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); } @@ -3688,6 +3691,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result, "expected upper bound maps to have as many inputs as upper bound " "operands"); + OpBuilder::InsertionGuard guard(builder); result.addTypes(resultTypes); // Convert the reductions to integer attributes. @@ -3733,11 +3737,11 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result, // Create a region and a block for the body. auto *bodyRegion = result.addRegion(); - auto *body = new Block(); + Block *body = builder.createBlock(bodyRegion); + // Add all the block arguments. for (unsigned i = 0, e = steps.size(); i < e; ++i) body->addArgument(IndexType::get(builder.getContext()), result.location); - bodyRegion->push_back(body); if (resultTypes.empty()) ensureTerminator(*bodyRegion, builder, result.location); } diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index 5f583f36cd2cb..a3e3f80954efc 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -68,7 +68,7 @@ void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, void ExecuteOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ValueRange dependencies, ValueRange operands, BodyBuilderFn bodyBuilder) { - + OpBuilder::InsertionGuard guard(builder); result.addOperands(dependencies); result.addOperands(operands); @@ -87,26 +87,21 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result, // Add a body region with block arguments as unwrapped async value operands. Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); + Block *bodyBlock = builder.createBlock(bodyRegion); for (Value operand : operands) { auto valueType = llvm::dyn_cast(operand.getType()); - bodyBlock.addArgument(valueType ? valueType.getValueType() - : operand.getType(), - operand.getLoc()); + bodyBlock->addArgument(valueType ? valueType.getValueType() + : operand.getType(), + operand.getLoc()); } // Create the default terminator if the builder is not provided and if the // expected result is empty. Otherwise, leave this to the caller // because we don't know which values to return from the execute op. if (resultTypes.empty() && !bodyBuilder) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); builder.create(result.location, ValueRange()); } else if (bodyBuilder) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(builder, result.location, bodyBlock.getArguments()); + bodyBuilder(builder, result.location, bodyBlock->getArguments()); } } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 803f9047b5c16..5772089e5eedf 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -262,20 +262,20 @@ LogicalResult ExpressionOp::verify() { void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, Value ub, Value step, BodyBuilderFn bodyBuilder) { + OpBuilder::InsertionGuard g(builder); result.addOperands({lb, ub, step}); Type t = lb.getType(); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); - bodyBlock.addArgument(t, result.location); + Block *bodyBlock = builder.createBlock(bodyRegion); + bodyBlock->addArgument(t, result.location); // Create the default terminator if the builder is not provided. if (!bodyBuilder) { ForOp::ensureTerminator(*bodyRegion, builder, result.location); } else { OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(builder, result.location, bodyBlock.getArgument(0)); + builder.setInsertionPointToStart(bodyBlock); + bodyBuilder(builder, result.location, bodyBlock->getArgument(0)); } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 458bf83eac17f..95b0d6ef1ae2a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2166,11 +2166,10 @@ LogicalResult ShuffleVectorOp::verify() { //===----------------------------------------------------------------------===// // Add the entry block to the function. -Block *LLVMFuncOp::addEntryBlock() { +Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) { assert(empty() && "function already has an entry block"); - - auto *entry = new Block; - push_back(entry); + OpBuilder::InsertionGuard g(builder); + Block *entry = builder.createBlock(&getBody()); // FIXME: Allow passing in proper locations for the entry arguments. LLVMFunctionType type = getFunctionType(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 6fbf351455787..370dee4448eb4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -132,12 +132,10 @@ struct MoveInitOperandsToInput : public OpRewritePattern { newIndexingMaps, genericOp.getIteratorTypesArray(), /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + OpBuilder::InsertionGuard guard(rewriter); Region ®ion = newOp.getRegion(); - Block *block = new Block(); - region.push_back(block); + Block *block = rewriter.createBlock(®ion); IRMapping mapper; - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(block); for (auto bbarg : genericOp.getRegionInputArgs()) mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 3eb91190751ef..55ff33792de61 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -178,11 +178,9 @@ static void generateFusedElementwiseOpRegion( // Build the region of the fused op. Block &producerBlock = producer->getRegion(0).front(); Block &consumerBlock = consumer->getRegion(0).front(); - Block *fusedBlock = new Block(); - fusedOp.getRegion().push_back(fusedBlock); - IRMapping mapper; OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(fusedBlock); + Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); + IRMapping mapper; // 2. Add an index operation for every fused loop dimension and use the // `consumerToProducerLoopsMap` to map the producer indices. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 75c8cd3e1d95a..6bdbc89608921 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -275,14 +275,13 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, auto transposeOp = b.create(loc, resultTensorType, inputTensor, outputTensor, indexingMaps, iteratorTypes); - Region &body = transposeOp.getRegion(); - body.push_back(new Block()); - body.front().addArguments({elementType, elementType}, {loc, loc}); // Create the body of the transpose operation. OpBuilder::InsertionGuard g(b); - b.setInsertionPointToEnd(&body.front()); - b.create(loc, transposeOp.getRegion().front().getArgument(0)); + Region &body = transposeOp.getRegion(); + Block *bodyBlock = b.createBlock(&body, /*insertPt=*/{}, + {elementType, elementType}, {loc, loc}); + b.create(loc, bodyBlock->getArgument(0)); return transposeOp; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 93327a28234ea..03017afe95dbd 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1420,6 +1420,7 @@ OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() { void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, Value memref, ValueRange ivs) { + OpBuilder::InsertionGuard g(builder); result.addOperands(memref); result.addOperands(ivs); @@ -1428,7 +1429,7 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, result.addTypes(elementType); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block()); + builder.createBlock(bodyRegion); bodyRegion->addArgument(elementType, memref.getLoc()); } } diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index 580782043c81b..7170a899069ee 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -365,12 +365,11 @@ Block *LoopOp::getMergeBlock() { return &getBody().back(); } -void LoopOp::addEntryAndMergeBlock() { +void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) { assert(getBody().empty() && "entry and merge block already exist"); - getBody().push_back(new Block()); - auto *mergeBlock = new Block(); - getBody().push_back(mergeBlock); - OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); + OpBuilder::InsertionGuard g(builder); + builder.createBlock(&getBody()); + builder.createBlock(&getBody()); // Add a spirv.mlir.merge op into the merge block. builder.create(getLoc()); @@ -525,11 +524,10 @@ Block *SelectionOp::getMergeBlock() { return &getBody().back(); } -void SelectionOp::addMergeBlock() { +void SelectionOp::addMergeBlock(OpBuilder &builder) { assert(getBody().empty() && "entry and merge block already exist"); - auto *mergeBlock = new Block(); - getBody().push_back(mergeBlock); - OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); + OpBuilder::InsertionGuard guard(builder); + builder.createBlock(&getBody()); // Add a spirv.mlir.merge op into the merge block. builder.create(getLoc()); @@ -542,7 +540,7 @@ SelectionOp::createIfThen(Location loc, Value condition, auto selectionOp = builder.create(loc, spirv::SelectionControl::None); - selectionOp.addMergeBlock(); + selectionOp.addMergeBlock(builder); Block *mergeBlock = selectionOp.getMergeBlock(); Block *thenBlock = nullptr; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 4f829db1305c8..d9ee39a4e8dd3 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -375,15 +375,13 @@ void AssumingOp::inlineRegionIntoParent(AssumingOp &op, void AssumingOp::build( OpBuilder &builder, OperationState &result, Value witness, function_ref(OpBuilder &, Location)> bodyBuilder) { + OpBuilder::InsertionGuard g(builder); result.addOperands(witness); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); + builder.createBlock(bodyRegion); // Build body. - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); SmallVector yieldValues = bodyBuilder(builder, result.location); builder.create(result.location, yieldValues); @@ -1904,23 +1902,23 @@ bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, ValueRange initVals) { + OpBuilder::InsertionGuard g(builder); result.addOperands(shape); result.addOperands(initVals); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); - bodyBlock.addArgument(builder.getIndexType(), result.location); + Block *bodyBlock = builder.createBlock( + bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location); Type elementType; if (auto tensorType = llvm::dyn_cast(shape.getType())) elementType = tensorType.getElementType(); else elementType = SizeType::get(builder.getContext()); - bodyBlock.addArgument(elementType, shape.getLoc()); + bodyBlock->addArgument(elementType, shape.getLoc()); for (Value initVal : initVals) { - bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); + bodyBlock->addArgument(initVal.getType(), initVal.getLoc()); result.addTypes(initVal.getType()); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 3b9685b8ae1e0..9a483078a4a44 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -299,8 +299,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern { Block &prodBlock = prod.getRegion().front(); Block &consBlock = op.getRegion().front(); IRMapping mapper; - Block *fusedBlock = new Block(); - fusedOp.getRegion().push_back(fusedBlock); + Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); unsigned num = prodBlock.getNumArguments(); for (unsigned i = 0; i < num - 1; i++) addArg(mapper, fusedBlock, prodBlock.getArgument(i)); @@ -309,7 +308,6 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern { // Clone bodies of the producer and consumer in new evaluation order. auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); - rewriter.setInsertionPointToStart(fusedBlock); Value last; for (auto &op : prodBlock.without_terminator()) if (&op != acc) { diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 89e2e7ad52fa7..b9455ea41e64b 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1799,7 +1799,7 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { auto control = static_cast(selectionControl); auto selectionOp = builder.create(location, control); - selectionOp.addMergeBlock(); + selectionOp.addMergeBlock(builder); return selectionOp; } @@ -1811,7 +1811,7 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { auto control = static_cast(loopControl); auto loopOp = builder.create(location, control); - loopOp.addEntryAndMergeBlock(); + loopOp.addEntryAndMergeBlock(builder); return loopOp; } From f760caa46a3209a7e9db653a3b7981c99d192f63 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Wed, 28 Feb 2024 20:41:05 +0100 Subject: [PATCH 14/21] [mlir][EmitC] Add logical operators (#83123) This adds operations for the logical operators AND, NOT and OR. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 64 +++++++++++++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 30 +++++++++- mlir/test/Dialect/EmitC/invalid_ops.mlir | 24 ++++++++ mlir/test/Dialect/EmitC/ops.mlir | 7 +++ mlir/test/Target/Cpp/logical_operators.mlir | 14 +++++ 5 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Target/Cpp/logical_operators.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index f06bfe44121f6..8b2c775350ffc 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -658,6 +658,70 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { let assemblyFormat = "$value attr-dict `:` type($result)"; } +def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> { + let summary = "Logical and operation"; + let description = [{ + With the `logical_and` operation the logical operator && (and) can + be applied. + + Example: + + ```mlir + %0 = emitc.logical_and %arg0, %arg1 : i32, i32 + ``` + ```c++ + // Code emitted for the operation above. + bool v3 = v1 && v2; + ``` + }]; + + let results = (outs I1); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def EmitC_LogicalNotOp : EmitC_Op<"logical_not", []> { + let summary = "Logical not operation"; + let description = [{ + With the `logical_not` operation the logical operator ! (negation) can + be applied. + + Example: + + ```mlir + %0 = emitc.logical_not %arg0 : i32 + ``` + ```c++ + // Code emitted for the operation above. + bool v2 = !v1; + ``` + }]; + + let arguments = (ins AnyType); + let results = (outs I1); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", []> { + let summary = "Logical or operation"; + let description = [{ + With the `logical_or` operation the logical operator || (inclusive or) + can be applied. + + Example: + + ```mlir + %0 = emitc.logical_or %arg0, %arg1 : i32, i32 + ``` + ```c++ + // Code emitted for the operation above. + bool v3 = v1 || v2; + ``` + }]; + + let results = (outs I1); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + def EmitC_MulOp : EmitC_BinaryOp<"mul", []> { let summary = "Multiplication operation"; let description = [{ diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 9b9ce817c1639..ccd5d105b0985 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -640,6 +640,33 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LogicalAndOp logicalAndOp) { + Operation *operation = logicalAndOp.getOperation(); + return printBinaryOperation(emitter, operation, "&&"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LogicalNotOp logicalNotOp) { + raw_ostream &os = emitter.ostream(); + + if (failed(emitter.emitAssignPrefix(*logicalNotOp.getOperation()))) + return failure(); + + os << "!"; + + if (failed(emitter.emitOperand(logicalNotOp.getOperand()))) + return failure(); + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LogicalOrOp logicalOrOp) { + Operation *operation = logicalOrOp.getOperation(); + return printBinaryOperation(emitter, operation, "||"); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { raw_indented_ostream &os = emitter.ostream(); @@ -1318,7 +1345,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp, emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp, - emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, + emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp, + emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp, emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 2c763a982569e..51b68eecbbd56 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -339,3 +339,27 @@ emitc.declare_func @bar // expected-error@+1 {{'emitc.declare_func' op requires attribute 'sym_name'}} "emitc.declare_func"() : () -> () + +// ----- + +func.func @logical_and_resulterror(%arg0: i32, %arg1: i32) { + // expected-error @+1 {{'emitc.logical_and' op result #0 must be 1-bit signless integer, but got 'i32'}} + %0 = "emitc.logical_and"(%arg0, %arg1) : (i32, i32) -> i32 + return +} + +// ----- + +func.func @logical_not_resulterror(%arg0: i32) { + // expected-error @+1 {{'emitc.logical_not' op result #0 must be 1-bit signless integer, but got 'i32'}} + %0 = "emitc.logical_not"(%arg0) : (i32) -> i32 + return +} + +// ----- + +func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) { + // expected-error @+1 {{'emitc.logical_or' op result #0 must be 1-bit signless integer, but got 'i32'}} + %0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32 + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 5c0ecf998741f..8c7cc994c5b02 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -117,6 +117,13 @@ func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emit return } +func.func @logical(%arg0: i32, %arg1: i32) { + %0 = emitc.logical_and %arg0, %arg1 : i32, i32 + %1 = emitc.logical_not %arg0 : i32 + %2 = emitc.logical_or %arg0, %arg1 : i32, i32 + return +} + func.func @test_if(%arg0: i1, %arg1: f32) { emitc.if %arg0 { %0 = emitc.call_opaque "func_const"(%arg1) : (f32) -> i32 diff --git a/mlir/test/Target/Cpp/logical_operators.mlir b/mlir/test/Target/Cpp/logical_operators.mlir new file mode 100644 index 0000000000000..7083dc218fca9 --- /dev/null +++ b/mlir/test/Target/Cpp/logical_operators.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @logical(%arg0: i32, %arg1: i32) -> () { + %0 = emitc.logical_and %arg0, %arg1 : i32, i32 + %1 = emitc.logical_not %arg0 : i32 + %2 = emitc.logical_or %arg0, %arg1 : i32, i32 + + return +} + +// CHECK-LABEL: void logical +// CHECK-NEXT: bool [[V2:[^ ]*]] = [[V0:[^ ]*]] && [[V1:[^ ]*]]; +// CHECK-NEXT: bool [[V3:[^ ]*]] = ![[V0]]; +// CHECK-NEXT: bool [[V4:[^ ]*]] = [[V0]] || [[V1]]; From 6aa8cda28a7df6e20f9c1c486760ce715b4f9767 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 1 Mar 2024 13:21:11 +0100 Subject: [PATCH 15/21] [mlir][EmitC] Add bitwise operators (#83387) This adds operations for bitwise operators. Furthermore, an UnaryOp class and a helper to print unary operations are introduced. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 119 +++++++++++++++++++- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 72 ++++++++++-- mlir/test/Dialect/EmitC/ops.mlir | 10 ++ mlir/test/Target/Cpp/bitwise_operators.mlir | 20 ++++ 4 files changed, 207 insertions(+), 14 deletions(-) create mode 100644 mlir/test/Target/Cpp/bitwise_operators.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 8b2c775350ffc..3e4dcf6b28251 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -31,6 +31,14 @@ include "mlir/IR/RegionKindInterface.td" class EmitC_Op traits = []> : Op; +// Base class for unary operations. +class EmitC_UnaryOp traits = []> : + EmitC_Op { + let arguments = (ins AnyType); + let results = (outs AnyType); + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; +} + // Base class for binary operations. class EmitC_BinaryOp traits = []> : EmitC_Op { @@ -95,6 +103,114 @@ def EmitC_ApplyOp : EmitC_Op<"apply", []> { let hasVerifier = 1; } +def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", []> { + let summary = "Bitwise and operation"; + let description = [{ + With the `bitwise_and` operation the bitwise operator & (and) can + be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_and %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 & v2; + ``` + }]; +} + +def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", []> { + let summary = "Bitwise left shift operation"; + let description = [{ + With the `bitwise_left_shift` operation the bitwise operator << + (left shift) can be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_left_shift %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 << v2; + ``` + }]; +} + +def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", []> { + let summary = "Bitwise not operation"; + let description = [{ + With the `bitwise_not` operation the bitwise operator ~ (not) can + be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_not %arg0 : (i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v2 = ~v1; + ``` + }]; +} + +def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", []> { + let summary = "Bitwise or operation"; + let description = [{ + With the `bitwise_or` operation the bitwise operator | (or) + can be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_or %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 | v2; + ``` + }]; +} + +def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", []> { + let summary = "Bitwise right shift operation"; + let description = [{ + With the `bitwise_right_shift` operation the bitwise operator >> + (right shift) can be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_right_shift %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 >> v2; + ``` + }]; +} + +def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", []> { + let summary = "Bitwise xor operation"; + let description = [{ + With the `bitwise_xor` operation the bitwise operator ^ (xor) + can be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_xor %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 ^ v2; + ``` + }]; +} + def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> { let summary = "Opaque call operation"; let description = [{ @@ -679,7 +795,7 @@ def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> { let assemblyFormat = "operands attr-dict `:` type(operands)"; } -def EmitC_LogicalNotOp : EmitC_Op<"logical_not", []> { +def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", []> { let summary = "Logical not operation"; let description = [{ With the `logical_not` operation the logical operator ! (negation) can @@ -696,7 +812,6 @@ def EmitC_LogicalNotOp : EmitC_Op<"logical_not", []> { ``` }]; - let arguments = (ins AnyType); let results = (outs I1); let assemblyFormat = "operands attr-dict `:` type(operands)"; } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index ccd5d105b0985..97a7f556821ed 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -374,6 +374,22 @@ static LogicalResult printBinaryOperation(CppEmitter &emitter, return success(); } +static LogicalResult printUnaryOperation(CppEmitter &emitter, + Operation *operation, + StringRef unaryOperator) { + raw_ostream &os = emitter.ostream(); + + if (failed(emitter.emitAssignPrefix(*operation))) + return failure(); + + os << unaryOperator; + + if (failed(emitter.emitOperand(operation->getOperand(0)))) + return failure(); + + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) { Operation *operation = addOp.getOperation(); @@ -601,6 +617,44 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::BitwiseAndOp bitwiseAndOp) { + Operation *operation = bitwiseAndOp.getOperation(); + return printBinaryOperation(emitter, operation, "&"); +} + +static LogicalResult +printOperation(CppEmitter &emitter, + emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) { + Operation *operation = bitwiseLeftShiftOp.getOperation(); + return printBinaryOperation(emitter, operation, "<<"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::BitwiseNotOp bitwiseNotOp) { + Operation *operation = bitwiseNotOp.getOperation(); + return printUnaryOperation(emitter, operation, "~"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::BitwiseOrOp bitwiseOrOp) { + Operation *operation = bitwiseOrOp.getOperation(); + return printBinaryOperation(emitter, operation, "|"); +} + +static LogicalResult +printOperation(CppEmitter &emitter, + emitc::BitwiseRightShiftOp bitwiseRightShiftOp) { + Operation *operation = bitwiseRightShiftOp.getOperation(); + return printBinaryOperation(emitter, operation, ">>"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::BitwiseXorOp bitwiseXorOp) { + Operation *operation = bitwiseXorOp.getOperation(); + return printBinaryOperation(emitter, operation, "^"); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { raw_ostream &os = emitter.ostream(); Operation &op = *castOp.getOperation(); @@ -648,17 +702,8 @@ static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter, emitc::LogicalNotOp logicalNotOp) { - raw_ostream &os = emitter.ostream(); - - if (failed(emitter.emitAssignPrefix(*logicalNotOp.getOperation()))) - return failure(); - - os << "!"; - - if (failed(emitter.emitOperand(logicalNotOp.getOperand()))) - return failure(); - - return success(); + Operation *operation = logicalNotOp.getOperation(); + return printUnaryOperation(emitter, operation, "!"); } static LogicalResult printOperation(CppEmitter &emitter, @@ -1341,7 +1386,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { .Case( [&](auto op) { return printOperation(*this, op); }) // EmitC ops. - .Case, %arg1: i32, %arg2: !emitc.opaque< return } +func.func @bitwise(%arg0: i32, %arg1: i32) -> () { + %0 = emitc.bitwise_and %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.bitwise_left_shift %arg0, %arg1 : (i32, i32) -> i32 + %2 = emitc.bitwise_not %arg0 : (i32) -> i32 + %3 = emitc.bitwise_or %arg0, %arg1 : (i32, i32) -> i32 + %4 = emitc.bitwise_right_shift %arg0, %arg1 : (i32, i32) -> i32 + %5 = emitc.bitwise_xor %arg0, %arg1 : (i32, i32) -> i32 + return +} + func.func @div_int(%arg0: i32, %arg1: i32) { %1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32 return diff --git a/mlir/test/Target/Cpp/bitwise_operators.mlir b/mlir/test/Target/Cpp/bitwise_operators.mlir new file mode 100644 index 0000000000000..e666359fc82c9 --- /dev/null +++ b/mlir/test/Target/Cpp/bitwise_operators.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @bitwise(%arg0: i32, %arg1: i32) -> () { + %0 = emitc.bitwise_and %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.bitwise_left_shift %arg0, %arg1 : (i32, i32) -> i32 + %2 = emitc.bitwise_not %arg0 : (i32) -> i32 + %3 = emitc.bitwise_or %arg0, %arg1 : (i32, i32) -> i32 + %4 = emitc.bitwise_right_shift %arg0, %arg1 : (i32, i32) -> i32 + %5 = emitc.bitwise_xor %arg0, %arg1 : (i32, i32) -> i32 + + return +} + +// CHECK-LABEL: void bitwise +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = [[V0:[^ ]*]] & [[V1:[^ ]*]]; +// CHECK-NEXT: int32_t [[V3:[^ ]*]] = [[V0]] << [[V1]]; +// CHECK-NEXT: int32_t [[V4:[^ ]*]] = ~[[V0]]; +// CHECK-NEXT: int32_t [[V5:[^ ]*]] = [[V0]] | [[V1]]; +// CHECK-NEXT: int32_t [[V6:[^ ]*]] = [[V0]] >> [[V1]]; +// CHECK-NEXT: int32_t [[V7:[^ ]*]] = [[V0]] ^ [[V1]]; From 138bf9cc000b22e547090e605f16e92931a1c92d Mon Sep 17 00:00:00 2001 From: Kirill Chibisov Date: Wed, 6 Mar 2024 15:41:20 +0400 Subject: [PATCH 16/21] [mlir][emitc] Fix `emitc.expression` example (#84060) Make it use and refer to `emitc.yield` and also fix type issues. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 3e4dcf6b28251..aa75f64d90fd2 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -389,17 +389,17 @@ def EmitC_ExpressionOp : EmitC_Op<"expression", As the operation is to be emitted as a C expression, the operations within its body must form a single Def-Use tree of emitc ops whose result is - yielded by a terminating `yield`. + yielded by a terminating `emitc.yield`. Example: ```mlir - %r = emitc.expression : () -> i32 { + %r = emitc.expression : i32 { %0 = emitc.add %a, %b : (i32, i32) -> i32 - %1 = emitc.call "foo"(%0) : () -> i32 + %1 = emitc.call_opaque "foo"(%0) : (i32) -> i32 %2 = emitc.add %c, %d : (i32, i32) -> i32 %3 = emitc.mul %1, %2 : (i32, i32) -> i32 - yield %3 + emitc.yield %3 : i32 } ``` @@ -409,9 +409,9 @@ def EmitC_ExpressionOp : EmitC_Op<"expression", int32_t v7 = foo(v1 + v2) * (v3 + v4); ``` - The operations allowed within expression body are emitc.add, emitc.apply, - emitc.call, emitc.cast, emitc.cmp, emitc.div, emitc.mul, emitc.rem and - emitc.sub. + The operations allowed within expression body are `emitc.add`, + `emitc.apply`, `emitc.call_opaque`, `emitc.cast`, `emitc.cmp`, `emitc.div`, + `emitc.mul`, `emitc.rem`, and `emitc.sub`. When specified, the optional `do_not_inline` indicates that the expression is to be emitted as seen above, i.e. as the rhs of an EmitC SSA value From 5f152558e00f57acbb8a856c9043c18d99b478a1 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 7 Mar 2024 08:37:47 +0100 Subject: [PATCH 17/21] [mlir][EmitC] Introduce a `CExpression` trait (#84177) This adds a `CExpression` trait and replaces the `isCExpression()` function. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 1 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 39 +++++++++---------- .../mlir/Dialect/EmitC/IR/EmitCTraits.h | 30 ++++++++++++++ mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 3 +- .../EmitC/Transforms/FormExpressions.cpp | 2 +- .../Dialect/EmitC/Transforms/Transforms.cpp | 3 +- 6 files changed, 54 insertions(+), 24 deletions(-) create mode 100644 mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index 3d38744527d59..1f0df3cb336b1 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_EMITC_IR_EMITC_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/EmitC/IR/EmitCTraits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index aa75f64d90fd2..3130e8f80d402 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -47,11 +47,14 @@ class EmitC_BinaryOp traits = []> : let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } +// EmitC OpTrait +def CExpression : NativeOpTrait<"emitc::CExpression">; + // Types only used in binary arithmetic operations. def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>; def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>; -def EmitC_AddOp : EmitC_BinaryOp<"add", []> { +def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> { let summary = "Addition operation"; let description = [{ With the `add` operation the arithmetic operator + (addition) can @@ -74,7 +77,7 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> { let hasVerifier = 1; } -def EmitC_ApplyOp : EmitC_Op<"apply", []> { +def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> { let summary = "Apply operation"; let description = [{ With the `apply` operation the operators & (address of) and * (contents of) @@ -211,7 +214,7 @@ def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", []> { }]; } -def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> { +def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { let summary = "Opaque call operation"; let description = [{ The `call_opaque` operation represents a C++ function call. The callee @@ -257,10 +260,10 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> { let hasVerifier = 1; } -def EmitC_CastOp : EmitC_Op<"cast", [ - DeclareOpInterfaceMethods, - SameOperandsAndResultShape - ]> { +def EmitC_CastOp : EmitC_Op<"cast", + [CExpression, + DeclareOpInterfaceMethods, + SameOperandsAndResultShape]> { let summary = "Cast operation"; let description = [{ The `cast` operation performs an explicit type conversion and is emitted @@ -284,7 +287,7 @@ def EmitC_CastOp : EmitC_Op<"cast", [ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; } -def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> { +def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> { let summary = "Comparison operation"; let description = [{ With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=> @@ -355,7 +358,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { let hasVerifier = 1; } -def EmitC_DivOp : EmitC_BinaryOp<"div", []> { +def EmitC_DivOp : EmitC_BinaryOp<"div", [CExpression]> { let summary = "Division operation"; let description = [{ With the `div` operation the arithmetic operator / (division) can @@ -409,9 +412,8 @@ def EmitC_ExpressionOp : EmitC_Op<"expression", int32_t v7 = foo(v1 + v2) * (v3 + v4); ``` - The operations allowed within expression body are `emitc.add`, - `emitc.apply`, `emitc.call_opaque`, `emitc.cast`, `emitc.cmp`, `emitc.div`, - `emitc.mul`, `emitc.rem`, and `emitc.sub`. + The operations allowed within expression body are EmitC operations with the + CExpression trait. When specified, the optional `do_not_inline` indicates that the expression is to be emitted as seen above, i.e. as the rhs of an EmitC SSA value @@ -427,14 +429,9 @@ def EmitC_ExpressionOp : EmitC_Op<"expression", let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region"; let extraClassDeclaration = [{ - static bool isCExpression(Operation &op) { - return isa(op); - } bool hasSideEffects() { auto predicate = [](Operation &op) { - assert(isCExpression(op) && "Expected a C expression"); + assert(op.hasTrait() && "Expected a C expression"); // Conservatively assume calls to read and write memory. if (isa(op)) return true; @@ -837,7 +834,7 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", []> { let assemblyFormat = "operands attr-dict `:` type(operands)"; } -def EmitC_MulOp : EmitC_BinaryOp<"mul", []> { +def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> { let summary = "Multiplication operation"; let description = [{ With the `mul` operation the arithmetic operator * (multiplication) can @@ -861,7 +858,7 @@ def EmitC_MulOp : EmitC_BinaryOp<"mul", []> { let results = (outs FloatIntegerIndexOrOpaqueType); } -def EmitC_RemOp : EmitC_BinaryOp<"rem", []> { +def EmitC_RemOp : EmitC_BinaryOp<"rem", [CExpression]> { let summary = "Remainder operation"; let description = [{ With the `rem` operation the arithmetic operator % (remainder) can @@ -883,7 +880,7 @@ def EmitC_RemOp : EmitC_BinaryOp<"rem", []> { let results = (outs IntegerIndexOrOpaqueType); } -def EmitC_SubOp : EmitC_BinaryOp<"sub", []> { +def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> { let summary = "Subtraction operation"; let description = [{ With the `sub` operation the arithmetic operator - (subtraction) can diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h new file mode 100644 index 0000000000000..c1602dfce4b48 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h @@ -0,0 +1,30 @@ +//===- EmitCTraits.h - EmitC trait definitions ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares C++ classes for some of the traits used in the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H +#define MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace OpTrait { +namespace emitc { + +template +class CExpression : public TraitBase {}; + +} // namespace emitc +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 5772089e5eedf..5db0777bc30ab 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/IR/EmitCTraits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" @@ -245,7 +246,7 @@ LogicalResult ExpressionOp::verify() { return emitOpError("requires yielded type to match return type"); for (Operation &op : region.front().without_terminator()) { - if (!isCExpression(op)) + if (!op.hasTrait()) return emitOpError("contains an unsupported operation"); if (op.getNumResults() != 1) return emitOpError("requires exactly one result for each operation"); diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp index 21212155ffb22..5b03f81b305fd 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp @@ -36,7 +36,7 @@ struct FormExpressionsPass // Wrap each C operator op with an expression op. OpBuilder builder(context); auto matchFun = [&](Operation *op) { - if (emitc::ExpressionOp::isCExpression(*op)) + if (op->hasTrait()) createExpression(op, builder); }; rootOp->walk(matchFun); diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp index 88b691b50f325..87350ecdceaaa 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp @@ -16,7 +16,8 @@ namespace mlir { namespace emitc { ExpressionOp createExpression(Operation *op, OpBuilder &builder) { - assert(ExpressionOp::isCExpression(*op) && "Expected a C expression"); + assert(op->hasTrait() && + "Expected a C expression"); // Create an expression yielding the value returned by op. assert(op->getNumResults() == 1 && "Expected exactly one result"); From 4b41a63c6557caf2c8aeaa9fa111231e2b09a9d9 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 7 Mar 2024 15:48:11 +0100 Subject: [PATCH 18/21] [mlir][EmitC] Allow further ops within expressions (#84284) This adds the `CExpression` trait to additional ops to allow to use these ops within the expression operation. Furthermore, the operator precedence is defined for those ops. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 22 +++++++++-------- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 26 ++++++++++++++------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 3130e8f80d402..74eaa662780e8 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -106,7 +106,7 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> { let hasVerifier = 1; } -def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", []> { +def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", [CExpression]> { let summary = "Bitwise and operation"; let description = [{ With the `bitwise_and` operation the bitwise operator & (and) can @@ -124,7 +124,8 @@ def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", []> { }]; } -def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", []> { +def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", + [CExpression]> { let summary = "Bitwise left shift operation"; let description = [{ With the `bitwise_left_shift` operation the bitwise operator << @@ -142,7 +143,7 @@ def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", []> { }]; } -def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", []> { +def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", [CExpression]> { let summary = "Bitwise not operation"; let description = [{ With the `bitwise_not` operation the bitwise operator ~ (not) can @@ -160,7 +161,7 @@ def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", []> { }]; } -def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", []> { +def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", [CExpression]> { let summary = "Bitwise or operation"; let description = [{ With the `bitwise_or` operation the bitwise operator | (or) @@ -178,7 +179,8 @@ def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", []> { }]; } -def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", []> { +def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", + [CExpression]> { let summary = "Bitwise right shift operation"; let description = [{ With the `bitwise_right_shift` operation the bitwise operator >> @@ -196,7 +198,7 @@ def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", []> { }]; } -def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", []> { +def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", [CExpression]> { let summary = "Bitwise xor operation"; let description = [{ With the `bitwise_xor` operation the bitwise operator ^ (xor) @@ -515,7 +517,7 @@ def EmitC_ForOp : EmitC_Op<"for", } def EmitC_CallOp : EmitC_Op<"call", - [CallOpInterface, + [CallOpInterface, CExpression, DeclareOpInterfaceMethods]> { let summary = "call operation"; let description = [{ @@ -771,7 +773,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { let assemblyFormat = "$value attr-dict `:` type($result)"; } -def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> { +def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", [CExpression]> { let summary = "Logical and operation"; let description = [{ With the `logical_and` operation the logical operator && (and) can @@ -792,7 +794,7 @@ def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> { let assemblyFormat = "operands attr-dict `:` type(operands)"; } -def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", []> { +def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", [CExpression]> { let summary = "Logical not operation"; let description = [{ With the `logical_not` operation the logical operator ! (negation) can @@ -813,7 +815,7 @@ def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", []> { let assemblyFormat = "operands attr-dict `:` type(operands)"; } -def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", []> { +def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> { let summary = "Logical or operation"; let description = [{ With the `logical_or` operation the logical operator || (inclusive or) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 97a7f556821ed..9d8e02586893a 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -71,9 +71,17 @@ inline LogicalResult interleaveCommaWithError(const Container &c, /// imply higher precedence. static FailureOr getOperatorPrecedence(Operation *operation) { return llvm::TypeSwitch>(operation) - .Case([&](auto op) { return 11; }) - .Case([&](auto op) { return 13; }) - .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) { return 7; }) + .Case([&](auto op) { return 11; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) { return 5; }) + .Case([&](auto op) { return 11; }) + .Case([&](auto op) { return 6; }) + .Case([&](auto op) { return 16; }) + .Case([&](auto op) { return 16; }) + .Case([&](auto op) { return 15; }) .Case([&](auto op) -> FailureOr { switch (op.getPredicate()) { case emitc::CmpPredicate::eq: @@ -89,11 +97,13 @@ static FailureOr getOperatorPrecedence(Operation *operation) { } return op->emitError("unsupported cmp predicate"); }) - .Case([&](auto op) { return 12; }) - .Case([&](auto op) { return 12; }) - .Case([&](auto op) { return 12; }) - .Case([&](auto op) { return 11; }) - .Case([&](auto op) { return 14; }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 4; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) { return 3; }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 12; }) .Default([](auto op) { return op->emitError("unsupported operation"); }); } From add43c2b9ab6104bd7ad8e76090b374fdd79afad Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 8 Mar 2024 08:34:56 +0100 Subject: [PATCH 19/21] [mlir][EmitC] Add `unary_{minus,plus}` operators (#84329) This adds operations for the unary minus and the unary plus operator. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 36 +++++++++++++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 18 +++++++++-- mlir/test/Dialect/EmitC/ops.mlir | 6 ++++ mlir/test/Target/Cpp/unary_operators.mlir | 12 +++++++ 4 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 mlir/test/Target/Cpp/unary_operators.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 74eaa662780e8..bcdd001528c46 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -908,6 +908,42 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> { let hasVerifier = 1; } +def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> { + let summary = "Unary minus operation"; + let description = [{ + With the `unary_minus` operation the unary operator - (minus) can be + applied. + + Example: + + ```mlir + %0 = emitc.unary_plus %arg0 : (i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v2 = -v1; + ``` + }]; +} + +def EmitC_UnaryPlusOp : EmitC_UnaryOp<"unary_plus", [CExpression]> { + let summary = "Unary plus operation"; + let description = [{ + With the `unary_plus` operation the unary operator + (plus) can be + applied. + + Example: + + ```mlir + %0 = emitc.unary_plus %arg0 : (i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v2 = +v1; + ``` + }]; +} + def EmitC_VariableOp : EmitC_Op<"variable", []> { let summary = "Variable operation"; let description = [{ diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 9d8e02586893a..46bf563305152 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -104,6 +104,8 @@ static FailureOr getOperatorPrecedence(Operation *operation) { .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) { return 15; }) .Default([](auto op) { return op->emitError("unsupported operation"); }); } @@ -665,6 +667,18 @@ static LogicalResult printOperation(CppEmitter &emitter, return printBinaryOperation(emitter, operation, "^"); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::UnaryPlusOp unaryPlusOp) { + Operation *operation = unaryPlusOp.getOperation(); + return printUnaryOperation(emitter, operation, "+"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::UnaryMinusOp unaryMinusOp) { + Operation *operation = unaryMinusOp.getOperation(); + return printUnaryOperation(emitter, operation, "-"); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { raw_ostream &os = emitter.ostream(); Operation &op = *castOp.getOperation(); @@ -1405,8 +1419,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, - emitc::SubOp, emitc::SubscriptOp, emitc::VariableOp, - emitc::VerbatimOp>( + emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp, + emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 1b9a104ffc07e..02294d13cef76 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -134,6 +134,12 @@ func.func @logical(%arg0: i32, %arg1: i32) { return } +func.func @unary(%arg0: i32) { + %0 = emitc.unary_minus %arg0 : (i32) -> i32 + %1 = emitc.unary_plus %arg0 : (i32) -> i32 + return +} + func.func @test_if(%arg0: i1, %arg1: f32) { emitc.if %arg0 { %0 = emitc.call_opaque "func_const"(%arg1) : (f32) -> i32 diff --git a/mlir/test/Target/Cpp/unary_operators.mlir b/mlir/test/Target/Cpp/unary_operators.mlir new file mode 100644 index 0000000000000..8a89437a41cc5 --- /dev/null +++ b/mlir/test/Target/Cpp/unary_operators.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @unary(%arg0: i32) -> () { + %0 = emitc.unary_minus %arg0 : (i32) -> i32 + %1 = emitc.unary_plus %arg0 : (i32) -> i32 + + return +} + +// CHECK-LABEL: void unary +// CHECK-NEXT: int32_t [[V1:[^ ]*]] = -[[V0:[^ ]*]]; +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = +[[V0]]; From 5f6e7a5ffeaf9aaaf94df3c5c7b53c0514808f5b Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 7 Mar 2024 11:34:11 +0100 Subject: [PATCH 20/21] [mlir][EmitC] Add Arith to EmitC conversions (#84151) This adds patterns and a pass to convert the Arith dialect to EmitC. For now, this covers arithemtic binary ops operating on floating point types. It is not checked within the patterns whether the types, such as the Tensor type, are supported in the respective EmitC operations. If unsupported types should be converted, the conversion will fail anyway because no legal EmitC operation can be created. This can clearly be improved in a follow up, also resulting in better error messages. Functions for such checks should not solely be used in the conversions and should also be (re)used in the verifier. --- .../Conversion/ArithToEmitC/ArithToEmitC.h | 12 +- .../ArithToEmitC/ArithToEmitCPass.h | 21 ++++ mlir/include/mlir/Conversion/Passes.h | 2 +- mlir/include/mlir/Conversion/Passes.td | 13 ++- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 106 +++++------------- .../ArithToEmitC/ArithToEmitCPass.cpp | 53 +++++++++ .../Conversion/ArithToEmitC/CMakeLists.txt | 13 +-- .../ArithToEmitC/arith-to-emitc.mlir | 29 ++--- .../llvm-project-overlay/mlir/BUILD.bazel | 27 +++++ 9 files changed, 163 insertions(+), 113 deletions(-) create mode 100644 mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h create mode 100644 mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h index 43322ac7f51f6..9cb43689d1ce6 100644 --- a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h @@ -1,22 +1,20 @@ -//===- ArithToEmitC.h - Convert Arith to EmitC ----------------------------===// +//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// + #ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H #define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H -#include "mlir/Pass/Pass.h" - namespace mlir { class RewritePatternSet; +class TypeConverter; -#define GEN_PASS_DECL_ARITHTOEMITCCONVERSIONPASS -#include "mlir/Conversion/Passes.h.inc" - -void populateArithToEmitCConversionPatterns(RewritePatternSet &patterns); +void populateArithToEmitCPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); } // namespace mlir #endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h new file mode 100644 index 0000000000000..6b98fed7185ea --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h @@ -0,0 +1,21 @@ +//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H +#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTARITHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index f334ec7a592f8..716b59e3ebea5 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -12,7 +12,7 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" -#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 2abd4b4b94f9d..b4693dbdf10b9 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -125,17 +125,20 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> { }]; let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"]; + + let options = [ + Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool", + /*default=*/"false", + "Use saturating truncation for 8-bit float types">, + ]; } //===----------------------------------------------------------------------===// // ArithToEmitC //===----------------------------------------------------------------------===// -def ArithToEmitCConversionPass : Pass<"convert-arith-to-emitc"> { - let summary = "Convert Arith ops to EmitC ops"; - let description = [{ - Convert `arith` operations to operations in the `emitc` dialect. - }]; +def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> { + let summary = "Convert Arith dialect to EmitC dialect"; let dependentDialects = ["emitc::EmitCDialect"]; } diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 648fd2b4af0b7..6909534d4790f 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -1,4 +1,4 @@ -//===- ArithToEmitC.cpp - Arith to EmitC conversion -----------------------===// +//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file implements a pass to convert arith ops into emitc ops. +// This file implements patterns to convert the Arith dialect to the EmitC +// dialect. // //===----------------------------------------------------------------------===// @@ -14,91 +15,46 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -namespace mlir { -#define GEN_PASS_DEF_ARITHTOEMITCCONVERSIONPASS -#include "mlir/Conversion/Passes.h.inc" -} // namespace mlir - using namespace mlir; -namespace { - -static bool isConvertibleToEmitC(Type type) { - Type baseType = type; - if (auto tensorType = dyn_cast(type)) { - if (!tensorType.hasRank() || !tensorType.hasStaticShape()) { - return false; - } - baseType = tensorType.getElementType(); - } - - if (isa(baseType)) { - return true; - } - - if (auto intType = dyn_cast(baseType)) { - switch (intType.getWidth()) { - case 1: - case 8: - case 16: - case 32: - case 64: - return true; - } - return false; - } - - if (auto floatType = dyn_cast(baseType)) { - return floatType.isF32() || floatType.isF64(); - } - - return false; -} +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// -class ArithConstantOpConversionPattern - : public OpRewritePattern { +namespace { +template +class ArithOpConversion final : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ConstantOp arithConst, - PatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { - auto constantType = arithConst.getType(); - if (!isConvertibleToEmitC(constantType)) { - return rewriter.notifyMatchFailure(arithConst.getLoc(), - "Type cannot be converted to emitc"); - } + rewriter.template replaceOpWithNewOp(arithOp, arithOp.getType(), + adaptor.getOperands()); - rewriter.replaceOpWithNewOp(arithConst, constantType, - arithConst.getValue()); return success(); } }; - -struct ConvertArithToEmitCPass - : public impl::ArithToEmitCConversionPassBase { -public: - void runOnOperation() override { - - ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addLegalDialect(); - RewritePatternSet patterns(&getContext()); - populateArithToEmitCConversionPatterns(patterns); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - signalPassFailure(); - } - } -}; - } // namespace -void mlir::populateArithToEmitCConversionPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + + // clang-format off + patterns.add< + ArithOpConversion, + ArithOpConversion, + ArithOpConversion, + ArithOpConversion + >(typeConverter, ctx); + // clang-format on } diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp new file mode 100644 index 0000000000000..b377c063a7aa0 --- /dev/null +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp @@ -0,0 +1,53 @@ +//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert the Arith dialect to the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTARITHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertArithToEmitC + : public impl::ConvertArithToEmitCBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertArithToEmitC::runOnOperation() { + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalDialect(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + + populateArithToEmitCPatterns(typeConverter, patterns); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt index c1bb6d71310ed..a3784f47c3bc2 100644 --- a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt +++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt @@ -1,5 +1,6 @@ -add_mlir_conversion_library(ArithToEmitC +add_mlir_conversion_library(MLIRArithToEmitC ArithToEmitC.cpp + ArithToEmitCPass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC @@ -7,11 +8,9 @@ add_mlir_conversion_library(ArithToEmitC DEPENDS MLIRConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC - MLIREmitCDialect MLIRArithDialect - MLIRTransforms -) + MLIREmitCDialect + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 2583dd832c314..6a56474a5c48b 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -1,21 +1,14 @@ -// RUN: mlir-opt -split-input-file -convert-arith-to-emitc %s | FileCheck %s +// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s + +func.func @arith_ops(%arg0: f32, %arg1: f32) { + // CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32 + %0 = arith.addf %arg0, %arg1 : f32 + // CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32 + %1 = arith.divf %arg0, %arg1 : f32 + // CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32 + %2 = arith.mulf %arg0, %arg1 : f32 + // CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32 + %3 = arith.subf %arg0, %arg1 : f32 -// CHECK-LABEL: arith_constants -func.func @arith_constants() { - // CHECK: emitc.constant - // CHECK-SAME: value = 0 : index - %c_index = arith.constant 0 : index - // CHECK: emitc.constant - // CHECK-SAME: value = 0 : i32 - %c_signless_int_32 = arith.constant 0 : i32 - // CHECK: emitc.constant - // CHECK-SAME: value = 0.{{0+}}e+00 : f32 - %c_float_32 = arith.constant 0.0 : f32 - // CHECK: emitc.constant - // CHECK-SAME: value = dense<0> : tensor - %c_tensor_single_value = arith.constant dense<0> : tensor - // CHECK: emitc.constant - // CHECK-SAME: value{{.*}}[1, 2], [-3, 9], [0, 0], [2, -1]{{.*}}tensor<4x2xi64> - %c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64> return } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 1ecca1df0ea47..417db2b3b33c2 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3841,6 +3841,7 @@ cc_library( ":AMDGPUToROCDL", ":AffineToStandard", ":ArithToAMDGPU", + ":ArithToEmitC", ":ArithToLLVM", ":ArithToSPIRV", ":ArmNeon2dToIntr", @@ -7967,6 +7968,32 @@ cc_library( ], ) +cc_library( + name = "ArithToEmitC", + srcs = glob([ + "lib/Conversion/ArithToEmitC/*.cpp", + "lib/Conversion/ArithToEmitC/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/ArithToEmitC/*.h", + ]), + includes = [ + "include", + "lib/Conversion/ArithToEmitC", + ], + deps = [ + ":ArithDialect", + ":ConversionPassIncGen", + ":EmitCDialect", + ":IR", + ":Pass", + ":Support", + ":TransformUtils", + ":Transforms", + "//llvm:Support", + ], +) + cc_library( name = "ArithToLLVM", srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]), From 86a976bcce6b81c5ef47592b73c18cd5805b1c33 Mon Sep 17 00:00:00 2001 From: Tina Jung <126699487+TinaAMD@users.noreply.github.com> Date: Fri, 8 Mar 2024 09:16:10 +0100 Subject: [PATCH 21/21] [mlir][emitc] Arith to EmitC conversion: constants (#83798) * Add a conversion from `arith.constant` to `emitc.constant`. * Drop the translation for `arith.constant`s. --- mlir/docs/Dialects/emitc.md | 2 - .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 16 ++++++++ .../ArithToEmitC/ArithToEmitCPass.cpp | 1 - mlir/lib/Target/Cpp/CMakeLists.txt | 1 - mlir/lib/Target/Cpp/TranslateRegistration.cpp | 4 +- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 12 ------ .../ArithToEmitC/arith-to-emit-c-failed.mlir | 15 ------- .../ArithToEmitC/arith-to-emitc.mlir | 24 ++++++++++- mlir/test/Target/Cpp/call.mlir | 2 +- mlir/test/Target/Cpp/const.mlir | 20 +++++++++ mlir/test/Target/Cpp/for.mlir | 10 ++--- mlir/test/Target/Cpp/if.mlir | 2 +- mlir/test/Target/Cpp/stdops.mlir | 41 ++----------------- .../llvm-project-overlay/mlir/BUILD.bazel | 1 - 14 files changed, 70 insertions(+), 81 deletions(-) delete mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md index b227a8c4599a8..1158bc683af06 100644 --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -30,5 +30,3 @@ translating the following operations: * `func.call` * `func.func` * `func.return` -* 'arith' Dialect - * `arith.constant` diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 6909534d4790f..40dce001a3b22 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -24,6 +24,21 @@ using namespace mlir; //===----------------------------------------------------------------------===// namespace { +class ArithConstantOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp arithConst, + arith::ConstantOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + arithConst, arithConst.getType(), adaptor.getValue()); + return success(); + } +}; + template class ArithOpConversion final : public OpConversionPattern { public: @@ -51,6 +66,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, // clang-format off patterns.add< + ArithConstantOpConversionPattern, ArithOpConversion, ArithOpConversion, ArithOpConversion, diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp index b377c063a7aa0..45a088ed144f1 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp @@ -38,7 +38,6 @@ void ConvertArithToEmitC::runOnOperation() { target.addLegalDialect(); target.addIllegalDialect(); - target.addLegalOp(); RewritePatternSet patterns(&getContext()); diff --git a/mlir/lib/Target/Cpp/CMakeLists.txt b/mlir/lib/Target/Cpp/CMakeLists.txt index 5521e7909a8ab..d8f372cf16245 100644 --- a/mlir/lib/Target/Cpp/CMakeLists.txt +++ b/mlir/lib/Target/Cpp/CMakeLists.txt @@ -6,7 +6,6 @@ add_mlir_translation_library(MLIRTargetCpp ${EMITC_MAIN_INCLUDE_DIR}/emitc/Target/Cpp LINK_LIBS PUBLIC - MLIRArithDialect MLIRControlFlowDialect MLIREmitCDialect MLIRFuncDialect diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp index b486e5429ea6a..4104b177d7d9a 100644 --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -41,8 +40,7 @@ void registerToCppTranslation() { }, [](DialectRegistry ®istry) { // clang-format off - registry.insertgetResult(0); @@ -1425,9 +1416,6 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { // Func ops. .Case( [&](auto op) { return printOperation(*this, op); }) - // Arithmetic ops. - .Case( - [&](auto op) { return printOperation(*this, op); }) .Case([&](auto op) { return success(); }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir deleted file mode 100644 index b13c6506787c5..0000000000000 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s - -func.func @arith_constant_complex_tensor() -> (tensor>) { - // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}} - %c = arith.constant dense<(2, 2)> : tensor> - return %c : tensor> -} - -// ----- - -func.func @arith_constant_invalid_int_type() -> (i10) { - // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}} - %c = arith.constant 0 : i10 - return %c : i10 -} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 6a56474a5c48b..2886810c01e91 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -1,4 +1,26 @@ -// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-arith-to-emitc %s | FileCheck %s + +// CHECK-LABEL: arith_constants +func.func @arith_constants() { + // CHECK: emitc.constant + // CHECK-SAME: value = 0 : index + %c_index = arith.constant 0 : index + // CHECK: emitc.constant + // CHECK-SAME: value = 0 : i32 + %c_signless_int_32 = arith.constant 0 : i32 + // CHECK: emitc.constant + // CHECK-SAME: value = 0.{{0+}}e+00 : f32 + %c_float_32 = arith.constant 0.0 : f32 + // CHECK: emitc.constant + // CHECK-SAME: value = dense<0> : tensor + %c_tensor_single_value = arith.constant dense<0> : tensor + // CHECK: emitc.constant + // CHECK-SAME: value{{.*}}[1, 2], [-3, 9], [0, 0], [2, -1]{{.*}}tensor<4x2xi64> + %c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64> + return +} + +// ----- func.func @arith_ops(%arg0: f32, %arg1: f32) { // CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32 diff --git a/mlir/test/Target/Cpp/call.mlir b/mlir/test/Target/Cpp/call.mlir index 2bcdc87205184..e3ac392f30b62 100644 --- a/mlir/test/Target/Cpp/call.mlir +++ b/mlir/test/Target/Cpp/call.mlir @@ -18,7 +18,7 @@ func.func @emitc_call_opaque() { func.func @emitc_call_opaque_two_results() { - %0 = arith.constant 0 : index + %0 = "emitc.constant"() <{value = 0 : index}> : () -> index %1:2 = emitc.call_opaque "two_results" () : () -> (i32, i32) return } diff --git a/mlir/test/Target/Cpp/const.mlir b/mlir/test/Target/Cpp/const.mlir index 28a547909a0ac..524d564b3b943 100644 --- a/mlir/test/Target/Cpp/const.mlir +++ b/mlir/test/Target/Cpp/const.mlir @@ -8,6 +8,11 @@ func.func @emitc_constant() { %c3 = "emitc.constant"(){value = -1 : si8} : () -> si8 %c4 = "emitc.constant"(){value = 255 : ui8} : () -> ui8 %c5 = "emitc.constant"(){value = #emitc.opaque<"CHAR_MIN">} : () -> !emitc.opaque<"char"> + %c6 = "emitc.constant"(){value = 2 : index} : () -> index + %c7 = "emitc.constant"(){value = 2.0 : f32} : () -> f32 + %c8 = "emitc.constant"(){value = dense<0> : tensor} : () -> tensor + %c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex> + %c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> return } // CPP-DEFAULT: void emitc_constant() { @@ -17,6 +22,11 @@ func.func @emitc_constant() { // CPP-DEFAULT-NEXT: int8_t [[V3:[^ ]*]] = -1; // CPP-DEFAULT-NEXT: uint8_t [[V4:[^ ]*]] = 255; // CPP-DEFAULT-NEXT: char [[V5:[^ ]*]] = CHAR_MIN; +// CPP-DEFAULT-NEXT: size_t [[V6:[^ ]*]] = 2; +// CPP-DEFAULT-NEXT: float [[V7:[^ ]*]] = (float)2.000000000e+00; +// CPP-DEFAULT-NEXT: Tensor [[V8:[^ ]*]] = {0}; +// CPP-DEFAULT-NEXT: Tensor [[V9:[^ ]*]] = {0, 1}; +// CPP-DEFAULT-NEXT: Tensor [[V10:[^ ]*]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00}; // CPP-DECLTOP: void emitc_constant() { // CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; @@ -25,9 +35,19 @@ func.func @emitc_constant() { // CPP-DECLTOP-NEXT: int8_t [[V3:[^ ]*]]; // CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]]; // CPP-DECLTOP-NEXT: char [[V5:[^ ]*]]; +// CPP-DECLTOP-NEXT: size_t [[V6:[^ ]*]]; +// CPP-DECLTOP-NEXT: float [[V7:[^ ]*]]; +// CPP-DECLTOP-NEXT: Tensor [[V8:[^ ]*]]; +// CPP-DECLTOP-NEXT: Tensor [[V9:[^ ]*]]; +// CPP-DECLTOP-NEXT: Tensor [[V10:[^ ]*]]; // CPP-DECLTOP-NEXT: [[V0]] = INT_MAX; // CPP-DECLTOP-NEXT: [[V1]] = 42; // CPP-DECLTOP-NEXT: [[V2]] = -1; // CPP-DECLTOP-NEXT: [[V3]] = -1; // CPP-DECLTOP-NEXT: [[V4]] = 255; // CPP-DECLTOP-NEXT: [[V5]] = CHAR_MIN; +// CPP-DECLTOP-NEXT: [[V6]] = 2; +// CPP-DECLTOP-NEXT: [[V7]] = (float)2.000000000e+00; +// CPP-DECLTOP-NEXT: [[V8]] = {0}; +// CPP-DECLTOP-NEXT: [[V9]] = {0, 1}; +// CPP-DECLTOP-NEXT: [[V10]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00}; diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir index b9bd3d98465a2..5225f3ddaff25 100644 --- a/mlir/test/Target/Cpp/for.mlir +++ b/mlir/test/Target/Cpp/for.mlir @@ -33,12 +33,12 @@ func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { // CPP-DECLTOP-NEXT: return; func.func @test_for_yield() { - %start = arith.constant 0 : index - %stop = arith.constant 10 : index - %step = arith.constant 1 : index + %start = "emitc.constant"() <{value = 0 : index}> : () -> index + %stop = "emitc.constant"() <{value = 10 : index}> : () -> index + %step = "emitc.constant"() <{value = 1 : index}> : () -> index - %s0 = arith.constant 0 : i32 - %p0 = arith.constant 1.0 : f32 + %s0 = "emitc.constant"() <{value = 0 : i32}> : () -> i32 + %p0 = "emitc.constant"() <{value = 1.0 : f32}> : () -> f32 %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32 %1 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32 diff --git a/mlir/test/Target/Cpp/if.mlir b/mlir/test/Target/Cpp/if.mlir index 743f8ad396882..7b0e2da85d0eb 100644 --- a/mlir/test/Target/Cpp/if.mlir +++ b/mlir/test/Target/Cpp/if.mlir @@ -49,7 +49,7 @@ func.func @test_if_else(%arg0: i1, %arg1: f32) { func.func @test_if_yield(%arg0: i1, %arg1: f32) { - %0 = arith.constant 0 : i8 + %0 = "emitc.constant"() <{value = 0 : i8}> : () -> i8 %x = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32 %y = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f64 emitc.if %arg0 { diff --git a/mlir/test/Target/Cpp/stdops.mlir b/mlir/test/Target/Cpp/stdops.mlir index 0723188a62c68..cc6bdbe376984 100644 --- a/mlir/test/Target/Cpp/stdops.mlir +++ b/mlir/test/Target/Cpp/stdops.mlir @@ -1,37 +1,6 @@ // RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP -func.func @std_constant() { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 2 : index - %c2 = arith.constant 2.0 : f32 - %c3 = arith.constant dense<0> : tensor - %c4 = arith.constant dense<[0, 1]> : tensor<2xindex> - %c5 = arith.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> - return -} -// CPP-DEFAULT: void std_constant() { -// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = 0; -// CPP-DEFAULT-NEXT: size_t [[V1:[^ ]*]] = 2; -// CPP-DEFAULT-NEXT: float [[V2:[^ ]*]] = (float)2.000000000e+00; -// CPP-DEFAULT-NEXT: Tensor [[V3:[^ ]*]] = {0}; -// CPP-DEFAULT-NEXT: Tensor [[V4:[^ ]*]] = {0, 1}; -// CPP-DEFAULT-NEXT: Tensor [[V5:[^ ]*]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00}; - -// CPP-DECLTOP: void std_constant() { -// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; -// CPP-DECLTOP-NEXT: size_t [[V1:[^ ]*]]; -// CPP-DECLTOP-NEXT: float [[V2:[^ ]*]]; -// CPP-DECLTOP-NEXT: Tensor [[V3:[^ ]*]]; -// CPP-DECLTOP-NEXT: Tensor [[V4:[^ ]*]]; -// CPP-DECLTOP-NEXT: Tensor [[V5:[^ ]*]]; -// CPP-DECLTOP-NEXT: [[V0]] = 0; -// CPP-DECLTOP-NEXT: [[V1]] = 2; -// CPP-DECLTOP-NEXT: [[V2]] = (float)2.000000000e+00; -// CPP-DECLTOP-NEXT: [[V3]] = {0}; -// CPP-DECLTOP-NEXT: [[V4]] = {0, 1}; -// CPP-DECLTOP-NEXT: [[V5]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00}; - func.func @std_call() { %0 = call @one_result () : () -> i32 %1 = call @one_result () : () -> i32 @@ -49,13 +18,11 @@ func.func @std_call() { func.func @std_call_two_results() { - %c = arith.constant 0 : i8 %0:2 = call @two_results () : () -> (i32, f32) %1:2 = call @two_results () : () -> (i32, f32) return } // CPP-DEFAULT: void std_call_two_results() { -// CPP-DEFAULT-NEXT: int8_t [[V0:[^ ]*]] = 0; // CPP-DEFAULT-NEXT: int32_t [[V1:[^ ]*]]; // CPP-DEFAULT-NEXT: float [[V2:[^ ]*]]; // CPP-DEFAULT-NEXT: std::tie([[V1]], [[V2]]) = two_results(); @@ -64,18 +31,16 @@ func.func @std_call_two_results() { // CPP-DEFAULT-NEXT: std::tie([[V3]], [[V4]]) = two_results(); // CPP-DECLTOP: void std_call_two_results() { -// CPP-DECLTOP-NEXT: int8_t [[V0:[^ ]*]]; // CPP-DECLTOP-NEXT: int32_t [[V1:[^ ]*]]; // CPP-DECLTOP-NEXT: float [[V2:[^ ]*]]; // CPP-DECLTOP-NEXT: int32_t [[V3:[^ ]*]]; // CPP-DECLTOP-NEXT: float [[V4:[^ ]*]]; -// CPP-DECLTOP-NEXT: [[V0]] = 0; // CPP-DECLTOP-NEXT: std::tie([[V1]], [[V2]]) = two_results(); // CPP-DECLTOP-NEXT: std::tie([[V3]], [[V4]]) = two_results(); func.func @one_result() -> i32 { - %0 = arith.constant 0 : i32 + %0 = "emitc.constant"() <{value = 0 : i32}> : () -> i32 return %0 : i32 } // CPP-DEFAULT: int32_t one_result() { @@ -89,8 +54,8 @@ func.func @one_result() -> i32 { func.func @two_results() -> (i32, f32) { - %0 = arith.constant 0 : i32 - %1 = arith.constant 1.0 : f32 + %0 = "emitc.constant"() <{value = 0 : i32}> : () -> i32 + %1 = "emitc.constant"() <{value = 1.0 : f32}> : () -> f32 return %0, %1 : i32, f32 } // CPP-DEFAULT: std::tuple two_results() { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 417db2b3b33c2..f2d804477f1b3 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1635,7 +1635,6 @@ cc_library( ]), hdrs = glob(["include/mlir/Target/Cpp/*.h"]), deps = [ - ":ArithDialect", ":ControlFlowDialect", ":EmitCDialect", ":FuncDialect",