diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index e664a1749216d1..485f0801226af4 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -265,8 +265,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { def EmitC_CastOp : EmitC_Op<"cast", [CExpression, - DeclareOpInterfaceMethods, - SameOperandsAndResultShape]> { + DeclareOpInterfaceMethods]> { let summary = "Cast operation"; let description = [{ The `cast` operation performs an explicit type conversion and is emitted diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp index 29c64487a2bc09..8389920f587d1c 100644 --- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -37,16 +37,17 @@ class CallOpConversion final : public OpConversionPattern { callOp, "only functions with zero or one result can be converted"); // Convert the original function results. - Type resultTy = nullptr; + SmallVector types; if (callOp.getNumResults()) { - resultTy = typeConverter->convertType(callOp.getResult(0).getType()); + auto resultTy = typeConverter->convertType(callOp.getResult(0).getType()); if (!resultTy) return rewriter.notifyMatchFailure( callOp, "function return type conversion failed"); + types.push_back(resultTy); } rewriter.replaceOpWithNewOp( - callOp, resultTy, adaptor.getOperands(), callOp->getAttrs()); + callOp, types, adaptor.getOperands(), callOp->getAttrs()); return success(); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e0c421741b3055..3e519813b04826 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -166,6 +166,51 @@ struct ConvertStore final : public OpConversionPattern { return success(); } }; + +struct ConvertCollapseShape final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CollapseShapeOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + auto arrayValue = dyn_cast>(operands.getSrc()); + if (!arrayValue) { + return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + rewriter.replaceOpWithNewOp(op, resultTy, operands.getSrc()); + return success(); + } +}; + +struct ConvertExpandShape final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ExpandShapeOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + auto arrayValue = dyn_cast>(operands.getSrc()); + if (!arrayValue) { + return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + rewriter.replaceOpWithNewOp(op, resultTy, operands.getSrc()); + return success(); + } +}; + } // namespace void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { @@ -186,6 +231,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter) { - patterns.add(converter, patterns.getContext()); + patterns.add( + converter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index c959f3b54fb9b3..322be2fe0c179e 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -236,13 +236,12 @@ LogicalResult emitc::AssignOp::verify() { bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); - return ( - (llvm::isa( - input)) && - (llvm::isa( - output))); + return ((llvm::isa(input)) && + (llvm::isa(output))); } OpFoldResult emitc::CastOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index e6efec14e31a60..960c77b9c74051 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2327,11 +2327,16 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) { auto insertOp = extractOp.getSource().getDefiningOp(); - auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; - if (insertOp && insertOp.getSource().getType() == extractOp.getType() && - insertOp.isSameAs(extractOp, isSame)) - return insertOp.getSource(); - + while (insertOp) { + auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; + if (insertOp.getSource().getType() == extractOp.getType() && + insertOp.isSameAs(extractOp, isSame)) + return insertOp.getSource(); + // TODO: Need to stop at the first insert_slice that has some overlap with + // the extracted range to avoid returning an early insert_slice that was + // (partially) overwritten by later ones. + insertOp = insertOp.getDest().getDefiningOp(); + } return {}; } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 39eb28300c0ca3..ccd7e69acf6ee7 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -136,11 +136,11 @@ struct CppEmitter { LogicalResult emitTupleType(Location loc, ArrayRef types); /// Emits an assignment for a variable which has been declared previously. - LogicalResult emitVariableAssignment(OpResult result); + LogicalResult emitVariableAssignment(OpResult result, StringRef prefix = "v"); /// Emits a variable declaration for a result of an operation. - LogicalResult emitVariableDeclaration(OpResult result, - bool trailingSemicolon); + LogicalResult emitVariableDeclaration(OpResult result, bool trailingSemicolon, + StringRef prefix = "v"); /// Emits a declaration of a variable with the given type and name. LogicalResult emitVariableDeclaration(Location loc, Type type, @@ -152,7 +152,7 @@ struct CppEmitter { /// - emits nothing if no value produced by op; /// Emits final '=' operator where a type is produced. Returns failure if /// any result type could not be converted. - LogicalResult emitAssignPrefix(Operation &op); + LogicalResult emitAssignPrefix(Operation &op, StringRef prefix = "v"); /// Emits a global variable declaration or definition. LogicalResult emitGlobalVariable(GlobalOp op); @@ -175,7 +175,7 @@ struct CppEmitter { LogicalResult emitExpression(ExpressionOp expressionOp); /// Return the existing or a new name for a Value. - StringRef getOrCreateName(Value val); + StringRef getOrCreateName(Value val, StringRef prefix = "v"); // Returns the textual representation of a subscript operation. std::string getSubscriptName(emitc::SubscriptOp op); @@ -303,6 +303,17 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value) { OpResult result = operation->getResult(0); + std::string prefix = "v"; + if (auto c = dyn_cast(operation)) { + Attribute val = c.getValue(); + if (auto ia = dyn_cast(val)) { + if (ia.getInt() > 0) + prefix = "c_" + std::to_string(ia.getInt()) + "_"; + else + prefix = "c_n" + std::to_string(-ia.getInt()) + "_"; + } + } + // Only emit an assignment as the variable was already declared when printing // the FuncOp. if (emitter.shouldDeclareVariablesAtTop()) { @@ -312,7 +323,7 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, return success(); } - if (failed(emitter.emitVariableAssignment(result))) + if (failed(emitter.emitVariableAssignment(result, prefix))) return failure(); return emitter.emitAttribute(operation->getLoc(), value); } @@ -326,7 +337,7 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, } // Emit a variable declaration. - if (failed(emitter.emitAssignPrefix(*operation))) + if (failed(emitter.emitAssignPrefix(*operation, prefix))) return failure(); return emitter.emitAttribute(operation->getLoc(), value); } @@ -716,6 +727,21 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { raw_ostream &os = emitter.ostream(); Operation &op = *castOp.getOperation(); + if (auto arrType = dyn_cast(castOp.getType())) { + std::string shapeStr = ""; + for (auto i : arrType.getShape()) { + shapeStr += "["; + shapeStr += std::to_string(i); + shapeStr += "]"; + } + os << "float (&" << emitter.getOrCreateName(castOp.getResult()) << ")" + << shapeStr << " = *reinterpret_cast("; + if (failed(emitter.emitOperand(castOp.getOperand()))) + return failure(); + os << ")"; + return success(); + } + if (failed(emitter.emitAssignPrefix(op))) return failure(); os << "("; @@ -1128,7 +1154,7 @@ std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) { } /// Return the existing or a new name for a Value. -StringRef CppEmitter::getOrCreateName(Value val) { +StringRef CppEmitter::getOrCreateName(Value val, StringRef prefix) { if (auto literal = dyn_cast_if_present(val.getDefiningOp())) return literal.getValue(); if (!valueMapper.count(val)) { @@ -1139,7 +1165,8 @@ StringRef CppEmitter::getOrCreateName(Value val) { val.getDefiningOp())) { valueMapper.insert(val, getGlobal.getName().str()); } else { - valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); + valueMapper.insert(val, + formatv("{0}{1}", prefix, ++valueInScopeCount.top())); } } return *valueMapper.begin(val); @@ -1377,17 +1404,19 @@ CppEmitter::emitOperandsAndAttributes(Operation &op, return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); } -LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { +LogicalResult CppEmitter::emitVariableAssignment(OpResult result, + StringRef prefix) { if (!hasValueInScope(result)) { return result.getDefiningOp()->emitOpError( "result variable for the operation has not been declared"); } - os << getOrCreateName(result) << " = "; + os << getOrCreateName(result, prefix) << " = "; return success(); } LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, - bool trailingSemicolon) { + bool trailingSemicolon, + StringRef prefix) { if (isa(result.getDefiningOp())) return success(); if (hasValueInScope(result)) { @@ -1396,7 +1425,7 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, } if (failed(emitVariableDeclaration(result.getOwner()->getLoc(), result.getType(), - getOrCreateName(result)))) + getOrCreateName(result, prefix)))) return failure(); if (trailingSemicolon) os << ";\n"; @@ -1427,7 +1456,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) { return success(); } -LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { +LogicalResult CppEmitter::emitAssignPrefix(Operation &op, StringRef prefix) { // If op is being emitted as part of an expression, bail out. if (getEmittedExpression()) return success(); @@ -1438,10 +1467,11 @@ LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { case 1: { OpResult result = op.getResult(0); if (shouldDeclareVariablesAtTop()) { - if (failed(emitVariableAssignment(result))) + if (failed(emitVariableAssignment(result, prefix))) return failure(); } else { - if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false))) + if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false, + prefix))) return failure(); os << " = "; } @@ -1450,7 +1480,8 @@ LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { default: if (!shouldDeclareVariablesAtTop()) { for (OpResult result : op.getResults()) { - if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) + if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true, + prefix))) return failure(); } } @@ -1512,7 +1543,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { shouldBeInlined(cast(op)))) return success(); - os << (trailingSemicolon ? ";\n" : "\n"); + os << (trailingSemicolon ? ";" : ""); + + if (!isa(op.getLoc())) { + os << " // "; + op.getLoc().print(os); + } + os << "\n"; return success(); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 563d035f155267..fb7f13e87939aa 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -124,6 +124,14 @@ declare_mlir_dialect_python_bindings( dialects/func.py DIALECT_NAME func) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/EmitCOps.td + SOURCES + dialects/emitc.py + DIALECT_NAME emitc) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/EmitCOps.td b/mlir/python/mlir/dialects/EmitCOps.td new file mode 100644 index 00000000000000..e27d589920eda1 --- /dev/null +++ b/mlir/python/mlir/dialects/EmitCOps.td @@ -0,0 +1,19 @@ +//===-- EmitCOps.td - Entry point for Func bind -------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the main file from which the Python bindings for the Func dialect +// are generated. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_FUNC +#define PYTHON_BINDINGS_FUNC + +include "mlir/Dialect/EmitC/IR/EmitC.td" + +#endif diff --git a/mlir/python/mlir/dialects/emitc.py b/mlir/python/mlir/dialects/emitc.py new file mode 100644 index 00000000000000..9bc0b48452d91c --- /dev/null +++ b/mlir/python/mlir/dialects/emitc.py @@ -0,0 +1,116 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._emitc_ops_gen import * +from ._emitc_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + _cext as _ods_cext, + ) + + from typing import Optional, Sequence, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuncOp(FuncOp): + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute