diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md index 809a04660336b..1158bc683af06 100644 --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -28,8 +28,5 @@ translating the following operations: * `cf.cond_br` * 'func' Dialect * `func.call` - * `func.constant` * `func.func` * `func.return` -* 'arith' Dialect - * `arith.constant` diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h index 43322ac7f51f6..9cb43689d1ce6 100644 --- a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h @@ -1,22 +1,20 @@ -//===- ArithToEmitC.h - Convert Arith to EmitC ----------------------------===// +//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// + #ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H #define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H -#include "mlir/Pass/Pass.h" - namespace mlir { class RewritePatternSet; +class TypeConverter; -#define GEN_PASS_DECL_ARITHTOEMITCCONVERSIONPASS -#include "mlir/Conversion/Passes.h.inc" - -void populateArithToEmitCConversionPatterns(RewritePatternSet &patterns); +void populateArithToEmitCPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); } // namespace mlir #endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h new file mode 100644 index 0000000000000..6b98fed7185ea --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h @@ -0,0 +1,21 @@ +//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H +#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTARITHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h new file mode 100644 index 0000000000000..5c7f87e470306 --- /dev/null +++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h @@ -0,0 +1,18 @@ +//===- FuncToEmitC.h - Func to EmitC Patterns -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H +#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H + +namespace mlir { +class RewritePatternSet; + +void populateFuncToEmitCPatterns(RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h new file mode 100644 index 0000000000000..65936703ee13e --- /dev/null +++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h @@ -0,0 +1,21 @@ +//===- FuncToEmitCPass.h - Func to EmitC Pass -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H +#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTFUNCTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index c91cc39829215..716b59e3ebea5 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -12,7 +12,7 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" -#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" @@ -29,6 +29,7 @@ #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" +#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bd0be40c04af3..b4693dbdf10b9 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -125,17 +125,20 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> { }]; let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"]; + + let options = [ + Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool", + /*default=*/"false", + "Use saturating truncation for 8-bit float types">, + ]; } //===----------------------------------------------------------------------===// // ArithToEmitC //===----------------------------------------------------------------------===// -def ArithToEmitCConversionPass : Pass<"convert-arith-to-emitc"> { - let summary = "Convert Arith ops to EmitC ops"; - let description = [{ - Convert `arith` operations to operations in the `emitc` dialect. - }]; +def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> { + let summary = "Convert Arith dialect to EmitC dialect"; let dependentDialects = ["emitc::EmitCDialect"]; } @@ -356,6 +359,15 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> { ]; } +//===----------------------------------------------------------------------===// +// FuncToEmitC +//===----------------------------------------------------------------------===// + +def ConvertFuncToEmitC : Pass<"convert-func-to-emitc", "ModuleOp"> { + let summary = "Convert Func dialect to EmitC dialect"; + let dependentDialects = ["emitc::EmitCDialect"]; +} + //===----------------------------------------------------------------------===// // FuncToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt index f33061b2d87cf..9f57627c321fb 100644 --- a/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index 4dff26e23c428..1f0df3cb336b1 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -14,12 +14,14 @@ #define MLIR_DIALECT_EMITC_IR_EMITC_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/EmitC/IR/EmitCTraits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc" diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 644a6ed2566e5..bcdd001528c46 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -16,9 +16,12 @@ include "mlir/Dialect/EmitC/IR/EmitCAttributes.td" include "mlir/Dialect/EmitC/IR/EmitCTypes.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/RegionKindInterface.td" //===----------------------------------------------------------------------===// // EmitC op definitions @@ -28,6 +31,14 @@ include "mlir/Interfaces/SideEffectInterfaces.td" class EmitC_Op traits = []> : Op; +// Base class for unary operations. +class EmitC_UnaryOp traits = []> : + EmitC_Op { + let arguments = (ins AnyType); + let results = (outs AnyType); + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; +} + // Base class for binary operations. class EmitC_BinaryOp traits = []> : EmitC_Op { @@ -36,11 +47,14 @@ class EmitC_BinaryOp traits = []> : let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } +// EmitC OpTrait +def CExpression : NativeOpTrait<"emitc::CExpression">; + // Types only used in binary arithmetic operations. def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>; def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>; -def EmitC_AddOp : EmitC_BinaryOp<"add", []> { +def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> { let summary = "Addition operation"; let description = [{ With the `add` operation the arithmetic operator + (addition) can @@ -63,7 +77,7 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> { let hasVerifier = 1; } -def EmitC_ApplyOp : EmitC_Op<"apply", []> { +def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> { let summary = "Apply operation"; let description = [{ With the `apply` operation the operators & (address of) and * (contents of) @@ -92,7 +106,117 @@ def EmitC_ApplyOp : EmitC_Op<"apply", []> { let hasVerifier = 1; } -def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> { +def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", [CExpression]> { + let summary = "Bitwise and operation"; + let description = [{ + With the `bitwise_and` operation the bitwise operator & (and) can + be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_and %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 & v2; + ``` + }]; +} + +def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", + [CExpression]> { + let summary = "Bitwise left shift operation"; + let description = [{ + With the `bitwise_left_shift` operation the bitwise operator << + (left shift) can be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_left_shift %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 << v2; + ``` + }]; +} + +def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", [CExpression]> { + let summary = "Bitwise not operation"; + let description = [{ + With the `bitwise_not` operation the bitwise operator ~ (not) can + be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_not %arg0 : (i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v2 = ~v1; + ``` + }]; +} + +def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", [CExpression]> { + let summary = "Bitwise or operation"; + let description = [{ + With the `bitwise_or` operation the bitwise operator | (or) + can be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_or %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 | v2; + ``` + }]; +} + +def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", + [CExpression]> { + let summary = "Bitwise right shift operation"; + let description = [{ + With the `bitwise_right_shift` operation the bitwise operator >> + (right shift) can be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_right_shift %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 >> v2; + ``` + }]; +} + +def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", [CExpression]> { + let summary = "Bitwise xor operation"; + let description = [{ + With the `bitwise_xor` operation the bitwise operator ^ (xor) + can be applied. + + Example: + + ```mlir + %0 = emitc.bitwise_xor %arg0, %arg1 : (i32, i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v3 = v1 ^ v2; + ``` + }]; +} + +def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { let summary = "Opaque call operation"; let description = [{ The `call_opaque` operation represents a C++ function call. The callee @@ -119,16 +243,29 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> { Variadic:$operands ); let results = (outs Variadic); + let builders = [ + OpBuilder<(ins + "::mlir::TypeRange":$resultTypes, + "::llvm::StringRef":$callee, + "::mlir::ValueRange":$operands, + CArg<"::mlir::ArrayAttr", "{}">:$args, + CArg<"::mlir::ArrayAttr", "{}">:$template_args), [{ + build($_builder, $_state, resultTypes, callee, args, template_args, + operands); + }] + > + ]; + let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; let hasVerifier = 1; } -def EmitC_CastOp : EmitC_Op<"cast", [ - DeclareOpInterfaceMethods, - SameOperandsAndResultShape - ]> { +def EmitC_CastOp : EmitC_Op<"cast", + [CExpression, + DeclareOpInterfaceMethods, + SameOperandsAndResultShape]> { let summary = "Cast operation"; let description = [{ The `cast` operation performs an explicit type conversion and is emitted @@ -152,7 +289,7 @@ def EmitC_CastOp : EmitC_Op<"cast", [ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; } -def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> { +def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> { let summary = "Comparison operation"; let description = [{ With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=> @@ -223,7 +360,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { let hasVerifier = 1; } -def EmitC_DivOp : EmitC_BinaryOp<"div", []> { +def EmitC_DivOp : EmitC_BinaryOp<"div", [CExpression]> { let summary = "Division operation"; let description = [{ With the `div` operation the arithmetic operator / (division) can @@ -247,6 +384,77 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> { let results = (outs FloatIntegerIndexOrOpaqueType); } +def EmitC_ExpressionOp : EmitC_Op<"expression", + [HasOnlyGraphRegion, SingleBlockImplicitTerminator<"emitc::YieldOp">, + NoRegionArguments]> { + let summary = "Expression operation"; + let description = [{ + The `expression` operation returns a single SSA value which is yielded by + its single-basic-block region. The operation doesn't take any arguments. + + As the operation is to be emitted as a C expression, the operations within + its body must form a single Def-Use tree of emitc ops whose result is + yielded by a terminating `emitc.yield`. + + Example: + + ```mlir + %r = emitc.expression : i32 { + %0 = emitc.add %a, %b : (i32, i32) -> i32 + %1 = emitc.call_opaque "foo"(%0) : (i32) -> i32 + %2 = emitc.add %c, %d : (i32, i32) -> i32 + %3 = emitc.mul %1, %2 : (i32, i32) -> i32 + emitc.yield %3 : i32 + } + ``` + + May be emitted as + + ```c++ + int32_t v7 = foo(v1 + v2) * (v3 + v4); + ``` + + The operations allowed within expression body are EmitC operations with the + CExpression trait. + + When specified, the optional `do_not_inline` indicates that the expression is + to be emitted as seen above, i.e. as the rhs of an EmitC SSA value + definition. Otherwise, the expression may be emitted inline, i.e. directly + at its use. + }]; + + let arguments = (ins UnitAttr:$do_not_inline); + let results = (outs AnyType:$result); + let regions = (region SizedRegion<1>:$region); + + let hasVerifier = 1; + let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region"; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + auto predicate = [](Operation &op) { + assert(op.hasTrait() && "Expected a C expression"); + // Conservatively assume calls to read and write memory. + if (isa(op)) + return true; + // De-referencing reads modifiable memory, address-taking has no + // side-effect. + auto applyOp = dyn_cast(op); + if (applyOp) + return applyOp.getApplicableOperator() == "*"; + // Any operation using variables is assumed to have a side effect of + // reading memory mutable by emitc::assign ops. + return llvm::any_of(op.getOperands(), [](Value operand) { + Operation *def = operand.getDefiningOp(); + return def && isa(def); + }); + }; + return llvm::any_of(getRegion().front().without_terminator(), predicate); + }; + Operation *getRootOp(); + }]; +} + def EmitC_ForOp : EmitC_Op<"for", [AllTypesMatch<["lowerBound", "upperBound", "step"]>, SingleBlockImplicitTerminator<"emitc::YieldOp">, @@ -308,6 +516,219 @@ def EmitC_ForOp : EmitC_Op<"for", let hasRegionVerifier = 1; } +def EmitC_CallOp : EmitC_Op<"call", + [CallOpInterface, CExpression, + DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `emitc.call` operation represents a direct call to an `emitc.func` + that is within the same symbol scope as the call. The operands and result type + of the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType(); + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def EmitC_DeclareFuncOp : EmitC_Op<"declare_func", [ + DeclareOpInterfaceMethods +]> { + let summary = "An operation to declare a function"; + let description = [{ + The `declare_func` operation allows to insert a function declaration for an + `emitc.func` at a specific position. The operation only requires the `callee` + of the `emitc.func` to be specified as an attribute. + + Example: + + ```mlir + emitc.declare_func @bar + emitc.func @foo(%arg0: i32) -> i32 { + %0 = emitc.call @bar(%arg0) : (i32) -> (i32) + emitc.return %0 : i32 + } + + emitc.func @bar(%arg0: i32) -> i32 { + emitc.return %arg0 : i32 + } + ``` + + ```c++ + // Code emitted for the operations above. + int32_t bar(int32_t v1); + int32_t foo(int32_t v1) { + int32_t v2 = bar(v1); + return v2; + } + + int32_t bar(int32_t v1) { + return v1; + } + ``` + }]; + let arguments = (ins FlatSymbolRefAttr:$sym_name); + let assemblyFormat = [{ + $sym_name attr-dict + }]; +} + +def EmitC_FuncOp : EmitC_Op<"func", [ + AutomaticAllocationScope, + FunctionOpInterface, IsolatedFromAbove +]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). While the MLIR textual form provides a nice + inline syntax for function arguments, they are internally represented as + “block arguments” to the first block in the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // A function with no results: + emitc.func @foo(%arg0 : i32) { + emitc.call_opaque "bar" (%arg0) : (i32) -> () + emitc.return + } + + // A function with its argument as single result: + emitc.func @foo(%arg0 : i32) -> i32 { + emitc.return %arg0 : i32 + } + + // A function with specifiers attribute: + emitc.func @example_specifiers_fn_attr() -> i32 + attributes {specifiers = ["static","inline"]} { + %0 = emitc.call_opaque "foo" (): () -> i32 + emitc.return %0 : i32 + } + + // An external function definition: + emitc.func private @extern_func(i32) + attributes {specifiers = ["extern"]} + ``` + }]; + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$specifiers, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">, + ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `emitc.return` operation represents a return operation within a function. + The operation takes zero or exactly one operand and produces no results. + The operand number and type must match the signature of the function + that contains the operation. + + Example: + + ```mlir + emitc.func @foo() : (i32) { + ... + emitc.return %0 : i32 + } + ``` + }]; + let arguments = (ins Optional:$operand); + + let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?"; + let hasVerifier = 1; +} + def EmitC_IncludeOp : EmitC_Op<"include", [HasParent<"ModuleOp">]> { let summary = "Include operation"; @@ -352,7 +773,70 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { let assemblyFormat = "$value attr-dict `:` type($result)"; } -def EmitC_MulOp : EmitC_BinaryOp<"mul", []> { +def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", [CExpression]> { + let summary = "Logical and operation"; + let description = [{ + With the `logical_and` operation the logical operator && (and) can + be applied. + + Example: + + ```mlir + %0 = emitc.logical_and %arg0, %arg1 : i32, i32 + ``` + ```c++ + // Code emitted for the operation above. + bool v3 = v1 && v2; + ``` + }]; + + let results = (outs I1); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", [CExpression]> { + let summary = "Logical not operation"; + let description = [{ + With the `logical_not` operation the logical operator ! (negation) can + be applied. + + Example: + + ```mlir + %0 = emitc.logical_not %arg0 : i32 + ``` + ```c++ + // Code emitted for the operation above. + bool v2 = !v1; + ``` + }]; + + let results = (outs I1); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> { + let summary = "Logical or operation"; + let description = [{ + With the `logical_or` operation the logical operator || (inclusive or) + can be applied. + + Example: + + ```mlir + %0 = emitc.logical_or %arg0, %arg1 : i32, i32 + ``` + ```c++ + // Code emitted for the operation above. + bool v3 = v1 || v2; + ``` + }]; + + let results = (outs I1); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> { let summary = "Multiplication operation"; let description = [{ With the `mul` operation the arithmetic operator * (multiplication) can @@ -376,7 +860,7 @@ def EmitC_MulOp : EmitC_BinaryOp<"mul", []> { let results = (outs FloatIntegerIndexOrOpaqueType); } -def EmitC_RemOp : EmitC_BinaryOp<"rem", []> { +def EmitC_RemOp : EmitC_BinaryOp<"rem", [CExpression]> { let summary = "Remainder operation"; let description = [{ With the `rem` operation the arithmetic operator % (remainder) can @@ -398,7 +882,7 @@ def EmitC_RemOp : EmitC_BinaryOp<"rem", []> { let results = (outs IntegerIndexOrOpaqueType); } -def EmitC_SubOp : EmitC_BinaryOp<"sub", []> { +def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> { let summary = "Subtraction operation"; let description = [{ With the `sub` operation the arithmetic operator - (subtraction) can @@ -424,6 +908,42 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", []> { let hasVerifier = 1; } +def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> { + let summary = "Unary minus operation"; + let description = [{ + With the `unary_minus` operation the unary operator - (minus) can be + applied. + + Example: + + ```mlir + %0 = emitc.unary_plus %arg0 : (i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v2 = -v1; + ``` + }]; +} + +def EmitC_UnaryPlusOp : EmitC_UnaryOp<"unary_plus", [CExpression]> { + let summary = "Unary plus operation"; + let description = [{ + With the `unary_plus` operation the unary operator + (plus) can be + applied. + + Example: + + ```mlir + %0 = emitc.unary_plus %arg0 : (i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v2 = +v1; + ``` + }]; +} + def EmitC_VariableOp : EmitC_Op<"variable", []> { let summary = "Variable operation"; let description = [{ @@ -466,6 +986,40 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> { let hasVerifier = 1; } +def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { + let summary = "Verbatim operation"; + let description = [{ + The `verbatim` operation produces no results and the value is emitted as is + followed by a line break ('\n' character) during translation. + + Note: Use with caution. This operation can have arbitrary effects on the + semantics of the emitted code. Use semantically more meaningful operations + whenever possible. Additionally this op is *NOT* intended to be used to + inject large snippets of code. + + This operation can be used in situations where a more suitable operation is + not yet implemented in the dialect or where preprocessor directives + interfere with the structure of the code. One example of this is to declare + the linkage of external symbols to make the generated code usable in both C + and C++ contexts: + + ```c++ + #ifdef __cplusplus + extern "C" { + #endif + + ... + + #ifdef __cplusplus + } + #endif + ``` + }]; + + let arguments = (ins StrAttr:$value); + let assemblyFormat = "$value attr-dict"; +} + def EmitC_AssignOp : EmitC_Op<"assign", []> { let summary = "Assign operation"; let description = [{ @@ -494,18 +1048,24 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> { } def EmitC_YieldOp : EmitC_Op<"yield", - [Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> { + [Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> { let summary = "block termination operation"; let description = [{ - "yield" terminates blocks within EmitC control-flow operations. Since - control-flow constructs in C do not return values, this operation doesn't - take any arguments. + "yield" terminates its parent EmitC op's region, optionally yielding + an SSA value. The semantics of how the values are yielded is defined by the + parent operation. + If "yield" has an operand, the operand must match the parent operation's + result. If the parent operation defines no values, then the "emitc.yield" + may be left out in the custom syntax and the builders will insert one + implicitly. Otherwise, it has to be present in the syntax to indicate which + value is yielded. }]; - let arguments = (ins); + let arguments = (ins Optional:$result); let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let assemblyFormat = [{ attr-dict }]; + let hasVerifier = 1; + let assemblyFormat = [{ attr-dict ($result^ `:` type($result))? }]; } def EmitC_IfOp : EmitC_Op<"if", diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td index ae843e49c6c5b..ea5e9efd5fa0b 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td @@ -57,8 +57,7 @@ def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> { }]; let parameters = (ins StringRefParameter<"the opaque value">:$value); - - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "`<` $value `>`"; } def EmitC_OpaqueOrTypedAttr : AnyAttrOf<[EmitC_OpaqueAttr, TypedAttrInterface]>; diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h new file mode 100644 index 0000000000000..c1602dfce4b48 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h @@ -0,0 +1,30 @@ +//===- EmitCTraits.h - EmitC trait definitions ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares C++ classes for some of the traits used in the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H +#define MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace OpTrait { +namespace emitc { + +template +class CExpression : public TraitBase {}; + +} // namespace emitc +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index 8dfda3be99d5f..5ab729df67882 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -90,7 +90,8 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> { }]; let parameters = (ins StringRefParameter<"the opaque value">:$value); - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "`<` $value `>`"; + let genVerifyDecl = 1; } def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> { diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..0b507d75fa07a --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name EmitC) +add_public_tablegen_target(MLIREmitCTransformsIncGen) + +add_mlir_doc(Passes EmitCPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h new file mode 100644 index 0000000000000..5cd27149d366e --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h @@ -0,0 +1,35 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace emitc { + +//===----------------------------------------------------------------------===// +// Passes +//===----------------------------------------------------------------------===// + +/// Creates an instance of the C-style expressions forming pass. +std::unique_ptr createFormExpressionsPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" + +} // namespace emitc +} // namespace mlir + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td new file mode 100644 index 0000000000000..fd083abc95715 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td @@ -0,0 +1,24 @@ +//===-- Passes.td - pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES +#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def FormExpressions : Pass<"form-expressions"> { + let summary = "Form C-style expressions from C-operator ops"; + let description = [{ + The pass wraps emitc ops modelling C operators in emitc.expression ops and + then folds single-use expressions into their users where possible. + }]; + let constructor = "mlir::emitc::createFormExpressionsPass()"; + let dependentDialects = ["emitc::EmitCDialect"]; +} + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h new file mode 100644 index 0000000000000..2574acd7d48e0 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h @@ -0,0 +1,34 @@ +//===- Transforms.h - EmitC transformations as patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace emitc { + +//===----------------------------------------------------------------------===// +// Expression transforms +//===----------------------------------------------------------------------===// + +ExpressionOp createExpression(Operation *op, OpBuilder &builder); + +//===----------------------------------------------------------------------===// +// Populate functions +//===----------------------------------------------------------------------===// + +/// Populates `patterns` with expression-related patterns. +void populateExpressionPatterns(RewritePatternSet &patterns); + +} // namespace emitc +} // namespace mlir + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 9e65898154bd6..789450903afe7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1439,7 +1439,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ let extraClassDeclaration = [{ // Add an entry block to an empty function, and set up the block arguments // to match the signature of the function. - Block *addEntryBlock(); + Block *addEntryBlock(OpBuilder &builder); bool isVarArg() { return getFunctionType().isVarArg(); } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index 36ad6755cab25..991e753d1b359 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -285,7 +285,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> { // Adds an empty entry block and loop merge block containing one // spirv.mlir.merge op. - void addEntryAndMergeBlock(); + void addEntryAndMergeBlock(OpBuilder &builder); }]; let hasOpcode = 0; @@ -427,7 +427,7 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> { Block *getMergeBlock(); /// Adds a selection merge block containing one spirv.mlir.merge op. - void addMergeBlock(); + void addMergeBlock(OpBuilder &builder); /// Creates a spirv.mlir.selection op for `if () then { }` /// with `builder`. `builder`'s insertion point will remain at after the diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index f22980036ffcf..5207559f36250 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/Async/Passes.h" #include "mlir/Dialect/Bufferization/Pipelines/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" @@ -86,6 +87,7 @@ inline void registerAllPasses() { vector::registerVectorPasses(); arm_sme::registerArmSMEPasses(); arm_sve::registerArmSVEPasses(); + emitc::registerEmitCPasses(); // Dialect pipelines bufferization::registerBufferizationPipelines(); diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td index 98e002565cf19..be71063272d80 100644 --- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td +++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td @@ -131,6 +131,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ static void buildWithEntryBlock( OpBuilder &builder, OperationState &state, StringRef name, Type type, ArrayRef attrs, TypeRange inputTypes) { + OpBuilder::InsertionGuard g(builder); state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name), @@ -139,8 +140,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ // Add the function body. Region *bodyRegion = state.addRegion(); - Block *body = new Block(); - bodyRegion->push_back(body); + Block *body = builder.createBlock(bodyRegion); for (Type input : inputTypes) body->addArgument(input, state.location); } diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 648fd2b4af0b7..40dce001a3b22 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -1,4 +1,4 @@ -//===- ArithToEmitC.cpp - Arith to EmitC conversion -----------------------===// +//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file implements a pass to convert arith ops into emitc ops. +// This file implements patterns to convert the Arith dialect to the EmitC +// dialect. // //===----------------------------------------------------------------------===// @@ -14,91 +15,62 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -namespace mlir { -#define GEN_PASS_DEF_ARITHTOEMITCCONVERSIONPASS -#include "mlir/Conversion/Passes.h.inc" -} // namespace mlir - using namespace mlir; -namespace { - -static bool isConvertibleToEmitC(Type type) { - Type baseType = type; - if (auto tensorType = dyn_cast(type)) { - if (!tensorType.hasRank() || !tensorType.hasStaticShape()) { - return false; - } - baseType = tensorType.getElementType(); - } - - if (isa(baseType)) { - return true; - } - - if (auto intType = dyn_cast(baseType)) { - switch (intType.getWidth()) { - case 1: - case 8: - case 16: - case 32: - case 64: - return true; - } - return false; - } - - if (auto floatType = dyn_cast(baseType)) { - return floatType.isF32() || floatType.isF64(); - } - - return false; -} +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// +namespace { class ArithConstantOpConversionPattern - : public OpRewritePattern { + : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(arith::ConstantOp arithConst, - PatternRewriter &rewriter) const override { - - auto constantType = arithConst.getType(); - if (!isConvertibleToEmitC(constantType)) { - return rewriter.notifyMatchFailure(arithConst.getLoc(), - "Type cannot be converted to emitc"); - } - - rewriter.replaceOpWithNewOp(arithConst, constantType, - arithConst.getValue()); + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp arithConst, + arith::ConstantOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + arithConst, arithConst.getType(), adaptor.getValue()); return success(); } }; -struct ConvertArithToEmitCPass - : public impl::ArithToEmitCConversionPassBase { +template +class ArithOpConversion final : public OpConversionPattern { public: - void runOnOperation() override { + using OpConversionPattern::OpConversionPattern; - ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addLegalDialect(); - RewritePatternSet patterns(&getContext()); - populateArithToEmitCConversionPatterns(patterns); + LogicalResult + matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - signalPassFailure(); - } + rewriter.template replaceOpWithNewOp(arithOp, arithOp.getType(), + adaptor.getOperands()); + + return success(); } }; - } // namespace -void mlir::populateArithToEmitCConversionPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + + // clang-format off + patterns.add< + ArithConstantOpConversionPattern, + ArithOpConversion, + ArithOpConversion, + ArithOpConversion, + ArithOpConversion + >(typeConverter, ctx); + // clang-format on } diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp new file mode 100644 index 0000000000000..45a088ed144f1 --- /dev/null +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp @@ -0,0 +1,52 @@ +//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert the Arith dialect to the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTARITHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertArithToEmitC + : public impl::ConvertArithToEmitCBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertArithToEmitC::runOnOperation() { + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(&getContext()); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + + populateArithToEmitCPatterns(typeConverter, patterns); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt index c1bb6d71310ed..a3784f47c3bc2 100644 --- a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt +++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt @@ -1,5 +1,6 @@ -add_mlir_conversion_library(ArithToEmitC +add_mlir_conversion_library(MLIRArithToEmitC ArithToEmitC.cpp + ArithToEmitCPass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC @@ -7,11 +8,9 @@ add_mlir_conversion_library(ArithToEmitC DEPENDS MLIRConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC - MLIREmitCDialect MLIRArithDialect - MLIRTransforms -) + MLIREmitCDialect + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 0ab53ce7e3327..7760373913761 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -259,7 +259,7 @@ static void addResumeFunction(ModuleOp module) { kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); resumeOp.setPrivate(); - auto *block = resumeOp.addEntryBlock(); + auto *block = resumeOp.addEntryBlock(moduleBuilder); auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); blockBuilder.create(resumeOp.getArgument(0)); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 7e9369b14b401..b70c26effe2b6 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -18,6 +18,7 @@ add_subdirectory(ControlFlowToLLVM) add_subdirectory(ControlFlowToSCF) add_subdirectory(ControlFlowToSPIRV) add_subdirectory(ConvertToLLVM) +add_subdirectory(FuncToEmitC) add_subdirectory(FuncToLLVM) add_subdirectory(FuncToSPIRV) add_subdirectory(GPUCommon) diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp index 363e5f9b8cefe..d3ee89743da9d 100644 --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -98,12 +98,10 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( loc, builder.create(loc, builder.getI1Type(), condition), loopVariablesNextIter); - auto *afterBlock = new Block; - whileOp.getAfter().push_back(afterBlock); + Block *afterBlock = builder.createBlock(&whileOp.getAfter()); afterBlock->addArguments( loopVariablesInit.getTypes(), SmallVector(loopVariablesInit.size(), loc)); - builder.setInsertionPointToEnd(afterBlock); builder.create(loc, afterBlock->getArguments()); return whileOp.getOperation(); diff --git a/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt b/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt new file mode 100644 index 0000000000000..97752205bbcb4 --- /dev/null +++ b/mlir/lib/Conversion/FuncToEmitC/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRFuncToEmitC + FuncToEmitC.cpp + FuncToEmitCPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/FuncToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIREmitCDialect + MLIRFuncDialect + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp new file mode 100644 index 0000000000000..6a8ecb7b00473 --- /dev/null +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -0,0 +1,120 @@ +//===- FuncToEmitC.cpp - Func to EmitC Patterns -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert the Func dialect to the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +namespace { +class CallOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Multiple results func was not converted to `emitc.func`. + if (callOp.getNumResults() > 1) + return rewriter.notifyMatchFailure( + callOp, "only functions with zero or one result can be converted"); + + rewriter.replaceOpWithNewOp( + callOp, + callOp.getNumResults() ? callOp.getResult(0).getType() : nullptr, + adaptor.getOperands(), callOp->getAttrs()); + + return success(); + } +}; + +class FuncOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (funcOp.getFunctionType().getNumResults() > 1) + return rewriter.notifyMatchFailure( + funcOp, "only functions with zero or one result can be converted"); + + // Create the converted `emitc.func` op. + emitc::FuncOp newFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType()); + + // Copy over all attributes other than the function name and type. + for (const auto &namedAttr : funcOp->getAttrs()) { + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && + namedAttr.getName() != SymbolTable::getSymbolAttrName()) + newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + + // Add `extern` to specifiers if `func.func` is declaration only. + if (funcOp.isDeclaration()) { + ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"}); + newFuncOp.setSpecifiersAttr(specifiers); + } + + // Add `static` to specifiers if `func.func` is private but not a + // declaration. + if (funcOp.isPrivate() && !funcOp.isDeclaration()) { + ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"}); + newFuncOp.setSpecifiersAttr(specifiers); + } + + if (!funcOp.isDeclaration()) + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + rewriter.eraseOp(funcOp); + + return success(); + } +}; + +class ReturnOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (returnOp.getNumOperands() > 1) + return rewriter.notifyMatchFailure( + returnOp, "only zero or one operand is supported"); + + rewriter.replaceOpWithNewOp( + returnOp, + returnOp.getNumOperands() ? adaptor.getOperands()[0] : nullptr); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + + patterns.add(ctx); +} diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp new file mode 100644 index 0000000000000..26d32e29bef8c --- /dev/null +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp @@ -0,0 +1,47 @@ +//===- FuncToEmitC.cpp - Func to EmitC Pass ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert the Func dialect to the EmitC dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h" + +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTFUNCTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertFuncToEmitC + : public impl::ConvertFuncToEmitCBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertFuncToEmitC::runOnOperation() { + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalOp(); + + RewritePatternSet patterns(&getContext()); + populateFuncToEmitCPatterns(patterns); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index bd50c67fb8795..53b44aa3241bb 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -135,7 +135,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp); OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); + rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter)); SmallVector args; size_t argOffset = resultStructType ? 1 : 0; @@ -203,7 +203,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, // The wrapper that we synthetize here should only be visible in this module. newFuncOp.setLinkage(LLVM::Linkage::Private); - builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); + builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder)); // Get a ValueRange containing arguments. FunctionType type = cast(funcOp.getFunctionType()); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 2bfca303b5fd4..2dc42f0a85e66 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -520,9 +520,7 @@ struct GlobalMemrefOpLowering global, arrayTy, global.getConstant(), linkage, global.getSymName(), initialValue, alignment, *addressSpace); if (!global.isExternal() && global.isUninitialized()) { - Block *blk = new Block(); - newGlobal.getInitializerRegion().push_back(blk); - rewriter.setInsertionPointToStart(blk); + rewriter.createBlock(&newGlobal.getInitializerRegion()); Value undef[] = { rewriter.create(global.getLoc(), arrayTy)}; rewriter.create(global.getLoc(), undef); diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index febfe97f6c0a9..d90cf931385fc 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -138,14 +138,13 @@ struct ForOpConversion final : SCFToSPIRVPattern { // from header to merge. auto loc = forOp.getLoc(); auto loopOp = rewriter.create(loc, spirv::LoopControl::None); - loopOp.addEntryAndMergeBlock(); + loopOp.addEntryAndMergeBlock(rewriter); OpBuilder::InsertionGuard guard(rewriter); // Create the block for the header. - auto *header = new Block(); - // Insert the header. - loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), - header); + Block *header = rewriter.createBlock(&loopOp.getBody(), + getBlockIt(loopOp.getBody(), 1)); + rewriter.setInsertionPointAfter(loopOp); // Create the new induction variable to use. Value adapLowerBound = adaptor.getLowerBound(); @@ -342,7 +341,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern { ConversionPatternRewriter &rewriter) const override { auto loc = whileOp.getLoc(); auto loopOp = rewriter.create(loc, spirv::LoopControl::None); - loopOp.addEntryAndMergeBlock(); + loopOp.addEntryAndMergeBlock(rewriter); Region &beforeRegion = whileOp.getBefore(); Region &afterRegion = whileOp.getAfter(); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index d5be2e906989f..0ccb5a9f658da 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1809,6 +1809,8 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result, "upper bound operand count does not match the affine map"); assert(step > 0 && "step has to be a positive integer constant"); + OpBuilder::InsertionGuard guard(builder); + // Set variadic segment sizes. result.addAttribute( getOperandSegmentSizeAttr(), @@ -1837,12 +1839,11 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result, // Create a region and a block for the body. The argument of the region is // the loop induction variable. Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); + Block *bodyBlock = builder.createBlock(bodyRegion); Value inductionVar = - bodyBlock.addArgument(builder.getIndexType(), result.location); + bodyBlock->addArgument(builder.getIndexType(), result.location); for (Value val : iterArgs) - bodyBlock.addArgument(val.getType(), val.getLoc()); + bodyBlock->addArgument(val.getType(), val.getLoc()); // Create the default terminator if the builder is not provided and if the // iteration arguments are not provided. Otherwise, leave this to the caller @@ -1851,9 +1852,9 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result, ensureTerminator(*bodyRegion, builder, result.location); } else if (bodyBuilder) { OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); + builder.setInsertionPointToStart(bodyBlock); bodyBuilder(builder, result.location, inductionVar, - bodyBlock.getArguments().drop_front()); + bodyBlock->getArguments().drop_front()); } } @@ -2890,18 +2891,20 @@ void AffineIfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, IntegerSet set, ValueRange args, bool withElseRegion) { assert(resultTypes.empty() || withElseRegion); + OpBuilder::InsertionGuard guard(builder); + result.addTypes(resultTypes); result.addOperands(args); result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set)); Region *thenRegion = result.addRegion(); - thenRegion->push_back(new Block()); + builder.createBlock(thenRegion); if (resultTypes.empty()) AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); Region *elseRegion = result.addRegion(); if (withElseRegion) { - elseRegion->push_back(new Block()); + builder.createBlock(elseRegion); if (resultTypes.empty()) AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); } @@ -3688,6 +3691,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result, "expected upper bound maps to have as many inputs as upper bound " "operands"); + OpBuilder::InsertionGuard guard(builder); result.addTypes(resultTypes); // Convert the reductions to integer attributes. @@ -3733,11 +3737,11 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result, // Create a region and a block for the body. auto *bodyRegion = result.addRegion(); - auto *body = new Block(); + Block *body = builder.createBlock(bodyRegion); + // Add all the block arguments. for (unsigned i = 0, e = steps.size(); i < e; ++i) body->addArgument(IndexType::get(builder.getContext()), result.location); - bodyRegion->push_back(body); if (resultTypes.empty()) ensureTerminator(*bodyRegion, builder, result.location); } diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index 5f583f36cd2cb..a3e3f80954efc 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -68,7 +68,7 @@ void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, void ExecuteOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ValueRange dependencies, ValueRange operands, BodyBuilderFn bodyBuilder) { - + OpBuilder::InsertionGuard guard(builder); result.addOperands(dependencies); result.addOperands(operands); @@ -87,26 +87,21 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result, // Add a body region with block arguments as unwrapped async value operands. Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); + Block *bodyBlock = builder.createBlock(bodyRegion); for (Value operand : operands) { auto valueType = llvm::dyn_cast(operand.getType()); - bodyBlock.addArgument(valueType ? valueType.getValueType() - : operand.getType(), - operand.getLoc()); + bodyBlock->addArgument(valueType ? valueType.getValueType() + : operand.getType(), + operand.getLoc()); } // Create the default terminator if the builder is not provided and if the // expected result is empty. Otherwise, leave this to the caller // because we don't know which values to return from the execute op. if (resultTypes.empty() && !bodyBuilder) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); builder.create(result.location, ValueRange()); } else if (bodyBuilder) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(builder, result.location, bodyBlock.getArguments()); + bodyBuilder(builder, result.location, bodyBlock->getArguments()); } } diff --git a/mlir/lib/Dialect/EmitC/CMakeLists.txt b/mlir/lib/Dialect/EmitC/CMakeLists.txt index f33061b2d87cf..9f57627c321fb 100644 --- a/mlir/lib/Dialect/EmitC/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt index 4665c41a62e80..4cc54201d2745 100644 --- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt @@ -9,8 +9,10 @@ add_mlir_dialect_library(MLIREmitCDialect MLIREmitCAttributesIncGen LINK_LIBS PUBLIC + MLIRCallInterfaces MLIRCastInterfaces MLIRControlFlowInterfaces + MLIRFunctionInterfaces MLIRIR MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 2d578d47aa4a8..5db0777bc30ab 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -7,8 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/IR/EmitCTraits.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -50,6 +54,32 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) { builder.create(loc); } +/// Check that the type of the initial value is compatible with the operations +/// result type. +static LogicalResult verifyInitializationAttribute(Operation *op, + Attribute value) { + assert(op->getNumResults() == 1 && "operation must have 1 result"); + + if (llvm::isa(value)) + return success(); + + if (llvm::isa(value)) + return op->emitOpError() + << "string attributes are not supported, use #emitc.opaque instead"; + + Type resultType = op->getResult(0).getType(); + Type attrType = cast(value).getType(); + + if (resultType != attrType) + return op->emitOpError() + << "requires attribute to either be an #emitc.opaque attribute or " + "it's type (" + << attrType << ") to match the op's result type (" << resultType + << ")"; + + return success(); +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -170,46 +200,83 @@ LogicalResult emitc::CallOpaqueOp::verify() { // ConstantOp //===----------------------------------------------------------------------===// -/// The constant op requires that the attribute's type matches the return type. LogicalResult emitc::ConstantOp::verify() { - if (llvm::isa(getValueAttr())) - return success(); - - // Value must not be empty - StringAttr strAttr = llvm::dyn_cast(getValueAttr()); - if (strAttr && strAttr.empty()) - return emitOpError() << "value must not be empty"; - - auto value = cast(getValueAttr()); - Type type = getType(); - if (!llvm::isa(value.getType()) && type != value.getType()) - return emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; + Attribute value = getValueAttr(); + if (failed(verifyInitializationAttribute(getOperation(), value))) + return failure(); + if (auto opaqueValue = llvm::dyn_cast(value)) { + if (opaqueValue.getValue().empty()) + return emitOpError() << "value must not be empty"; + } return success(); } OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } +//===----------------------------------------------------------------------===// +// ExpressionOp +//===----------------------------------------------------------------------===// + +Operation *ExpressionOp::getRootOp() { + auto yieldOp = cast(getBody()->getTerminator()); + Value yieldedValue = yieldOp.getResult(); + Operation *rootOp = yieldedValue.getDefiningOp(); + assert(rootOp && "Yielded value not defined within expression"); + return rootOp; +} + +LogicalResult ExpressionOp::verify() { + Type resultType = getResult().getType(); + Region ®ion = getRegion(); + + Block &body = region.front(); + + if (!body.mightHaveTerminator()) + return emitOpError("must yield a value at termination"); + + auto yield = cast(body.getTerminator()); + Value yieldResult = yield.getResult(); + + if (!yieldResult) + return emitOpError("must yield a value at termination"); + + Type yieldType = yieldResult.getType(); + + if (resultType != yieldType) + return emitOpError("requires yielded type to match return type"); + + for (Operation &op : region.front().without_terminator()) { + if (!op.hasTrait()) + return emitOpError("contains an unsupported operation"); + if (op.getNumResults() != 1) + return emitOpError("requires exactly one result for each operation"); + if (!op.getResult(0).hasOneUse()) + return emitOpError("requires exactly one use for each operation"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, Value ub, Value step, BodyBuilderFn bodyBuilder) { + OpBuilder::InsertionGuard g(builder); result.addOperands({lb, ub, step}); Type t = lb.getType(); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); - bodyBlock.addArgument(t, result.location); + Block *bodyBlock = builder.createBlock(bodyRegion); + bodyBlock->addArgument(t, result.location); // Create the default terminator if the builder is not provided. if (!bodyBuilder) { ForOp::ensureTerminator(*bodyRegion, builder, result.location); } else { OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(builder, result.location, bodyBlock.getArgument(0)); + builder.setInsertionPointToStart(bodyBlock); + bodyBuilder(builder, result.location, bodyBlock->getArgument(0)); } } @@ -285,6 +352,137 @@ LogicalResult ForOp::verifyRegions() { return success(); } +//===----------------------------------------------------------------------===// +// CallOp +//===----------------------------------------------------------------------===// + +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this)->getAttrOfType("callee"); + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +FunctionType CallOp::getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// DeclareFuncOp +//===----------------------------------------------------------------------===// + +LogicalResult +DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the sym_name attribute was specified. + auto fnAttr = getSymNameAttr(); + if (!fnAttr) + return emitOpError("requires a 'sym_name' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +LogicalResult FuncOp::verify() { + if (getNumResults() > 1) + return emitOpError("requires zero or exactly one result, but has ") + << getNumResults(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + if (getNumOperands() != function.getNumResults()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << function.getNumResults(); + + if (function.getNumResults() == 1) + if (getOperand().getType() != function.getResultTypes()[0]) + return emitError() << "type of the return operand (" + << getOperand().getType() + << ") doesn't match function result type (" + << function.getResultTypes()[0] << ")" + << " in function @" << function.getName(); + return success(); +} + //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// @@ -518,17 +716,8 @@ LogicalResult SubOp::verify() { // VariableOp //===----------------------------------------------------------------------===// -/// The variable op requires that the attribute's type matches the return type. LogicalResult emitc::VariableOp::verify() { - if (llvm::isa(getValueAttr())) - return success(); - - auto value = cast(getValueAttr()); - Type type = getType(); - if (!llvm::isa(value.getType()) && type != value.getType()) - return emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; - return success(); + return verifyInitializationAttribute(getOperation(), getValueAttr()); } //===----------------------------------------------------------------------===// @@ -545,6 +734,23 @@ LogicalResult emitc::SubscriptOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +LogicalResult emitc::YieldOp::verify() { + Value result = getResult(); + Operation *containingOp = getOperation()->getParentOp(); + + if (result && containingOp->getNumResults() != 1) + return emitOpError() << "yields a value not returned by parent"; + + if (!result && containingOp->getNumResults() != 0) + return emitOpError() << "does not yield a value to be returned by parent"; + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// @@ -565,27 +771,6 @@ LogicalResult emitc::SubscriptOp::verify() { #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc" -Attribute emitc::OpaqueAttr::parse(AsmParser &parser, Type type) { - if (parser.parseLess()) - return Attribute(); - std::string value; - SMLoc loc = parser.getCurrentLocation(); - if (parser.parseOptionalString(&value)) { - parser.emitError(loc) << "expected string"; - return Attribute(); - } - if (parser.parseGreater()) - return Attribute(); - - return get(parser.getContext(), value); -} - -void emitc::OpaqueAttr::print(AsmPrinter &printer) const { - printer << "<\""; - llvm::printEscapedString(getValue(), printer.getStream()); - printer << "\">"; -} - //===----------------------------------------------------------------------===// // EmitC Types //===----------------------------------------------------------------------===// @@ -660,27 +845,15 @@ emitc::ArrayType::cloneWith(std::optional> shape, // OpaqueType //===----------------------------------------------------------------------===// -Type emitc::OpaqueType::parse(AsmParser &parser) { - if (parser.parseLess()) - return Type(); - std::string value; - SMLoc loc = parser.getCurrentLocation(); - if (parser.parseOptionalString(&value) || value.empty()) { - parser.emitError(loc) << "expected non empty string in !emitc.opaque type"; - return Type(); +LogicalResult mlir::emitc::OpaqueType::verify( + llvm::function_ref emitError, + llvm::StringRef value) { + if (value.empty()) { + return emitError() << "expected non empty string in !emitc.opaque type"; } if (value.back() == '*') { - parser.emitError(loc) << "pointer not allowed as outer type with " - "!emitc.opaque, use !emitc.ptr instead"; - return Type(); + return emitError() << "pointer not allowed as outer type with " + "!emitc.opaque, use !emitc.ptr instead"; } - if (parser.parseGreater()) - return Type(); - return get(parser.getContext(), value); -} - -void emitc::OpaqueType::print(AsmPrinter &printer) const { - printer << "<\""; - llvm::printEscapedString(getValue(), printer.getStream()); - printer << "\">"; + return success(); } diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..bfcc14523f137 --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIREmitCTransforms + Transforms.cpp + FormExpressions.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms + + DEPENDS + MLIREmitCTransformsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIREmitCDialect + MLIRTransforms +) diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp new file mode 100644 index 0000000000000..5b03f81b305fd --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp @@ -0,0 +1,60 @@ +//===- FormExpressions.cpp - Form C-style expressions --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that forms EmitC operations modeling C operators +// into C-style expressions using the emitc.expression op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace emitc { +#define GEN_PASS_DEF_FORMEXPRESSIONS +#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" +} // namespace emitc +} // namespace mlir + +using namespace mlir; +using namespace emitc; + +namespace { +struct FormExpressionsPass + : public emitc::impl::FormExpressionsBase { + void runOnOperation() override { + Operation *rootOp = getOperation(); + MLIRContext *context = rootOp->getContext(); + + // Wrap each C operator op with an expression op. + OpBuilder builder(context); + auto matchFun = [&](Operation *op) { + if (op->hasTrait()) + createExpression(op, builder); + }; + rootOp->walk(matchFun); + + // Fold expressions where possible. + RewritePatternSet patterns(context); + populateExpressionPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns)))) + return signalPassFailure(); + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; +} // namespace + +std::unique_ptr mlir::emitc::createFormExpressionsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp new file mode 100644 index 0000000000000..87350ecdceaaa --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp @@ -0,0 +1,112 @@ +//===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/Transforms/Transforms.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace emitc { + +ExpressionOp createExpression(Operation *op, OpBuilder &builder) { + assert(op->hasTrait() && + "Expected a C expression"); + + // Create an expression yielding the value returned by op. + assert(op->getNumResults() == 1 && "Expected exactly one result"); + Value result = op->getResult(0); + Type resultType = result.getType(); + Location loc = op->getLoc(); + + builder.setInsertionPointAfter(op); + auto expressionOp = builder.create(loc, resultType); + + // Replace all op's uses with the new expression's result. + result.replaceAllUsesWith(expressionOp.getResult()); + + // Create an op to yield op's value. + Region ®ion = expressionOp.getRegion(); + Block &block = region.emplaceBlock(); + builder.setInsertionPointToEnd(&block); + auto yieldOp = builder.create(loc, result); + + // Move op into the new expression. + op->moveBefore(yieldOp); + + return expressionOp; +} + +} // namespace emitc +} // namespace mlir + +using namespace mlir; +using namespace mlir::emitc; + +namespace { + +struct FoldExpressionOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ExpressionOp expressionOp, + PatternRewriter &rewriter) const override { + bool anythingFolded = false; + for (Operation &op : llvm::make_early_inc_range( + expressionOp.getBody()->without_terminator())) { + // Don't fold expressions whose result value has its address taken. + auto applyOp = dyn_cast(op); + if (applyOp && applyOp.getApplicableOperator() == "&") + continue; + + for (Value operand : op.getOperands()) { + auto usedExpression = + dyn_cast_if_present(operand.getDefiningOp()); + + if (!usedExpression) + continue; + + // Don't fold expressions with multiple users: assume any + // re-materialization was done separately. + if (!usedExpression.getResult().hasOneUse()) + continue; + + // Don't fold expressions with side effects. + if (usedExpression.hasSideEffects()) + continue; + + // Fold the used expression into this expression by cloning all + // instructions in the used expression just before the operation using + // its value. + rewriter.setInsertionPoint(&op); + IRMapping mapper; + for (Operation &opToClone : + usedExpression.getBody()->without_terminator()) { + Operation *clone = rewriter.clone(opToClone, mapper); + mapper.map(&opToClone, clone); + } + + Operation *expressionRoot = usedExpression.getRootOp(); + Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot); + assert(clonedExpressionRootOp && + "Expected cloned expression root to be in mapper"); + assert(clonedExpressionRootOp->getNumResults() == 1 && + "Expected cloned root to have a single result"); + + rewriter.replaceOp(usedExpression, clonedExpressionRootOp); + anythingFolded = true; + } + } + return anythingFolded ? success() : failure(); + } +}; + +} // namespace + +void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 458bf83eac17f..95b0d6ef1ae2a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2166,11 +2166,10 @@ LogicalResult ShuffleVectorOp::verify() { //===----------------------------------------------------------------------===// // Add the entry block to the function. -Block *LLVMFuncOp::addEntryBlock() { +Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) { assert(empty() && "function already has an entry block"); - - auto *entry = new Block; - push_back(entry); + OpBuilder::InsertionGuard g(builder); + Block *entry = builder.createBlock(&getBody()); // FIXME: Allow passing in proper locations for the entry arguments. LLVMFunctionType type = getFunctionType(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 6fbf351455787..370dee4448eb4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -132,12 +132,10 @@ struct MoveInitOperandsToInput : public OpRewritePattern { newIndexingMaps, genericOp.getIteratorTypesArray(), /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + OpBuilder::InsertionGuard guard(rewriter); Region ®ion = newOp.getRegion(); - Block *block = new Block(); - region.push_back(block); + Block *block = rewriter.createBlock(®ion); IRMapping mapper; - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(block); for (auto bbarg : genericOp.getRegionInputArgs()) mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 3eb91190751ef..55ff33792de61 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -178,11 +178,9 @@ static void generateFusedElementwiseOpRegion( // Build the region of the fused op. Block &producerBlock = producer->getRegion(0).front(); Block &consumerBlock = consumer->getRegion(0).front(); - Block *fusedBlock = new Block(); - fusedOp.getRegion().push_back(fusedBlock); - IRMapping mapper; OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(fusedBlock); + Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); + IRMapping mapper; // 2. Add an index operation for every fused loop dimension and use the // `consumerToProducerLoopsMap` to map the producer indices. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 75c8cd3e1d95a..6bdbc89608921 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -275,14 +275,13 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, auto transposeOp = b.create(loc, resultTensorType, inputTensor, outputTensor, indexingMaps, iteratorTypes); - Region &body = transposeOp.getRegion(); - body.push_back(new Block()); - body.front().addArguments({elementType, elementType}, {loc, loc}); // Create the body of the transpose operation. OpBuilder::InsertionGuard g(b); - b.setInsertionPointToEnd(&body.front()); - b.create(loc, transposeOp.getRegion().front().getArgument(0)); + Region &body = transposeOp.getRegion(); + Block *bodyBlock = b.createBlock(&body, /*insertPt=*/{}, + {elementType, elementType}, {loc, loc}); + b.create(loc, bodyBlock->getArgument(0)); return transposeOp; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 93327a28234ea..03017afe95dbd 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1420,6 +1420,7 @@ OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() { void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, Value memref, ValueRange ivs) { + OpBuilder::InsertionGuard g(builder); result.addOperands(memref); result.addOperands(ivs); @@ -1428,7 +1429,7 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, result.addTypes(elementType); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block()); + builder.createBlock(bodyRegion); bodyRegion->addArgument(elementType, memref.getLoc()); } } diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index 580782043c81b..7170a899069ee 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -365,12 +365,11 @@ Block *LoopOp::getMergeBlock() { return &getBody().back(); } -void LoopOp::addEntryAndMergeBlock() { +void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) { assert(getBody().empty() && "entry and merge block already exist"); - getBody().push_back(new Block()); - auto *mergeBlock = new Block(); - getBody().push_back(mergeBlock); - OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); + OpBuilder::InsertionGuard g(builder); + builder.createBlock(&getBody()); + builder.createBlock(&getBody()); // Add a spirv.mlir.merge op into the merge block. builder.create(getLoc()); @@ -525,11 +524,10 @@ Block *SelectionOp::getMergeBlock() { return &getBody().back(); } -void SelectionOp::addMergeBlock() { +void SelectionOp::addMergeBlock(OpBuilder &builder) { assert(getBody().empty() && "entry and merge block already exist"); - auto *mergeBlock = new Block(); - getBody().push_back(mergeBlock); - OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); + OpBuilder::InsertionGuard guard(builder); + builder.createBlock(&getBody()); // Add a spirv.mlir.merge op into the merge block. builder.create(getLoc()); @@ -542,7 +540,7 @@ SelectionOp::createIfThen(Location loc, Value condition, auto selectionOp = builder.create(loc, spirv::SelectionControl::None); - selectionOp.addMergeBlock(); + selectionOp.addMergeBlock(builder); Block *mergeBlock = selectionOp.getMergeBlock(); Block *thenBlock = nullptr; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 4f829db1305c8..d9ee39a4e8dd3 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -375,15 +375,13 @@ void AssumingOp::inlineRegionIntoParent(AssumingOp &op, void AssumingOp::build( OpBuilder &builder, OperationState &result, Value witness, function_ref(OpBuilder &, Location)> bodyBuilder) { + OpBuilder::InsertionGuard g(builder); result.addOperands(witness); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); + builder.createBlock(bodyRegion); // Build body. - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); SmallVector yieldValues = bodyBuilder(builder, result.location); builder.create(result.location, yieldValues); @@ -1904,23 +1902,23 @@ bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, ValueRange initVals) { + OpBuilder::InsertionGuard g(builder); result.addOperands(shape); result.addOperands(initVals); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); - bodyBlock.addArgument(builder.getIndexType(), result.location); + Block *bodyBlock = builder.createBlock( + bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location); Type elementType; if (auto tensorType = llvm::dyn_cast(shape.getType())) elementType = tensorType.getElementType(); else elementType = SizeType::get(builder.getContext()); - bodyBlock.addArgument(elementType, shape.getLoc()); + bodyBlock->addArgument(elementType, shape.getLoc()); for (Value initVal : initVals) { - bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); + bodyBlock->addArgument(initVal.getType(), initVal.getLoc()); result.addTypes(initVal.getType()); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 3b9685b8ae1e0..9a483078a4a44 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -299,8 +299,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern { Block &prodBlock = prod.getRegion().front(); Block &consBlock = op.getRegion().front(); IRMapping mapper; - Block *fusedBlock = new Block(); - fusedOp.getRegion().push_back(fusedBlock); + Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); unsigned num = prodBlock.getNumArguments(); for (unsigned i = 0; i < num - 1; i++) addArg(mapper, fusedBlock, prodBlock.getArgument(i)); @@ -309,7 +308,6 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern { // Clone bodies of the producer and consumer in new evaluation order. auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); - rewriter.setInsertionPointToStart(fusedBlock); Value last; for (auto &op : prodBlock.without_terminator()) if (&op != acc) { diff --git a/mlir/lib/Target/Cpp/CMakeLists.txt b/mlir/lib/Target/Cpp/CMakeLists.txt index 5521e7909a8ab..d8f372cf16245 100644 --- a/mlir/lib/Target/Cpp/CMakeLists.txt +++ b/mlir/lib/Target/Cpp/CMakeLists.txt @@ -6,7 +6,6 @@ add_mlir_translation_library(MLIRTargetCpp ${EMITC_MAIN_INCLUDE_DIR}/emitc/Target/Cpp LINK_LIBS PUBLIC - MLIRArithDialect MLIRControlFlowDialect MLIREmitCDialect MLIRFuncDialect diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp index b486e5429ea6a..4104b177d7d9a 100644 --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -41,8 +40,7 @@ void registerToCppTranslation() { }, [](DialectRegistry ®istry) { // clang-format off - registry.insert getOperatorPrecedence(Operation *operation) { + return llvm::TypeSwitch>(operation) + .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) { return 7; }) + .Case([&](auto op) { return 11; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) { return 5; }) + .Case([&](auto op) { return 11; }) + .Case([&](auto op) { return 6; }) + .Case([&](auto op) { return 16; }) + .Case([&](auto op) { return 16; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) -> FailureOr { + switch (op.getPredicate()) { + case emitc::CmpPredicate::eq: + case emitc::CmpPredicate::ne: + return 8; + case emitc::CmpPredicate::lt: + case emitc::CmpPredicate::le: + case emitc::CmpPredicate::gt: + case emitc::CmpPredicate::ge: + return 9; + case emitc::CmpPredicate::three_way: + return 10; + } + return op->emitError("unsupported cmp predicate"); + }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 4; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) { return 3; }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) { return 15; }) + .Default([](auto op) { return op->emitError("unsupported operation"); }); +} + namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { @@ -119,6 +162,12 @@ struct CppEmitter { /// Emits the operands of the operation. All operands are emitted in order. LogicalResult emitOperands(Operation &op); + /// Emits value as an operands of an operation + LogicalResult emitOperand(Value value); + + /// Emit an expression as a C expression. + LogicalResult emitExpression(ExpressionOp expressionOp); + /// Return the existing or a new name for a Value. StringRef getOrCreateName(Value val); @@ -163,6 +212,21 @@ struct CppEmitter { /// be declared at the beginning of a function. bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; + /// Get expression currently being emitted. + ExpressionOp getEmittedExpression() { return emittedExpression; } + + /// Determine whether given value is part of the expression potentially being + /// emitted. + bool isPartOfCurrentExpression(Value value) { + if (!emittedExpression) + return false; + Operation *def = value.getDefiningOp(); + if (!def) + return false; + auto operandExpression = dyn_cast(def->getParentOp()); + return operandExpression == emittedExpression; + }; + private: using ValueMapper = llvm::ScopedHashTable; using BlockMapper = llvm::ScopedHashTable; @@ -185,9 +249,50 @@ struct CppEmitter { /// names of values in a scope. std::stack valueInScopeCount; std::stack labelInScopeCount; + + /// State of the current expression being emitted. + ExpressionOp emittedExpression; + SmallVector emittedExpressionPrecedence; + + void pushExpressionPrecedence(int precedence) { + emittedExpressionPrecedence.push_back(precedence); + } + void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); } + static int lowestPrecedence() { return 0; } + int getExpressionPrecedence() { + if (emittedExpressionPrecedence.empty()) + return lowestPrecedence(); + return emittedExpressionPrecedence.back(); + } }; } // namespace +/// Determine whether expression \p expressionOp should be emitted inline, i.e. +/// as part of its user. This function recommends inlining of any expressions +/// that can be inlined unless it is used by another expression, under the +/// assumption that any expression fusion/re-materialization was taken care of +/// by transformations run by the backend. +static bool shouldBeInlined(ExpressionOp expressionOp) { + // Do not inline if expression is marked as such. + if (expressionOp.getDoNotInline()) + return false; + + // Do not inline expressions with side effects to prevent side-effect + // reordering. + if (expressionOp.hasSideEffects()) + return false; + + // Do not inline expressions with multiple uses. + Value result = expressionOp.getResult(); + if (!result.hasOneUse()) + return false; + + // Do not inline expressions used by other expressions, as any desired + // expression folding was taken care of by transformations. + Operation *user = *result.getUsers().begin(); + return !user->getParentOfType(); +} + static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value) { OpResult result = operation->getResult(0); @@ -236,22 +341,6 @@ static LogicalResult printOperation(CppEmitter &emitter, return printConstantOp(emitter, operation, value); } -static LogicalResult printOperation(CppEmitter &emitter, - arith::ConstantOp constantOp) { - Operation *operation = constantOp.getOperation(); - Attribute value = constantOp.getValue(); - - return printConstantOp(emitter, operation, value); -} - -static LogicalResult printOperation(CppEmitter &emitter, - func::ConstantOp constantOp) { - Operation *operation = constantOp.getOperation(); - Attribute value = constantOp.getValueAttr(); - - return printConstantOp(emitter, operation, value); -} - static LogicalResult printOperation(CppEmitter &emitter, emitc::AssignOp assignOp) { OpResult result = assignOp.getVar().getDefiningOp()->getResult(0); @@ -259,9 +348,7 @@ static LogicalResult printOperation(CppEmitter &emitter, if (failed(emitter.emitVariableAssignment(result))) return failure(); - emitter.ostream() << emitter.getOrCreateName(assignOp.getValue()); - - return success(); + return emitter.emitOperand(assignOp.getValue()); } static LogicalResult printOperation(CppEmitter &emitter, @@ -278,9 +365,30 @@ static LogicalResult printBinaryOperation(CppEmitter &emitter, if (failed(emitter.emitAssignPrefix(*operation))) return failure(); - os << emitter.getOrCreateName(operation->getOperand(0)); - os << " " << binaryOperator; - os << " " << emitter.getOrCreateName(operation->getOperand(1)); + + if (failed(emitter.emitOperand(operation->getOperand(0)))) + return failure(); + + os << " " << binaryOperator << " "; + + if (failed(emitter.emitOperand(operation->getOperand(1)))) + return failure(); + + return success(); +} + +static LogicalResult printUnaryOperation(CppEmitter &emitter, + Operation *operation, + StringRef unaryOperator) { + raw_ostream &os = emitter.ostream(); + + if (failed(emitter.emitAssignPrefix(*operation))) + return failure(); + + os << unaryOperator; + + if (failed(emitter.emitOperand(operation->getOperand(0)))) + return failure(); return success(); } @@ -347,6 +455,15 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) { return printBinaryOperation(emitter, operation, binaryOperator); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::VerbatimOp verbatimOp) { + raw_ostream &os = emitter.ostream(); + + os << verbatimOp.getValue(); + + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, cf::BranchOp branchOp) { raw_ostream &os = emitter.ostream(); @@ -413,18 +530,33 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } -static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { - if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) +static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, + StringRef callee) { + if (failed(emitter.emitAssignPrefix(*callOp))) return failure(); raw_ostream &os = emitter.ostream(); - os << callOp.getCallee() << "("; - if (failed(emitter.emitOperands(*callOp.getOperation()))) + os << callee << "("; + if (failed(emitter.emitOperands(*callOp))) return failure(); os << ")"; return success(); } +static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { + Operation *operation = callOp.getOperation(); + StringRef callee = callOp.getCallee(); + + return printCallOperation(emitter, operation, callee); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { + Operation *operation = callOp.getOperation(); + StringRef callee = callOp.getCallee(); + + return printCallOperation(emitter, operation, callee); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOpaqueOp callOpaqueOp) { raw_ostream &os = emitter.ostream(); @@ -488,6 +620,56 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::BitwiseAndOp bitwiseAndOp) { + Operation *operation = bitwiseAndOp.getOperation(); + return printBinaryOperation(emitter, operation, "&"); +} + +static LogicalResult +printOperation(CppEmitter &emitter, + emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) { + Operation *operation = bitwiseLeftShiftOp.getOperation(); + return printBinaryOperation(emitter, operation, "<<"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::BitwiseNotOp bitwiseNotOp) { + Operation *operation = bitwiseNotOp.getOperation(); + return printUnaryOperation(emitter, operation, "~"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::BitwiseOrOp bitwiseOrOp) { + Operation *operation = bitwiseOrOp.getOperation(); + return printBinaryOperation(emitter, operation, "|"); +} + +static LogicalResult +printOperation(CppEmitter &emitter, + emitc::BitwiseRightShiftOp bitwiseRightShiftOp) { + Operation *operation = bitwiseRightShiftOp.getOperation(); + return printBinaryOperation(emitter, operation, ">>"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::BitwiseXorOp bitwiseXorOp) { + Operation *operation = bitwiseXorOp.getOperation(); + return printBinaryOperation(emitter, operation, "^"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::UnaryPlusOp unaryPlusOp) { + Operation *operation = unaryPlusOp.getOperation(); + return printUnaryOperation(emitter, operation, "+"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::UnaryMinusOp unaryMinusOp) { + Operation *operation = unaryMinusOp.getOperation(); + return printUnaryOperation(emitter, operation, "-"); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { raw_ostream &os = emitter.ostream(); Operation &op = *castOp.getOperation(); @@ -498,9 +680,20 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) return failure(); os << ") "; - os << emitter.getOrCreateName(castOp.getOperand()); + return emitter.emitOperand(castOp.getOperand()); +} - return success(); +static LogicalResult printOperation(CppEmitter &emitter, + emitc::ExpressionOp expressionOp) { + if (shouldBeInlined(expressionOp)) + return success(); + + Operation &op = *expressionOp.getOperation(); + + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + + return emitter.emitExpression(expressionOp); } static LogicalResult printOperation(CppEmitter &emitter, @@ -516,10 +709,39 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LogicalAndOp logicalAndOp) { + Operation *operation = logicalAndOp.getOperation(); + return printBinaryOperation(emitter, operation, "&&"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LogicalNotOp logicalNotOp) { + Operation *operation = logicalNotOp.getOperation(); + return printUnaryOperation(emitter, operation, "!"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LogicalOrOp logicalOrOp) { + Operation *operation = logicalOrOp.getOperation(); + return printBinaryOperation(emitter, operation, "||"); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { raw_indented_ostream &os = emitter.ostream(); + // Utility function to determine whether a value is an expression that will be + // inlined, and as such should be wrapped in parentheses in order to guarantee + // its precedence and associativity. + auto requiresParentheses = [&](Value value) { + auto expressionOp = + dyn_cast_if_present(value.getDefiningOp()); + if (!expressionOp) + return false; + return shouldBeInlined(expressionOp); + }; + os << "for ("; if (failed( emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) @@ -527,15 +749,24 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { os << " "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " = "; - os << emitter.getOrCreateName(forOp.getLowerBound()); + if (failed(emitter.emitOperand(forOp.getLowerBound()))) + return failure(); os << "; "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " < "; - os << emitter.getOrCreateName(forOp.getUpperBound()); + Value upperBound = forOp.getUpperBound(); + bool upperBoundRequiresParentheses = requiresParentheses(upperBound); + if (upperBoundRequiresParentheses) + os << "("; + if (failed(emitter.emitOperand(upperBound))) + return failure(); + if (upperBoundRequiresParentheses) + os << ")"; os << "; "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " += "; - os << emitter.getOrCreateName(forOp.getStep()); + if (failed(emitter.emitOperand(forOp.getStep()))) + return failure(); os << ") {\n"; os.indent(); @@ -570,7 +801,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) { }; os << "if ("; - if (failed(emitter.emitOperands(*ifOp.getOperation()))) + if (failed(emitter.emitOperand(ifOp.getCondition()))) return failure(); os << ") {\n"; os.indent(); @@ -598,8 +829,10 @@ static LogicalResult printOperation(CppEmitter &emitter, case 0: return success(); case 1: - os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); - return success(emitter.hasValueInScope(returnOp.getOperand(0))); + os << " "; + if (failed(emitter.emitOperand(returnOp.getOperand(0)))) + return failure(); + return success(); default: os << " std::make_tuple("; if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) @@ -609,6 +842,19 @@ static LogicalResult printOperation(CppEmitter &emitter, } } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::ReturnOp returnOp) { + raw_ostream &os = emitter.ostream(); + os << "return"; + if (returnOp.getNumOperands() == 0) + return success(); + + os << " "; + if (failed(emitter.emitOperand(returnOp.getOperand()))) + return failure(); + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { CppEmitter::Scope scope(emitter); @@ -619,38 +865,48 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { return success(); } -static LogicalResult printOperation(CppEmitter &emitter, - func::FuncOp functionOp) { - // We need to declare variables at top if the function has multiple blocks. - if (!emitter.shouldDeclareVariablesAtTop() && - functionOp.getBlocks().size() > 1) { - return functionOp.emitOpError( - "with multiple blocks needs variables declared at top"); - } +static LogicalResult printFunctionArgs(CppEmitter &emitter, + Operation *functionOp, + ArrayRef arguments) { + raw_indented_ostream &os = emitter.ostream(); - CppEmitter::Scope scope(emitter); + return ( + interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult { + return emitter.emitType(functionOp->getLoc(), arg); + })); +} + +static LogicalResult printFunctionArgs(CppEmitter &emitter, + Operation *functionOp, + Region::BlockArgListType arguments) { raw_indented_ostream &os = emitter.ostream(); - if (failed(emitter.emitTypes(functionOp.getLoc(), - functionOp.getFunctionType().getResults()))) - return failure(); - os << " " << functionOp.getName(); - os << "("; - if (failed(interleaveCommaWithError(functionOp.getArguments(), os, + if (failed(interleaveCommaWithError(arguments, os, [&](BlockArgument arg) -> LogicalResult { return emitter.emitVariableDeclaration( - functionOp.getLoc(), arg.getType(), + functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg)); }))) return failure(); - os << ") {\n"; + + return success(); +} + +static LogicalResult printFunctionBody(CppEmitter &emitter, + Operation *functionOp, + Region::BlockListType &blocks) { + raw_indented_ostream &os = emitter.ostream(); os.indent(); + if (emitter.shouldDeclareVariablesAtTop()) { // Declare all variables that hold op results including those from nested // regions. WalkResult result = - functionOp.walk([&](Operation *op) -> WalkResult { - if (isa(op)) + functionOp->walk([&](Operation *op) -> WalkResult { + if (isa(op) || + isa(op->getParentOp()) || + (isa(op) && + shouldBeInlined(cast(op)))) return WalkResult::skip(); for (OpResult result : op->getResults()) { if (failed(emitter.emitVariableDeclaration( @@ -665,7 +921,6 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); } - Region::BlockListType &blocks = functionOp.getBlocks(); // Create label names for basic blocks. for (Block &block : blocks) { emitter.getOrCreateName(block); @@ -675,7 +930,7 @@ static LogicalResult printOperation(CppEmitter &emitter, for (Block &block : llvm::drop_begin(blocks)) { for (BlockArgument &arg : block.getArguments()) { if (emitter.hasValueInScope(arg)) - return functionOp.emitOpError(" block argument #") + return functionOp->emitOpError(" block argument #") << arg.getArgNumber() << " is out of scope"; if (failed( emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { @@ -694,18 +949,120 @@ static LogicalResult printOperation(CppEmitter &emitter, for (Operation &op : block.getOperations()) { // When generating code for an emitc.if or cf.cond_br op no semicolon // needs to be printed after the closing brace. - // When generating code for an emitc.for op, printing a trailing semicolon - // is handled within the printOperation function. + // When generating code for an emitc.for and emitc.verbatim op, printing a + // trailing semicolon is handled within the printOperation function. bool trailingSemicolon = - !isa( - op); + !isa(op); if (failed(emitter.emitOperation( op, /*trailingSemicolon=*/trailingSemicolon))) return failure(); } } - os.unindent() << "}\n"; + + os.unindent(); + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + func::FuncOp functionOp) { + // We need to declare variables at top if the function has multiple blocks. + if (!emitter.shouldDeclareVariablesAtTop() && + functionOp.getBlocks().size() > 1) { + return functionOp.emitOpError( + "with multiple blocks needs variables declared at top"); + } + + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + Operation *operation = functionOp.getOperation(); + if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) + return failure(); + os << ") {\n"; + if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) + return failure(); + os << "}\n"; + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::FuncOp functionOp) { + // We need to declare variables at top if the function has multiple blocks. + if (!emitter.shouldDeclareVariablesAtTop() && + functionOp.getBlocks().size() > 1) { + return functionOp.emitOpError( + "with multiple blocks needs variables declared at top"); + } + + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + if (functionOp.getSpecifiers()) { + for (Attribute specifier : functionOp.getSpecifiersAttr()) { + os << cast(specifier).str() << " "; + } + } + + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + Operation *operation = functionOp.getOperation(); + if (functionOp.isExternal()) { + if (failed(printFunctionArgs(emitter, operation, + functionOp.getArgumentTypes()))) + return failure(); + os << ");"; + return success(); + } + if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) + return failure(); + os << ") {\n"; + if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) + return failure(); + os << "}\n"; + + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + DeclareFuncOp declareFuncOp) { + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + + auto functionOp = SymbolTable::lookupNearestSymbolFrom( + declareFuncOp, declareFuncOp.getSymNameAttr()); + + if (!functionOp) + return failure(); + + if (functionOp.getSpecifiers()) { + for (Attribute specifier : functionOp.getSpecifiersAttr()) { + os << cast(specifier).str() << " "; + } + } + + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + Operation *operation = functionOp.getOperation(); + if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) + return failure(); + os << ");"; + return success(); } @@ -868,15 +1225,75 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { return emitError(loc, "cannot emit attribute: ") << attr; } +LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { + assert(emittedExpressionPrecedence.empty() && + "Expected precedence stack to be empty"); + Operation *rootOp = expressionOp.getRootOp(); + + emittedExpression = expressionOp; + FailureOr precedence = getOperatorPrecedence(rootOp); + if (failed(precedence)) + return failure(); + pushExpressionPrecedence(precedence.value()); + + if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false))) + return failure(); + + popExpressionPrecedence(); + assert(emittedExpressionPrecedence.empty() && + "Expected precedence stack to be empty"); + emittedExpression = nullptr; + + return success(); +} + +LogicalResult CppEmitter::emitOperand(Value value) { + if (isPartOfCurrentExpression(value)) { + Operation *def = value.getDefiningOp(); + assert(def && "Expected operand to be defined by an operation"); + FailureOr precedence = getOperatorPrecedence(def); + if (failed(precedence)) + return failure(); + bool encloseInParenthesis = precedence.value() < getExpressionPrecedence(); + if (encloseInParenthesis) { + os << "("; + pushExpressionPrecedence(lowestPrecedence()); + } else + pushExpressionPrecedence(precedence.value()); + + if (failed(emitOperation(*def, /*trailingSemicolon=*/false))) + return failure(); + + if (encloseInParenthesis) + os << ")"; + + popExpressionPrecedence(); + return success(); + } + + auto expressionOp = dyn_cast_if_present(value.getDefiningOp()); + if (expressionOp && shouldBeInlined(expressionOp)) + return emitExpression(expressionOp); + + auto literalOp = dyn_cast_if_present(value.getDefiningOp()); + if (!literalOp && !hasValueInScope(value)) + return failure(); + os << getOrCreateName(value); + return success(); +} + LogicalResult CppEmitter::emitOperands(Operation &op) { - auto emitOperandName = [&](Value result) -> LogicalResult { - auto literalDef = dyn_cast_if_present(result.getDefiningOp()); - if (!literalDef && !hasValueInScope(result)) - return op.emitOpError() << "operand value not in scope"; - os << getOrCreateName(result); + return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) { + // If an expression is being emitted, push lowest precedence as these + // operands are either wrapped by parenthesis. + if (getEmittedExpression()) + pushExpressionPrecedence(lowestPrecedence()); + if (failed(emitOperand(operand))) + return failure(); + if (getEmittedExpression()) + popExpressionPrecedence(); return success(); - }; - return interleaveCommaWithError(op.getOperands(), os, emitOperandName); + }); } LogicalResult @@ -932,6 +1349,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, } LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { + // If op is being emitted as part of an expression, bail out. + if (getEmittedExpression()) + return success(); + switch (op.getNumResults()) { case 0: break; @@ -981,16 +1402,19 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { [&](auto op) { return printOperation(*this, op); }) // EmitC ops. .Case( + emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, + emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp, + emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp, + emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, + emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp, + emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. - .Case( - [&](auto op) { return printOperation(*this, op); }) - // Arithmetic ops. - .Case( + .Case( [&](auto op) { return printOperation(*this, op); }) .Case([&](auto op) { return success(); }) .Default([&](Operation *) { @@ -1003,7 +1427,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (isa(op)) return success(); + if (getEmittedExpression() || + (isa(op) && + shouldBeInlined(cast(op)))) + return success(); + os << (trailingSemicolon ? ";\n" : "\n"); + return success(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 89e2e7ad52fa7..b9455ea41e64b 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1799,7 +1799,7 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { auto control = static_cast(selectionControl); auto selectionOp = builder.create(location, control); - selectionOp.addMergeBlock(); + selectionOp.addMergeBlock(builder); return selectionOp; } @@ -1811,7 +1811,7 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { auto control = static_cast(loopControl); auto loopOp = builder.create(location, control); - loopOp.addEntryAndMergeBlock(); + loopOp.addEntryAndMergeBlock(builder); return loopOp; } diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir deleted file mode 100644 index b13c6506787c5..0000000000000 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s - -func.func @arith_constant_complex_tensor() -> (tensor>) { - // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}} - %c = arith.constant dense<(2, 2)> : tensor> - return %c : tensor> -} - -// ----- - -func.func @arith_constant_invalid_int_type() -> (i10) { - // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}} - %c = arith.constant 0 : i10 - return %c : i10 -} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 2583dd832c314..2886810c01e91 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -19,3 +19,18 @@ func.func @arith_constants() { %c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64> return } + +// ----- + +func.func @arith_ops(%arg0: f32, %arg1: f32) { + // CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32 + %0 = arith.addf %arg0, %arg1 : f32 + // CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32 + %1 = arith.divf %arg0, %arg1 : f32 + // CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32 + %2 = arith.mulf %arg0, %arg1 : f32 + // CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32 + %3 = arith.subf %arg0, %arg1 : f32 + + return +} diff --git a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir new file mode 100644 index 0000000000000..5c96cf1ce0d34 --- /dev/null +++ b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt -split-input-file -convert-func-to-emitc %s | FileCheck %s + +// CHECK-LABEL: emitc.func @foo() +// CHECK-NEXT: emitc.return +func.func @foo() { + return +} + +// ----- + +// CHECK-LABEL: emitc.func private @foo() attributes {specifiers = ["static"]} +// CHECK-NEXT: emitc.return +func.func private @foo() { + return +} + +// ----- + +// CHECK-LABEL: emitc.func @foo(%arg0: i32) +func.func @foo(%arg0: i32) { + emitc.call_opaque "bar"(%arg0) : (i32) -> () + return +} + +// ----- + +// CHECK-LABEL: emitc.func @foo(%arg0: i32) -> i32 +// CHECK-NEXT: emitc.return %arg0 : i32 +func.func @foo(%arg0: i32) -> i32 { + return %arg0 : i32 +} + +// ----- + +// CHECK-LABEL: emitc.func @foo(%arg0: i32, %arg1: i32) -> i32 +func.func @foo(%arg0: i32, %arg1: i32) -> i32 { + %0 = "emitc.add" (%arg0, %arg1) : (i32, i32) -> i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: emitc.func private @return_i32(%arg0: i32) -> i32 attributes {specifiers = ["static"]} +// CHECK-NEXT: emitc.return %arg0 : i32 +func.func private @return_i32(%arg0: i32) -> i32 { + return %arg0 : i32 +} + +// CHECK-LABEL: emitc.func @call(%arg0: i32) -> i32 +// CHECK-NEXT: %0 = emitc.call @return_i32(%arg0) : (i32) -> i32 +// CHECK-NEXT: emitc.return %0 : i32 +func.func @call(%arg0: i32) -> i32 { + %0 = call @return_i32(%arg0) : (i32) -> (i32) + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: emitc.func private @return_i32(i32) -> i32 attributes {specifiers = ["extern"]} +func.func private @return_i32(%arg0: i32) -> i32 diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index fd79bbd8a1d30..51b68eecbbd56 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -1,7 +1,15 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +func.func @const_attribute_str() { + // expected-error @+1 {{'emitc.constant' op string attributes are not supported, use #emitc.opaque instead}} + %c0 = "emitc.constant"(){value = "NULL"} : () -> !emitc.ptr + return +} + +// ----- + func.func @const_attribute_return_type_1() { - // expected-error @+1 {{'emitc.constant' op requires attribute's type ('i64') to match op's return type ('i32')}} + // expected-error @+1 {{'emitc.constant' op requires attribute to either be an #emitc.opaque attribute or it's type ('i64') to match the op's result type ('i32')}} %c0 = "emitc.constant"(){value = 42: i64} : () -> i32 return } @@ -9,8 +17,8 @@ func.func @const_attribute_return_type_1() { // ----- func.func @const_attribute_return_type_2() { - // expected-error @+1 {{'emitc.constant' op requires attribute's type ('!emitc.opaque<"char">') to match op's return type ('!emitc.opaque<"mychar">')}} - %c0 = "emitc.constant"(){value = "CHAR_MIN" : !emitc.opaque<"char">} : () -> !emitc.opaque<"mychar"> + // expected-error @+1 {{'emitc.constant' op attribute 'value' failed to satisfy constraint: An opaque attribute or TypedAttr instance}} + %c0 = "emitc.constant"(){value = unit} : () -> i32 return } @@ -18,7 +26,7 @@ func.func @const_attribute_return_type_2() { func.func @empty_constant() { // expected-error @+1 {{'emitc.constant' op value must not be empty}} - %c0 = "emitc.constant"(){value = ""} : () -> i32 + %c0 = "emitc.constant"(){value = #emitc.opaque<"">} : () -> i32 return } @@ -98,7 +106,7 @@ func.func @illegal_operand() { // ----- func.func @var_attribute_return_type_1() { - // expected-error @+1 {{'emitc.variable' op requires attribute's type ('i64') to match op's return type ('i32')}} + // expected-error @+1 {{'emitc.variable' op requires attribute to either be an #emitc.opaque attribute or it's type ('i64') to match the op's result type ('i32')}} %c0 = "emitc.variable"(){value = 42: i64} : () -> i32 return } @@ -106,8 +114,8 @@ func.func @var_attribute_return_type_1() { // ----- func.func @var_attribute_return_type_2() { - // expected-error @+1 {{'emitc.variable' op requires attribute's type ('!emitc.ptr') to match op's return type ('!emitc.ptr')}} - %c0 = "emitc.variable"(){value = "nullptr" : !emitc.ptr} : () -> !emitc.ptr + // expected-error @+1 {{'emitc.variable' op attribute 'value' failed to satisfy constraint: An opaque attribute or TypedAttr instance}} + %c0 = "emitc.variable"(){value = unit} : () -> i32 return } @@ -203,7 +211,7 @@ func.func @sub_pointer_pointer(%arg0: !emitc.ptr, %arg1: !emitc.ptr) { // ----- func.func @test_misplaced_yield() { - // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.if, emitc.for'}} + // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.expression, emitc.if, emitc.for'}} emitc.yield return } @@ -232,3 +240,126 @@ func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: %0 = emitc.subscript %arg0[%arg2] : <4x8xf32> return } + +// ----- + +func.func @test_expression_no_yield() -> i32 { + // expected-error @+1 {{'emitc.expression' op must yield a value at termination}} + %r = emitc.expression : i32 { + %c7 = "emitc.constant"(){value = 7 : i32} : () -> i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_illegal_op(%arg0 : i1) -> i32 { + // expected-error @+1 {{'emitc.expression' op contains an unsupported operation}} + %r = emitc.expression : i32 { + %x = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32 + emitc.yield %x : i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 { + // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}} + %r = emitc.expression : i32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + emitc.yield %a : i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 { + // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}} + %r = emitc.expression : i32 { + %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.add %a, %arg0 : (i32, i32) -> i32 + %c = emitc.mul %arg1, %a : (i32, i32) -> i32 + emitc.yield %a : i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_multiple_results(%arg0: i32) -> i32 { + // expected-error @+1 {{'emitc.expression' op requires exactly one result for each operation}} + %r = emitc.expression : i32 { + %a:2 = emitc.call_opaque "bar" (%arg0) : (i32) -> (i32, i32) + emitc.yield %a : i32 + } + return %r : i32 +} + +// ----- + +// expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}} +emitc.func @multiple_results(%0: i32) -> (i32, i32) { + emitc.return %0 : i32 +} + +// ----- + +emitc.func @resulterror() -> i32 { +^bb42: + emitc.return // expected-error {{'emitc.return' op has 0 operands, but enclosing function (@resulterror) returns 1}} +} + +// ----- + +emitc.func @return_type_mismatch() -> i32 { + %0 = emitc.call_opaque "foo()"(): () -> f32 + emitc.return %0 : f32 // expected-error {{type of the return operand ('f32') doesn't match function result type ('i32') in function @return_type_mismatch}} +} + +// ----- + +func.func @return_inside_func.func(%0: i32) -> (i32) { + // expected-error@+1 {{'emitc.return' op expects parent op 'emitc.func'}} + emitc.return %0 : i32 +} +// ----- + +// expected-error@+1 {{expected non-function type}} +emitc.func @func_variadic(...) + +// ----- + +// expected-error@+1 {{'emitc.declare_func' op 'bar' does not reference a valid function}} +emitc.declare_func @bar + +// ----- + +// expected-error@+1 {{'emitc.declare_func' op requires attribute 'sym_name'}} +"emitc.declare_func"() : () -> () + +// ----- + +func.func @logical_and_resulterror(%arg0: i32, %arg1: i32) { + // expected-error @+1 {{'emitc.logical_and' op result #0 must be 1-bit signless integer, but got 'i32'}} + %0 = "emitc.logical_and"(%arg0, %arg1) : (i32, i32) -> i32 + return +} + +// ----- + +func.func @logical_not_resulterror(%arg0: i32) { + // expected-error @+1 {{'emitc.logical_not' op result #0 must be 1-bit signless integer, but got 'i32'}} + %0 = "emitc.logical_not"(%arg0) : (i32) -> i32 + return +} + +// ----- + +func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) { + // expected-error @+1 {{'emitc.logical_or' op result #0 must be 1-bit signless integer, but got 'i32'}} + %0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32 + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index d280f12b78516..02294d13cef76 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -15,6 +15,25 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) { return } +emitc.declare_func @func + +emitc.func @func(%arg0 : i32) { + emitc.call_opaque "foo"(%arg0) : (i32) -> () + emitc.return +} + +emitc.func @return_i32() -> i32 attributes {specifiers = ["static","inline"]} { + %0 = emitc.call_opaque "foo"(): () -> i32 + emitc.return %0 : i32 +} + +emitc.func @call() -> i32 { + %0 = emitc.call @return_i32() : () -> (i32) + emitc.return %0 : i32 +} + +emitc.func private @extern(i32) attributes {specifiers = ["extern"]} + func.func @cast(%arg0: i32) { %1 = emitc.cast %arg0: i32 to f32 return @@ -42,6 +61,16 @@ func.func @add_pointer(%arg0: !emitc.ptr, %arg1: i32, %arg2: !emitc.opaque< return } +func.func @bitwise(%arg0: i32, %arg1: i32) -> () { + %0 = emitc.bitwise_and %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.bitwise_left_shift %arg0, %arg1 : (i32, i32) -> i32 + %2 = emitc.bitwise_not %arg0 : (i32) -> i32 + %3 = emitc.bitwise_or %arg0, %arg1 : (i32, i32) -> i32 + %4 = emitc.bitwise_right_shift %arg0, %arg1 : (i32, i32) -> i32 + %5 = emitc.bitwise_xor %arg0, %arg1 : (i32, i32) -> i32 + return +} + func.func @div_int(%arg0: i32, %arg1: i32) { %1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32 return @@ -98,6 +127,19 @@ func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emit return } +func.func @logical(%arg0: i32, %arg1: i32) { + %0 = emitc.logical_and %arg0, %arg1 : i32, i32 + %1 = emitc.logical_not %arg0 : i32 + %2 = emitc.logical_or %arg0, %arg1 : i32, i32 + return +} + +func.func @unary(%arg0: i32) { + %0 = emitc.unary_minus %arg0 : (i32) -> i32 + %1 = emitc.unary_plus %arg0 : (i32) -> i32 + return +} + func.func @test_if(%arg0: i1, %arg1: f32) { emitc.if %arg0 { %0 = emitc.call_opaque "func_const"(%arg1) : (f32) -> i32 @@ -128,6 +170,23 @@ func.func @test_assign(%arg1: f32) { return } +func.func @test_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> i32 { + %c7 = "emitc.constant"() {value = 7 : i32} : () -> i32 + %q = emitc.expression : i32 { + %a = emitc.rem %arg1, %c7 : (i32, i32) -> i32 + emitc.yield %a : i32 + } + %r = emitc.expression noinline : i32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%a, %arg2, %q) : (i32, i32, i32) -> (i32) + %c = emitc.mul %arg3, %arg4 : (f32, f32) -> f32 + %d = emitc.cast %c : f32 to i32 + %e = emitc.sub %b, %d : (i32, i32) -> i32 + emitc.yield %e : i32 + } + return %r : i32 +} + func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { emitc.for %i0 = %arg0 to %arg1 step %arg2 { %0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32 @@ -157,3 +216,14 @@ func.func @test_subscript(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5x emitc.assign %0 : f32 to %1 : f32 return } + +emitc.verbatim "#ifdef __cplusplus" +emitc.verbatim "extern \"C\" {" +emitc.verbatim "#endif // __cplusplus" + +emitc.verbatim "#ifdef __cplusplus" +emitc.verbatim "} // extern \"C\"" +emitc.verbatim "#endif // __cplusplus" + +emitc.verbatim "typedef int32_t i32;" +emitc.verbatim "typedef float f32;" diff --git a/mlir/test/Dialect/EmitC/transforms.mlir b/mlir/test/Dialect/EmitC/transforms.mlir new file mode 100644 index 0000000000000..ad167fa455a1a --- /dev/null +++ b/mlir/test/Dialect/EmitC/transforms.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-opt %s --form-expressions --verify-diagnostics --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @single_expression( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 { +// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32 +// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_6:.*]] = emitc.mul %[[VAL_0]], %[[VAL_4]] : (i32, i32) -> i32 +// CHECK: %[[VAL_7:.*]] = emitc.sub %[[VAL_6]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_7]], %[[VAL_3]] : (i32, i32) -> i1 +// CHECK: emitc.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: return %[[VAL_5]] : i1 +// CHECK: } + +func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i1 { + %c42 = "emitc.constant"(){value = 42 : i32} : () -> i32 + %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32 + %b = emitc.sub %a, %arg2 : (i32, i32) -> i32 + %c = emitc.cmp lt, %b, %arg3 :(i32, i32) -> i1 + return %c : i1 +} + +// CHECK-LABEL: func.func @multiple_expressions( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) { +// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_5:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: %[[VAL_6:.*]] = emitc.sub %[[VAL_5]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_6]] : i32 +// CHECK: } +// CHECK: %[[VAL_7:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_8:.*]] = emitc.add %[[VAL_1]], %[[VAL_3]] : (i32, i32) -> i32 +// CHECK: %[[VAL_9:.*]] = emitc.div %[[VAL_8]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_9]] : i32 +// CHECK: } +// CHECK: return %[[VAL_4]], %[[VAL_7]] : i32, i32 +// CHECK: } + +func.func @multiple_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> (i32, i32) { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.sub %a, %arg2 : (i32, i32) -> i32 + %c = emitc.add %arg1, %arg3 : (i32, i32) -> i32 + %d = emitc.div %c, %arg2 : (i32, i32) -> i32 + return %b, %d : i32, i32 +} + +// CHECK-LABEL: func.func @expression_with_call( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 { +// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_5:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: %[[VAL_6:.*]] = emitc.call_opaque "foo"(%[[VAL_5]], %[[VAL_2]]) : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_6]] : i32 +// CHECK: } +// CHECK: %[[VAL_7:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_4]], %[[VAL_1]] : (i32, i32) -> i1 +// CHECK: emitc.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: return %[[VAL_7]] : i1 +// CHECK: } + +func.func @expression_with_call(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i1 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "foo" (%a, %arg2) : (i32, i32) -> (i32) + %c = emitc.cmp lt, %b, %arg1 :(i32, i32) -> i1 + return %c : i1 +} + +// CHECK-LABEL: func.func @expression_with_dereference( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr) -> i1 { +// CHECK: %[[VAL_3:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_4:.*]] = emitc.apply "*"(%[[VAL_2]]) : (!emitc.ptr) -> i32 +// CHECK: emitc.yield %[[VAL_4]] : i32 +// CHECK: } +// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_6:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: %[[VAL_7:.*]] = emitc.cmp lt, %[[VAL_6]], %[[VAL_3]] : (i32, i32) -> i1 +// CHECK: emitc.yield %[[VAL_7]] : i1 +// CHECK: } +// CHECK: return %[[VAL_5]] : i1 +// CHECK: } + +func.func @expression_with_dereference(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.apply "*"(%arg2) : (!emitc.ptr) -> (i32) + %c = emitc.cmp lt, %a, %b :(i32, i32) -> i1 + return %c : i1 +} + +// CHECK-LABEL: func.func @expression_with_address_taken( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr) -> i1 { +// CHECK: %[[VAL_3:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_4:.*]] = emitc.rem %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_4]] : i32 +// CHECK: } +// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_6:.*]] = emitc.apply "&"(%[[VAL_3]]) : (i32) -> !emitc.ptr +// CHECK: %[[VAL_7:.*]] = emitc.add %[[VAL_6]], %[[VAL_1]] : (!emitc.ptr, i32) -> !emitc.ptr +// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_7]], %[[VAL_2]] : (!emitc.ptr, !emitc.ptr) -> i1 +// CHECK: emitc.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: return %[[VAL_5]] : i1 +// CHECK: } + +func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { + %a = emitc.rem %arg0, %arg1 : (i32, i32) -> (i32) + %b = emitc.apply "&"(%a) : (i32) -> !emitc.ptr + %c = emitc.add %b, %arg1 : (!emitc.ptr, i32) -> !emitc.ptr + %d = emitc.cmp lt, %c, %arg2 :(!emitc.ptr, !emitc.ptr) -> i1 + return %d : i1 +} diff --git a/mlir/test/Target/Cpp/bitwise_operators.mlir b/mlir/test/Target/Cpp/bitwise_operators.mlir new file mode 100644 index 0000000000000..e666359fc82c9 --- /dev/null +++ b/mlir/test/Target/Cpp/bitwise_operators.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @bitwise(%arg0: i32, %arg1: i32) -> () { + %0 = emitc.bitwise_and %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.bitwise_left_shift %arg0, %arg1 : (i32, i32) -> i32 + %2 = emitc.bitwise_not %arg0 : (i32) -> i32 + %3 = emitc.bitwise_or %arg0, %arg1 : (i32, i32) -> i32 + %4 = emitc.bitwise_right_shift %arg0, %arg1 : (i32, i32) -> i32 + %5 = emitc.bitwise_xor %arg0, %arg1 : (i32, i32) -> i32 + + return +} + +// CHECK-LABEL: void bitwise +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = [[V0:[^ ]*]] & [[V1:[^ ]*]]; +// CHECK-NEXT: int32_t [[V3:[^ ]*]] = [[V0]] << [[V1]]; +// CHECK-NEXT: int32_t [[V4:[^ ]*]] = ~[[V0]]; +// CHECK-NEXT: int32_t [[V5:[^ ]*]] = [[V0]] | [[V1]]; +// CHECK-NEXT: int32_t [[V6:[^ ]*]] = [[V0]] >> [[V1]]; +// CHECK-NEXT: int32_t [[V7:[^ ]*]] = [[V0]] ^ [[V1]]; diff --git a/mlir/test/Target/Cpp/call.mlir b/mlir/test/Target/Cpp/call.mlir index 2bcdc87205184..e3ac392f30b62 100644 --- a/mlir/test/Target/Cpp/call.mlir +++ b/mlir/test/Target/Cpp/call.mlir @@ -18,7 +18,7 @@ func.func @emitc_call_opaque() { func.func @emitc_call_opaque_two_results() { - %0 = arith.constant 0 : index + %0 = "emitc.constant"() <{value = 0 : index}> : () -> index %1:2 = emitc.call_opaque "two_results" () : () -> (i32, i32) return } diff --git a/mlir/test/Target/Cpp/const.mlir b/mlir/test/Target/Cpp/const.mlir index e6c94732e9f6b..524d564b3b943 100644 --- a/mlir/test/Target/Cpp/const.mlir +++ b/mlir/test/Target/Cpp/const.mlir @@ -2,21 +2,31 @@ // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP func.func @emitc_constant() { - %c0 = "emitc.constant"(){value = #emitc.opaque<"">} : () -> i32 + %c0 = "emitc.constant"(){value = #emitc.opaque<"INT_MAX">} : () -> i32 %c1 = "emitc.constant"(){value = 42 : i32} : () -> i32 %c2 = "emitc.constant"(){value = -1 : i32} : () -> i32 %c3 = "emitc.constant"(){value = -1 : si8} : () -> si8 %c4 = "emitc.constant"(){value = 255 : ui8} : () -> ui8 %c5 = "emitc.constant"(){value = #emitc.opaque<"CHAR_MIN">} : () -> !emitc.opaque<"char"> + %c6 = "emitc.constant"(){value = 2 : index} : () -> index + %c7 = "emitc.constant"(){value = 2.0 : f32} : () -> f32 + %c8 = "emitc.constant"(){value = dense<0> : tensor} : () -> tensor + %c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex> + %c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> return } // CPP-DEFAULT: void emitc_constant() { -// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]]; +// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = INT_MAX; // CPP-DEFAULT-NEXT: int32_t [[V1:[^ ]*]] = 42; // CPP-DEFAULT-NEXT: int32_t [[V2:[^ ]*]] = -1; // CPP-DEFAULT-NEXT: int8_t [[V3:[^ ]*]] = -1; // CPP-DEFAULT-NEXT: uint8_t [[V4:[^ ]*]] = 255; // CPP-DEFAULT-NEXT: char [[V5:[^ ]*]] = CHAR_MIN; +// CPP-DEFAULT-NEXT: size_t [[V6:[^ ]*]] = 2; +// CPP-DEFAULT-NEXT: float [[V7:[^ ]*]] = (float)2.000000000e+00; +// CPP-DEFAULT-NEXT: Tensor [[V8:[^ ]*]] = {0}; +// CPP-DEFAULT-NEXT: Tensor [[V9:[^ ]*]] = {0, 1}; +// CPP-DEFAULT-NEXT: Tensor [[V10:[^ ]*]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00}; // CPP-DECLTOP: void emitc_constant() { // CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; @@ -25,9 +35,19 @@ func.func @emitc_constant() { // CPP-DECLTOP-NEXT: int8_t [[V3:[^ ]*]]; // CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]]; // CPP-DECLTOP-NEXT: char [[V5:[^ ]*]]; -// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: size_t [[V6:[^ ]*]]; +// CPP-DECLTOP-NEXT: float [[V7:[^ ]*]]; +// CPP-DECLTOP-NEXT: Tensor [[V8:[^ ]*]]; +// CPP-DECLTOP-NEXT: Tensor [[V9:[^ ]*]]; +// CPP-DECLTOP-NEXT: Tensor [[V10:[^ ]*]]; +// CPP-DECLTOP-NEXT: [[V0]] = INT_MAX; // CPP-DECLTOP-NEXT: [[V1]] = 42; // CPP-DECLTOP-NEXT: [[V2]] = -1; // CPP-DECLTOP-NEXT: [[V3]] = -1; // CPP-DECLTOP-NEXT: [[V4]] = 255; // CPP-DECLTOP-NEXT: [[V5]] = CHAR_MIN; +// CPP-DECLTOP-NEXT: [[V6]] = 2; +// CPP-DECLTOP-NEXT: [[V7]] = (float)2.000000000e+00; +// CPP-DECLTOP-NEXT: [[V8]] = {0}; +// CPP-DECLTOP-NEXT: [[V9]] = {0, 1}; +// CPP-DECLTOP-NEXT: [[V10]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00}; diff --git a/mlir/test/Target/Cpp/declare_func.mlir b/mlir/test/Target/Cpp/declare_func.mlir new file mode 100644 index 0000000000000..72c087a3388e2 --- /dev/null +++ b/mlir/test/Target/Cpp/declare_func.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]); +emitc.declare_func @bar +// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]) { +emitc.func @bar(%arg0: i32) -> i32 { + emitc.return %arg0 : i32 +} + + +// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]); +emitc.declare_func @foo +// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]) { +emitc.func @foo(%arg0: i32) -> i32 attributes {specifiers = ["static","inline"]} { + emitc.return %arg0 : i32 +} diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir new file mode 100644 index 0000000000000..9ec9dcc3c6a84 --- /dev/null +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -0,0 +1,212 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP + +// CPP-DEFAULT: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if ([[VAL_5]]) { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: return [[VAL_6]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: bool [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if ([[VAL_5]]) { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: return [[VAL_6]]; +// CPP-DECLTOP-NEXT: } + +func.func @single_use(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 { + %p0 = emitc.literal "M_PI" : i32 + %e = emitc.expression : i1 { + %a = emitc.mul %arg0, %p0 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32) + %c = emitc.sub %b, %arg3 : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + emitc.yield %d : i1 + } + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32 + emitc.if %e { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } else { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } + return %v : i32 +} + +// CPP-DEFAULT: int32_t do_not_inline(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = ([[VAL_1]] + [[VAL_2]]) * [[VAL_3]]; +// CPP-DEFAULT-NEXT: return [[VAL_4]]; +// CPP-DEFAULT-NEXT:} + +// CPP-DECLTOP: int32_t do_not_inline(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = ([[VAL_1]] + [[VAL_2]]) * [[VAL_3]]; +// CPP-DECLTOP-NEXT: return [[VAL_4]]; +// CPP-DECLTOP-NEXT:} + +func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 { + %e = emitc.expression noinline : i32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.mul %a, %arg2 : (i32, i32) -> i32 + emitc.yield %b : i32 + } + return %e : i32 +} + +// CPP-DEFAULT: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); +// CPP-DECLTOP-NEXT: } + +func.func @paranthesis_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 { + %e = emitc.expression : f32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.mul %a, %arg2 : (i32, i32) -> i32 + %d = emitc.cast %b : i32 to f32 + emitc.yield %d : f32 + } + return %e : f32 +} + +// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if ([[VAL_5]]) { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: bool [[VAL_7:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_5]]; +// CPP-DEFAULT-NEXT: return [[VAL_6]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: bool [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: bool [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if ([[VAL_5]]) { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_5]]; +// CPP-DECLTOP-NEXT: return [[VAL_6]]; +// CPP-DECLTOP-NEXT: } + +func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 { + %e = emitc.expression : i1 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32) + %c = emitc.sub %b, %arg3 : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + emitc.yield %d : i1 + } + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32 + emitc.if %e { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } else { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } + %q = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i1 + emitc.assign %e : i1 to %q : i1 + return %v : i32 +} + +// CPP-DEFAULT: int32_t different_expressions(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { +// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: return [[VAL_7]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t different_expressions(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_3]] % [[VAL_4]]; +// CPP-DECLTOP-NEXT: [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { +// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: return [[VAL_7]]; +// CPP-DECLTOP-NEXT: } + +func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 { + %e1 = emitc.expression : i32 { + %a = emitc.rem %arg2, %arg3 : (i32, i32) -> i32 + emitc.yield %a : i32 + } + %e2 = emitc.expression : i32 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%e1, %a) : (i32, i32) -> (i32) + emitc.yield %b : i32 + } + %e3 = emitc.expression : i1 { + %c = emitc.sub %e2, %arg3 : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + emitc.yield %d : i1 + } + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32 + emitc.if %e3 { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } else { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } + return %v : i32 +} + +// CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]]; +// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]] % [[VAL_2]]; +// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DECLTOP-NEXT: } + +func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { + %a = emitc.expression : i32 { + %b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + emitc.yield %b : i32 + } + %c = emitc.expression : i1 { + %d = emitc.apply "&"(%a) : (i32) -> !emitc.ptr + %e = emitc.sub %d, %arg1 : (!emitc.ptr, i32) -> !emitc.ptr + %f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr, !emitc.ptr) -> i1 + emitc.yield %f : i1 + } + return %c : i1 +} diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir index 90504b1347bb4..5225f3ddaff25 100644 --- a/mlir/test/Target/Cpp/for.mlir +++ b/mlir/test/Target/Cpp/for.mlir @@ -2,31 +2,43 @@ // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { - emitc.for %i0 = %arg0 to %arg1 step %arg2 { + %lb = emitc.expression : index { + %a = emitc.add %arg0, %arg1 : (index, index) -> index + emitc.yield %a : index + } + %ub = emitc.expression : index { + %a = emitc.mul %arg1, %arg2 : (index, index) -> index + emitc.yield %a : index + } + %step = emitc.expression : index { + %a = emitc.div %arg0, %arg2 : (index, index) -> index + emitc.yield %a : index + } + emitc.for %i0 = %lb to %ub step %step { %0 = emitc.call_opaque "f"() : () -> i32 } return } -// CPP-DEFAULT: void test_for(size_t [[START:[^ ]*]], size_t [[STOP:[^ ]*]], size_t [[STEP:[^ ]*]]) { -// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) { +// CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { +// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f(); // CPP-DEFAULT-NEXT: } // CPP-DEFAULT-NEXT: return; -// CPP-DECLTOP: void test_for(size_t [[START:[^ ]*]], size_t [[STOP:[^ ]*]], size_t [[STEP:[^ ]*]]) { +// CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { // CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]]; -// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) { +// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { // CPP-DECLTOP-NEXT: [[V4]] = f(); // CPP-DECLTOP-NEXT: } // CPP-DECLTOP-NEXT: return; func.func @test_for_yield() { - %start = arith.constant 0 : index - %stop = arith.constant 10 : index - %step = arith.constant 1 : index + %start = "emitc.constant"() <{value = 0 : index}> : () -> index + %stop = "emitc.constant"() <{value = 10 : index}> : () -> index + %step = "emitc.constant"() <{value = 1 : index}> : () -> index - %s0 = arith.constant 0 : i32 - %p0 = arith.constant 1.0 : f32 + %s0 = "emitc.constant"() <{value = 0 : i32}> : () -> i32 + %p0 = "emitc.constant"() <{value = 1.0 : f32}> : () -> f32 %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32 %1 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32 diff --git a/mlir/test/Target/Cpp/func.mlir b/mlir/test/Target/Cpp/func.mlir new file mode 100644 index 0000000000000..a639cae6f623c --- /dev/null +++ b/mlir/test/Target/Cpp/func.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP + + +emitc.func @emitc_func(%arg0 : i32) { + emitc.call_opaque "foo" (%arg0) : (i32) -> () + emitc.return +} +// CPP-DEFAULT: void emitc_func(int32_t [[V0:[^ ]*]]) { +// CPP-DEFAULT-NEXT: foo([[V0:[^ ]*]]); +// CPP-DEFAULT-NEXT: return; + + +emitc.func @return_i32() -> i32 attributes {specifiers = ["static","inline"]} { + %0 = emitc.call_opaque "foo" (): () -> i32 + emitc.return %0 : i32 +} +// CPP-DEFAULT: static inline int32_t return_i32() { +// CPP-DEFAULT-NEXT: [[V0:[^ ]*]] = foo(); +// CPP-DEFAULT-NEXT: return [[V0:[^ ]*]]; + +// CPP-DECLTOP: static inline int32_t return_i32() { +// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; +// CPP-DECLTOP-NEXT: [[V0:]] = foo(); +// CPP-DECLTOP-NEXT: return [[V0:[^ ]*]]; + + +emitc.func @emitc_call() -> i32 { + %0 = emitc.call @return_i32() : () -> (i32) + emitc.return %0 : i32 +} +// CPP-DEFAULT: int32_t emitc_call() { +// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = return_i32(); +// CPP-DEFAULT-NEXT: return [[V0:[^ ]*]]; + +// CPP-DECLTOP: int32_t emitc_call() { +// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; +// CPP-DECLTOP-NEXT: [[V0:[^ ]*]] = return_i32(); +// CPP-DECLTOP-NEXT: return [[V0:[^ ]*]]; + +emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]} +// CPP-DEFAULT: extern void extern_func(int32_t); diff --git a/mlir/test/Target/Cpp/if.mlir b/mlir/test/Target/Cpp/if.mlir index 743f8ad396882..7b0e2da85d0eb 100644 --- a/mlir/test/Target/Cpp/if.mlir +++ b/mlir/test/Target/Cpp/if.mlir @@ -49,7 +49,7 @@ func.func @test_if_else(%arg0: i1, %arg1: f32) { func.func @test_if_yield(%arg0: i1, %arg1: f32) { - %0 = arith.constant 0 : i8 + %0 = "emitc.constant"() <{value = 0 : i8}> : () -> i8 %x = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32 %y = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f64 emitc.if %arg0 { diff --git a/mlir/test/Target/Cpp/logical_operators.mlir b/mlir/test/Target/Cpp/logical_operators.mlir new file mode 100644 index 0000000000000..7083dc218fca9 --- /dev/null +++ b/mlir/test/Target/Cpp/logical_operators.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @logical(%arg0: i32, %arg1: i32) -> () { + %0 = emitc.logical_and %arg0, %arg1 : i32, i32 + %1 = emitc.logical_not %arg0 : i32 + %2 = emitc.logical_or %arg0, %arg1 : i32, i32 + + return +} + +// CHECK-LABEL: void logical +// CHECK-NEXT: bool [[V2:[^ ]*]] = [[V0:[^ ]*]] && [[V1:[^ ]*]]; +// CHECK-NEXT: bool [[V3:[^ ]*]] = ![[V0]]; +// CHECK-NEXT: bool [[V4:[^ ]*]] = [[V0]] || [[V1]]; diff --git a/mlir/test/Target/Cpp/stdops.mlir b/mlir/test/Target/Cpp/stdops.mlir index 0723188a62c68..cc6bdbe376984 100644 --- a/mlir/test/Target/Cpp/stdops.mlir +++ b/mlir/test/Target/Cpp/stdops.mlir @@ -1,37 +1,6 @@ // RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP -func.func @std_constant() { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 2 : index - %c2 = arith.constant 2.0 : f32 - %c3 = arith.constant dense<0> : tensor - %c4 = arith.constant dense<[0, 1]> : tensor<2xindex> - %c5 = arith.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> - return -} -// CPP-DEFAULT: void std_constant() { -// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = 0; -// CPP-DEFAULT-NEXT: size_t [[V1:[^ ]*]] = 2; -// CPP-DEFAULT-NEXT: float [[V2:[^ ]*]] = (float)2.000000000e+00; -// CPP-DEFAULT-NEXT: Tensor [[V3:[^ ]*]] = {0}; -// CPP-DEFAULT-NEXT: Tensor [[V4:[^ ]*]] = {0, 1}; -// CPP-DEFAULT-NEXT: Tensor [[V5:[^ ]*]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00}; - -// CPP-DECLTOP: void std_constant() { -// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; -// CPP-DECLTOP-NEXT: size_t [[V1:[^ ]*]]; -// CPP-DECLTOP-NEXT: float [[V2:[^ ]*]]; -// CPP-DECLTOP-NEXT: Tensor [[V3:[^ ]*]]; -// CPP-DECLTOP-NEXT: Tensor [[V4:[^ ]*]]; -// CPP-DECLTOP-NEXT: Tensor [[V5:[^ ]*]]; -// CPP-DECLTOP-NEXT: [[V0]] = 0; -// CPP-DECLTOP-NEXT: [[V1]] = 2; -// CPP-DECLTOP-NEXT: [[V2]] = (float)2.000000000e+00; -// CPP-DECLTOP-NEXT: [[V3]] = {0}; -// CPP-DECLTOP-NEXT: [[V4]] = {0, 1}; -// CPP-DECLTOP-NEXT: [[V5]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00}; - func.func @std_call() { %0 = call @one_result () : () -> i32 %1 = call @one_result () : () -> i32 @@ -49,13 +18,11 @@ func.func @std_call() { func.func @std_call_two_results() { - %c = arith.constant 0 : i8 %0:2 = call @two_results () : () -> (i32, f32) %1:2 = call @two_results () : () -> (i32, f32) return } // CPP-DEFAULT: void std_call_two_results() { -// CPP-DEFAULT-NEXT: int8_t [[V0:[^ ]*]] = 0; // CPP-DEFAULT-NEXT: int32_t [[V1:[^ ]*]]; // CPP-DEFAULT-NEXT: float [[V2:[^ ]*]]; // CPP-DEFAULT-NEXT: std::tie([[V1]], [[V2]]) = two_results(); @@ -64,18 +31,16 @@ func.func @std_call_two_results() { // CPP-DEFAULT-NEXT: std::tie([[V3]], [[V4]]) = two_results(); // CPP-DECLTOP: void std_call_two_results() { -// CPP-DECLTOP-NEXT: int8_t [[V0:[^ ]*]]; // CPP-DECLTOP-NEXT: int32_t [[V1:[^ ]*]]; // CPP-DECLTOP-NEXT: float [[V2:[^ ]*]]; // CPP-DECLTOP-NEXT: int32_t [[V3:[^ ]*]]; // CPP-DECLTOP-NEXT: float [[V4:[^ ]*]]; -// CPP-DECLTOP-NEXT: [[V0]] = 0; // CPP-DECLTOP-NEXT: std::tie([[V1]], [[V2]]) = two_results(); // CPP-DECLTOP-NEXT: std::tie([[V3]], [[V4]]) = two_results(); func.func @one_result() -> i32 { - %0 = arith.constant 0 : i32 + %0 = "emitc.constant"() <{value = 0 : i32}> : () -> i32 return %0 : i32 } // CPP-DEFAULT: int32_t one_result() { @@ -89,8 +54,8 @@ func.func @one_result() -> i32 { func.func @two_results() -> (i32, f32) { - %0 = arith.constant 0 : i32 - %1 = arith.constant 1.0 : f32 + %0 = "emitc.constant"() <{value = 0 : i32}> : () -> i32 + %1 = "emitc.constant"() <{value = 1.0 : f32}> : () -> f32 return %0, %1 : i32, f32 } // CPP-DEFAULT: std::tuple two_results() { diff --git a/mlir/test/Target/Cpp/unary_operators.mlir b/mlir/test/Target/Cpp/unary_operators.mlir new file mode 100644 index 0000000000000..8a89437a41cc5 --- /dev/null +++ b/mlir/test/Target/Cpp/unary_operators.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @unary(%arg0: i32) -> () { + %0 = emitc.unary_minus %arg0 : (i32) -> i32 + %1 = emitc.unary_plus %arg0 : (i32) -> i32 + + return +} + +// CHECK-LABEL: void unary +// CHECK-NEXT: int32_t [[V1:[^ ]*]] = -[[V0:[^ ]*]]; +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = +[[V0]]; diff --git a/mlir/test/Target/Cpp/verbatim.mlir b/mlir/test/Target/Cpp/verbatim.mlir new file mode 100644 index 0000000000000..10465dd781a81 --- /dev/null +++ b/mlir/test/Target/Cpp/verbatim.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s + + +emitc.verbatim "#ifdef __cplusplus" +// CHECK: #ifdef __cplusplus +emitc.verbatim "extern \"C\" {" +// CHECK-NEXT: extern "C" { +emitc.verbatim "#endif // __cplusplus" +// CHECK-NEXT: #endif // __cplusplus +emitc.verbatim "#ifdef __cplusplus" +// CHECK-NEXT: #ifdef __cplusplus +emitc.verbatim "} // extern \"C\"" +// CHECK-NEXT: } // extern "C" +emitc.verbatim "#endif // __cplusplus" +// CHECK-NEXT: #endif // __cplusplus + +emitc.verbatim "typedef int32_t i32;" +// CHECK-NEXT: typedef int32_t i32; +emitc.verbatim "typedef float f32;" +// CHECK-NEXT: typedef float f32; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index cdb35d87992ed..f2d804477f1b3 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1560,8 +1560,10 @@ td_library( includes = ["include"], deps = [ ":BuiltinDialectTdFiles", + ":CallInterfacesTdFiles", ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":FunctionInterfacesTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -1633,7 +1635,6 @@ cc_library( ]), hdrs = glob(["include/mlir/Target/Cpp/*.h"]), deps = [ - ":ArithDialect", ":ControlFlowDialect", ":EmitCDialect", ":FuncDialect", @@ -3633,10 +3634,12 @@ cc_library( ]), includes = ["include"], deps = [ + ":CallOpInterfaces", ":CastInterfaces", ":ControlFlowInterfaces", ":EmitCAttributesIncGen", ":EmitCOpsIncGen", + ":FunctionInterfaces", ":IR", ":SideEffectInterfaces", "//llvm:Support", @@ -3837,6 +3840,7 @@ cc_library( ":AMDGPUToROCDL", ":AffineToStandard", ":ArithToAMDGPU", + ":ArithToEmitC", ":ArithToLLVM", ":ArithToSPIRV", ":ArmNeon2dToIntr", @@ -3853,6 +3857,7 @@ cc_library( ":ControlFlowToSPIRV", ":ConversionPassIncGen", ":ConvertToLLVM", + ":FuncToEmitC", ":FuncToLLVM", ":FuncToSPIRV", ":GPUToGPURuntimeTransforms", @@ -6751,6 +6756,32 @@ cc_library( ], ) +cc_library( + name = "FuncToEmitC", + srcs = glob([ + "lib/Conversion/FuncToEmitC*.cpp", + "lib/Conversion/FuncToEmitC/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/FuncToEmitC/*.h", + ]), + includes = [ + "include", + "lib/Conversion/FuncToEmitC", + ], + deps = [ + ":ConversionPassIncGen", + ":FuncDialect", + ":EmitCDialect", + ":IR", + ":Pass", + ":Support", + ":TransformUtils", + ":Transforms", + "//llvm:Support", + ], +) + cc_library( name = "FuncToSPIRV", srcs = glob([ @@ -7936,6 +7967,32 @@ cc_library( ], ) +cc_library( + name = "ArithToEmitC", + srcs = glob([ + "lib/Conversion/ArithToEmitC/*.cpp", + "lib/Conversion/ArithToEmitC/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/ArithToEmitC/*.h", + ]), + includes = [ + "include", + "lib/Conversion/ArithToEmitC", + ], + deps = [ + ":ArithDialect", + ":ConversionPassIncGen", + ":EmitCDialect", + ":IR", + ":Pass", + ":Support", + ":TransformUtils", + ":Transforms", + "//llvm:Support", + ], +) + cc_library( name = "ArithToLLVM", srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),