diff --git a/MLIR.cmake b/MLIR.cmake index 037d495b8967..557c64a9457b 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -70,10 +70,16 @@ function(add_onnx_mlir_dialect_doc dialect dialect_tablegen_file) endfunction() add_custom_target(onnx-mlir-docs) -function(add_onnx_mlir_dialect dialect) +# If an extra parameter, the dialect name, is provided, +# this function will generate dialect and type from the td file +function(add_onnx_mlir_dialect dialect dialect_name) set(LLVM_TARGET_DEFINITIONS ${dialect}.td) - mlir_tablegen(${dialect}.hpp.inc -gen-op-decls "-I${ONNX_MLIR_SRC_ROOT}") - mlir_tablegen(${dialect}.cpp.inc -gen-op-defs "-I${ONNX_MLIR_SRC_ROOT}") + mlir_tablegen(${dialect}Ops.hpp.inc -gen-op-decls "-I${ONNX_MLIR_SRC_ROOT}") + mlir_tablegen(${dialect}Ops.cpp.inc -gen-op-defs "-I${ONNX_MLIR_SRC_ROOT}") + mlir_tablegen(${dialect}Dialect.hpp.inc -gen-dialect-decls -dialect=${dialect_name} "-I${ONNX_MLIR_SRC_ROOT}") + mlir_tablegen(${dialect}Dialect.cpp.inc -gen-dialect-defs -dialect=${dialect_name} "-I${ONNX_MLIR_SRC_ROOT}") + mlir_tablegen(${dialect}Types.hpp.inc -gen-typedef-decls -typedefs-dialect=${dialect_name} "-I${ONNX_MLIR_SRC_ROOT}") + mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs -typedefs-dialect=${dialect_name} "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(OM${dialect}IncGen) endfunction() diff --git a/src/Builder/FrontendDialectHelper.cpp b/src/Builder/FrontendDialectHelper.cpp index 747171d53cbb..1c02063583c9 100644 --- a/src/Builder/FrontendDialectHelper.cpp +++ b/src/Builder/FrontendDialectHelper.cpp @@ -271,7 +271,7 @@ mlir::Type convertONNXTypeToMLIRType( case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return builder_.getI1Type(); case onnx::TensorProto_DataType::TensorProto_DataType_STRING: - return mlir::onnxmlir::StringType::get(builder_.getContext()); + return mlir::ONNXStringType::get(builder_.getContext()); case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index eecd299f4dac..38ce7b512a4e 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -267,7 +267,9 @@ class FrontendGenImpl { assert(elem_type.value_case() == onnx::TypeProto::kTensorType && "expect tensor inside sequence type"); Type mlir_elem_type = ImportTensorType(elem_type); - Type seq_type = mlir::onnxmlir::SeqType::get(mlir_elem_type); + if (!mlir_elem_type.isa()) + llvm_unreachable("Seq type is incorrect"); + Type seq_type = mlir::SeqType::get(mlir_elem_type.cast(), -1); return seq_type; } llvm_unreachable("unexpected type"); @@ -1274,7 +1276,7 @@ class FrontendGenImpl { std::string comma = std::string(""); TypeSwitch(argType) - .Case([&](mlir::onnxmlir::SeqType seqTy) { + .Case([&](mlir::SeqType seqTy) { auto et = seqTy.getElementType(); dstream << " {\"seq\" : "; concatTypeString(et, attr, dstream); diff --git a/src/Compiler/CompilerUtils.cpp b/src/Compiler/CompilerUtils.cpp index 95e82574c811..d94de06c104c 100644 --- a/src/Compiler/CompilerUtils.cpp +++ b/src/Compiler/CompilerUtils.cpp @@ -610,7 +610,7 @@ void registerDialects(mlir::MLIRContext &context) { context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); - context.getOrLoadDialect(); + context.getOrLoadDialect(); context.getOrLoadDialect(); } @@ -736,9 +736,8 @@ InputIRLevelType determineInputIRLevel(mlir::OwningModuleRef &module) { }); // If there are ONNX ops, the input level is ONNX. - bool hasONNXOps = llvm::any_of(dialectNamespace, [&](StringRef ns) { - return (ns == ONNXOpsDialect::getDialectNamespace()); - }); + bool hasONNXOps = llvm::any_of(dialectNamespace, + [&](StringRef ns) { return (ns == ONNXDialect::getDialectNamespace()); }); if (hasONNXOps) return ONNXLevel; diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index e537bfe8bda7..89b84bade02e 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -110,7 +110,7 @@ bool hasAllScalarValues(ArrayRef values) { MemRefType convertToMemRefType(Type type) { // Convert the element type of the (tensor or memref) to a valid Krnl type. auto convertElemType = [](Type elemType) -> Type { - if (elemType.isa()) + if (elemType.isa()) return StringType::get(elemType.getContext()); return elemType; }; @@ -755,13 +755,13 @@ KrnlTypeConverter::KrnlTypeConverter() { // The order of type conversion is important: later ones are tried earlier. addConversion([](Type type) { return type; }); - addConversion([](onnxmlir::StringType stringType) { + addConversion([](ONNXStringType stringType) { return StringType::get(stringType.getContext()); }); addConversion([](TensorType tensorType) { assert(tensorType.hasRank() && "expected only ranked shapes"); - if (tensorType.getElementType().isa()) { + if (tensorType.getElementType().isa()) { Type elementType = StringType::get(tensorType.getContext()); return MemRefType::get(tensorType.getShape(), elementType); } diff --git a/src/Dialect/Krnl/CMakeLists.txt b/src/Dialect/Krnl/CMakeLists.txt index c657f648c963..820e2d8d3827 100644 --- a/src/Dialect/Krnl/CMakeLists.txt +++ b/src/Dialect/Krnl/CMakeLists.txt @@ -6,8 +6,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "s390x") llvm_replace_compiler_option(CMAKE_CXX_FLAGS_RELEASE "-O3" "-O1") endif() -add_onnx_mlir_dialect(KrnlOps) -add_onnx_mlir_dialect_doc(krnl KrnlOps.td) +add_onnx_mlir_dialect(Krnl krnl) +add_onnx_mlir_dialect_doc(krnl Krnl.td) add_onnx_mlir_library(OMKrnlOps KrnlOps.cpp @@ -15,7 +15,7 @@ add_onnx_mlir_library(OMKrnlOps KrnlHelper.cpp DEPENDS - OMKrnlOpsIncGen + OMKrnlIncGen LINK_LIBS PUBLIC OMONNXOps diff --git a/src/Dialect/Krnl/KrnlOps.td b/src/Dialect/Krnl/Krnl.td similarity index 100% rename from src/Dialect/Krnl/KrnlOps.td rename to src/Dialect/Krnl/Krnl.td diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index f522fd9958b1..b837c23eac1a 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -6,17 +6,19 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "s390x") llvm_replace_compiler_option(CMAKE_CXX_FLAGS_RELEASE "-O3" "-O1") endif() -add_onnx_mlir_dialect(ONNXOps) -add_onnx_mlir_dialect_doc(onnx ONNXOps.td) +add_onnx_mlir_dialect(ONNX onnx) +add_onnx_mlir_dialect_doc(onnx ONNX.td) add_onnx_mlir_rewriter(Rewrite) add_onnx_mlir_library(OMONNXOps - ONNXOps.cpp - ONNXOpsHelper.cpp IndexExpr.cpp IndexExprDetail.cpp MLIRDialectBuilder.cpp + ONNXDialect.cpp + ONNXOps.cpp + ONNXOpsHelper.cpp + ONNXTypes.cpp Rewrite.cpp ShapeInference/ArgMax.cpp @@ -50,7 +52,7 @@ add_onnx_mlir_library(OMONNXOps DEPENDS OMHasOnnxSubgraphOpInterfaceIncGen - OMONNXOpsIncGen + OMONNXIncGen OMONNXRewriteIncGen OMResultTypeInferenceOpInterfaceIncGen OMShapeInferenceOpInterfaceIncGen diff --git a/src/Dialect/ONNX/ONNXOps.td b/src/Dialect/ONNX/ONNX.td similarity index 84% rename from src/Dialect/ONNX/ONNXOps.td rename to src/Dialect/ONNX/ONNX.td index e3ca5e4087cf..6c79a0e18e3f 100644 --- a/src/Dialect/ONNX/ONNXOps.td +++ b/src/Dialect/ONNX/ONNX.td @@ -24,7 +24,7 @@ include "src/Interface/HasOnnxSubgraphOpInterface.td" // Definition for rewrite rules for onnx dialect // Can be used in other table gen files (.td) for onnx dialect -def StringType : Type()">, "string type">; +def StringType : Type()">, "string type">; def IsSeqTypePred : CPred<"$_self.isa()">; @@ -204,4 +204,52 @@ def ONNXBatchNormalizationInferenceModeOp: ONNX_Op<"BatchNormalizationInferenceM }]; } +class ONNX_Type + : TypeDef { + let mnemonic = typeMnemonic; +} + +def ONNX_StringType : ONNX_Type<"ONNXString", "String"> { + let summary = " ONNX StringType"; + let description = [{ + An array of characters. + }]; +} + +def ONNX_SeqType : ONNX_Type<"Seq", "Seq"> { + let summary = " ONNX SeqType"; + let description = [{ + An list of tensors which may have different shape + }]; + let parameters = (ins "::mlir::ShapedType":$ElementType, "int64_t":$Length); + + // Previous implementation did not print/parse the length field + // May add the field in future + let printer = [{ + $_printer << "<" << getImpl()->ElementType << ">"; + }]; + let parser = [{ + if (parser.parseLess()) + return Type(); + Type elementType; + if ($_parser.parseType(elementType)) + return Type(); + if ($_parser.parseGreater()) + return Type(); + if (!elementType.isa()) + return Type(); + else { + ShapedType ty = elementType.cast(); + return get($_ctxt, ty, -1); + } + }]; + let builders = [ + TypeBuilderWithInferredContext<(ins "::mlir::ShapedType":$elementType, + "int64_t":$length), [{ + return Base::get(elementType.getContext(), elementType, length); + }]> + ]; +} + #endif // ONNX_OPS diff --git a/src/Dialect/ONNX/ONNXDialect.cpp b/src/Dialect/ONNX/ONNXDialect.cpp new file mode 100644 index 000000000000..ec9cbe5df646 --- /dev/null +++ b/src/Dialect/ONNX/ONNXDialect.cpp @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------ ONNXDialect.cpp - ONNX Operations -----------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file provides definition of ONNX dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/DialectImplementation.h" + +#include "src/Dialect/ONNX/ONNXDialect.hpp" + +using namespace mlir; + +// Code for ONNX_Dialect class +#include "src/Dialect/ONNX/ONNXDialect.cpp.inc" diff --git a/src/Dialect/ONNX/ONNXDialect.hpp b/src/Dialect/ONNX/ONNXDialect.hpp new file mode 100644 index 000000000000..298876f7021b --- /dev/null +++ b/src/Dialect/ONNX/ONNXDialect.hpp @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- ONNXDialect.hpp - ONNX Operations -------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file defines ONNX Dialect class. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/IR/Dialect.h" + +#include "src/Dialect/ONNX/ONNXDialect.hpp.inc" diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 07077fcdc3d2..068bfe195dac 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -35,6 +35,32 @@ using namespace mlir; using namespace mlir::OpTrait::util; +//===----------------------------------------------------------------------===// +// Tablegen Type Definitions +//===----------------------------------------------------------------------===// +// Explanation: the type implementation is used in dialect initialization. +// If ONNXTypes.cpp.inc is included in ONNXTypes.cpp, compilation error occurs. +#define GET_TYPEDEF_CLASSES +#include "src/Dialect/ONNX/ONNXTypes.cpp.inc" + +//===----------------------------------------------------------------------===// +// ONNXDialect initialization +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void ONNXDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "src/Dialect/ONNX/ONNXOps.cpp.inc" + >(); + + addTypes< +#define GET_TYPEDEF_LIST +#include "src/Dialect/ONNX/ONNXTypes.cpp.inc" + >(); +} + //===----------------------------------------------------------------------===// // ONNX Helper functions for shape helpers //===----------------------------------------------------------------------===// @@ -564,59 +590,6 @@ static void insertConvTransposeSpatialDim(SmallVectorImpl &outputDims, } } -//===----------------------------------------------------------------------===// -// ONNXOpsDialect -//===----------------------------------------------------------------------===// - -/// Dialect creation, the instance will be owned by the context. This is the -/// point of registration of custom types and operations for the dialect. -ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx) - : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { - addOperations< -#define GET_OP_LIST -#include "src/Dialect/ONNX/ONNXOps.cpp.inc" - >(); - addTypes(); -} - -mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const { - StringRef keyword; - if (parser.parseKeyword(&keyword)) - return Type(); - - MLIRContext *context = getContext(); - if (keyword == "String") - return onnxmlir::StringType::get(context); - if (keyword == "Seq") { - if (parser.parseLess()) - return Type(); - - SmallVector elementTypes; - mlir::Type elementType; - if (parser.parseType(elementType)) - return Type(); - - if (parser.parseGreater()) - return Type(); - return onnxmlir::SeqType::get(elementType); - } - - parser.emitError(parser.getNameLoc(), "unknown onnx type: " + keyword); - return Type(); -} - -void ONNXOpsDialect::printType( - mlir::Type type, mlir::DialectAsmPrinter &os) const { - TypeSwitch(type) - .Case([&](Type) { os << "String"; }) - .Case([&](onnxmlir::SeqType type) { - os << "Seq<"; - os << type.getElementType(); - os << '>'; - }) - .Default([](Type) { llvm_unreachable("Unexpected 'onnx' type kind"); }); -} - void ONNXEntryPointOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::FuncOp function, int numInputs, int numOutputs, std::string signature) { @@ -881,8 +854,7 @@ LogicalResult ONNXSeluOp::inferShapes( LogicalResult ONNXSequenceInsertOp::inferShapes( std::function doShapeInference) { - onnxmlir::SeqType seqType = - input_sequence().getType().dyn_cast(); + SeqType seqType = input_sequence().getType().dyn_cast(); ShapedType tensorType = tensor().getType().dyn_cast(); ShapedType seqTensorType = seqType.getElementType().cast(); @@ -892,7 +864,7 @@ LogicalResult ONNXSequenceInsertOp::inferShapes( // When the input seq is empty, inherit the tensor type if (seqType.getLength() == 0) { - getResult().setType(onnxmlir::SeqType::get(tensorType, 1)); + getResult().setType(SeqType::get(tensorType, 1)); return success(); } @@ -900,11 +872,11 @@ LogicalResult ONNXSequenceInsertOp::inferShapes( // When one of the tensor is unranked if (!tensorType.hasRank()) { - getResult().setType(onnxmlir::SeqType::get(tensorType, newLength)); + getResult().setType(SeqType::get(tensorType, newLength)); return success(); } if (!seqTensorType.hasRank()) { - getResult().setType(onnxmlir::SeqType::get(seqTensorType, newLength)); + getResult().setType(SeqType::get(seqTensorType, newLength)); return success(); } @@ -922,7 +894,7 @@ LogicalResult ONNXSequenceInsertOp::inferShapes( for (auto i = 0; i < tensorRank; i++) { dims.emplace_back(seqShape[i] != tensorShape[i] ? -1 : tensorShape[i]); } - getResult().setType(onnxmlir::SeqType::get( + getResult().setType(SeqType::get( mlir::RankedTensorType::get(dims, tensorType.getElementType()), newLength)); @@ -935,7 +907,7 @@ static LogicalResult verify(ONNXSequenceInsertOp op) { // These cast should be guaranteed by default verifier Type seqElementType = operandAdaptor.input_sequence() .getType() - .dyn_cast() + .dyn_cast() .getElementType(); Type elementType1 = seqElementType.dyn_cast().getElementType(); ShapedType insertType = @@ -966,22 +938,22 @@ LogicalResult ONNXSequenceConstructOp::inferShapes( LogicalResult ONNXSequenceEmptyOp::inferShapes( std::function doShapeInference) { - auto originTy = getResult().getType().cast(); + auto originTy = getResult().getType().cast(); auto elementTy = originTy.getElementType(); - auto returnTy = onnxmlir::SeqType::get(elementTy, 0); + auto returnTy = SeqType::get(elementTy, 0); getResult().setType(returnTy); return success(); } LogicalResult ONNXSequenceEraseOp::inferShapes( std::function doShapeInference) { - auto inputTy = input_sequence().getType().cast(); + auto inputTy = input_sequence().getType().cast(); int64_t length = inputTy.getLength(); if (length == 0) return emitError("SequenceErase from an empty seq"); - getResult().setType(onnxmlir::SeqType::get( - inputTy.getElementType(), length == -1 ? -1 : length - 1)); + getResult().setType( + SeqType::get(inputTy.getElementType(), length == -1 ? -1 : length - 1)); return success(); } @@ -4700,7 +4672,7 @@ static LogicalResult verify(ONNXCategoryMapperOp op) { ShapedType inputType = X.getType().cast(); Type elementType = inputType.getElementType(); - if (!elementType.isInteger(64) && !elementType.isa()) + if (!elementType.isInteger(64) && !elementType.isa()) return op.emitError("input must be a tensor of int64 or string"); // Check attributes. @@ -4714,7 +4686,7 @@ static LogicalResult verify(ONNXCategoryMapperOp op) { if (elementType.isInteger(64) && !op.default_stringAttr()) return op.emitError("'default_string' attribute is missing."); - if (elementType.isa() && !op.default_int64Attr()) + if (elementType.isa() && !op.default_int64Attr()) return op.emitError("'default_int64' attribute is missing."); return success(); @@ -4728,7 +4700,7 @@ LogicalResult ONNXCategoryMapperOp::inferShapes( Type inputElementType = X().getType().cast().getElementType(); assert((inputElementType.isInteger(64) || - inputElementType.isa()) && + inputElementType.isa()) && "Input tensor must have int64 or string element type."); ONNXCategoryMapperOpAdaptor operandAdaptor(*this); @@ -4738,7 +4710,7 @@ LogicalResult ONNXCategoryMapperOp::inferShapes( Type outputElementType; if (inputElementType.isInteger(64)) - outputElementType = onnxmlir::StringType::get(getContext()); + outputElementType = ONNXStringType::get(getContext()); else outputElementType = IntegerType::get(getContext(), /*width=*/64); @@ -4971,68 +4943,12 @@ LogicalResult ONNXCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { FunctionType ONNXCallOp::getCalleeType() { return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); } -//===----------------------------------------------------------------------===// -// ONNX type related code -//===----------------------------------------------------------------------===// - -namespace mlir { -namespace onnxmlir { -namespace detail { -struct SeqTypeStorage : public mlir::TypeStorage { - // std::tuple, instead of std::pair, is used as the key for seq Type - // because the list of elements may be added later for lowering seq - using KeyTy = std::tuple; - - SeqTypeStorage(mlir::Type elementType, int64_t length) - : elementType(elementType), seqLength(length) {} - - bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, seqLength); - } - static llvm::hash_code hasKey(const KeyTy &key) { - mlir::Type eT; - int64_t l; - std::tie(eT, l) = key; - return llvm::hash_combine(eT, l); - } - - static KeyTy getKey(mlir::Type elementType, int64_t length) { - return KeyTy(elementType, length); - } - - static SeqTypeStorage *construct( - mlir::TypeStorageAllocator &allocator, const KeyTy &key) { - mlir::Type eT; - int64_t l; - std::tie(eT, l) = key; - return new (allocator.allocate()) SeqTypeStorage(eT, l); - } - mlir::Type elementType; // Type for element of Seq - int64_t seqLength; // Length of Seq. -1 when is not statically known -}; -} // end namespace detail -} // end namespace onnxmlir -} // end namespace mlir - -onnxmlir::SeqType onnxmlir::SeqType::get( - mlir::Type elementType, int64_t length) { - mlir::MLIRContext *ctx = elementType.getContext(); - return Base::get(ctx, elementType, length); -} - -mlir::Type onnxmlir::SeqType::getElementType() const { - return getImpl()->elementType; -} - -int64_t onnxmlir::SeqType::getLength() const { return getImpl()->seqLength; } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES - -using namespace onnxmlir; #include "src/Dialect/ONNX/ONNXOps.cpp.inc" template struct ONNXGenericPoolShapeHelper { - -public: - using Base::Base; - using Base::getChecked; - - static StringType get(MLIRContext *ctx) { return Base::get(ctx); } -}; - -namespace detail { -struct SeqTypeStorage; -} // namespace detail - -class SeqType - : public mlir::Type::TypeBase { -public: - using Base::Base; - - static SeqType get(mlir::Type elementType, int64_t length = -1); - - mlir::Type getElementType() const; - - // Return the length of the sequence. - // 0 : if the seq is empty - // -1 if unknown at compiler time - int64_t getLength() const; -}; - -} // end namespace onnxmlir -} // end namespace mlir diff --git a/src/Dialect/ONNX/ONNXOpsHelper.cpp b/src/Dialect/ONNX/ONNXOpsHelper.cpp index 9ce4efe0360e..fd24e52e3bdf 100644 --- a/src/Dialect/ONNX/ONNXOpsHelper.cpp +++ b/src/Dialect/ONNX/ONNXOpsHelper.cpp @@ -20,7 +20,6 @@ // Identity affine using namespace mlir; -using namespace mlir::onnxmlir; //====-------------------------- ONNX Builder ---------------------------===// diff --git a/src/Dialect/ONNX/ONNXTypes.cpp b/src/Dialect/ONNX/ONNXTypes.cpp new file mode 100644 index 000000000000..05c31baa420d --- /dev/null +++ b/src/Dialect/ONNX/ONNXTypes.cpp @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------ ONNXTypes.cpp - ONNX Types ------------------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file provides definition of utility functions for ONNX Types. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "src/Dialect/ONNX/ONNXTypes.hpp" + +using namespace mlir; + +// ONNXTyps.cpp.inc is NOT included here, but in ONNXOps.cpp. +// The reason is that the type is used in dialect initialization + +// This file is for utility functions for type definition if there is any. diff --git a/src/Dialect/ONNX/ONNXTypes.hpp b/src/Dialect/ONNX/ONNXTypes.hpp new file mode 100644 index 000000000000..4922f65ea113 --- /dev/null +++ b/src/Dialect/ONNX/ONNXTypes.hpp @@ -0,0 +1,21 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===--------------------- ONNXTypes.hpp - ONNX Types ---------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file defines types in ONNX Dialect. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/IR/BuiltinTypes.h" +#include "src/Dialect/ONNX/ONNXDialect.hpp" + +#define GET_TYPEDEF_CLASSES +#include "src/Dialect/ONNX/ONNXTypes.hpp.inc" diff --git a/src/Dialect/ONNX/Rewrite.td b/src/Dialect/ONNX/Rewrite.td index badd4de315ca..0aed37b3092b 100644 --- a/src/Dialect/ONNX/Rewrite.td +++ b/src/Dialect/ONNX/Rewrite.td @@ -15,7 +15,7 @@ #define ONNX_REWRITE #ifndef OP_BASE -include "src/Dialect/ONNX/ONNXOps.td" +include "src/Dialect/ONNX/ONNX.td" #endif // OP_BASE /// Note: The DRR definition used for defining patterns is shown below: diff --git a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp index c23eccf6733e..c8b1c61ecb21 100644 --- a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp +++ b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp @@ -101,7 +101,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); registerTransformsPasses(); diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp index f7d7546ab347..fb9f897dc6d8 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -763,7 +763,7 @@ void ConstPropONNXToONNXPass::runOnOperation() { MLIRContext *context = &getContext(); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); RewritePatternSet patterns(context); populateWithGenerated(patterns); diff --git a/src/Transform/ONNX/ConstProp.td b/src/Transform/ONNX/ConstProp.td index 0a8896f94656..518b051fa9de 100644 --- a/src/Transform/ONNX/ConstProp.td +++ b/src/Transform/ONNX/ConstProp.td @@ -15,7 +15,7 @@ #define ONNX_CONSTPROP #ifndef OP_BASE -include "src/Dialect/ONNX/ONNXOps.td" +include "src/Dialect/ONNX/ONNX.td" #endif // OP_BASE //===----------------------------------------------------------------------===// diff --git a/src/Transform/ONNX/Decompose.cpp b/src/Transform/ONNX/Decompose.cpp index 24509e7688de..c4bcdc24bbc1 100644 --- a/src/Transform/ONNX/Decompose.cpp +++ b/src/Transform/ONNX/Decompose.cpp @@ -141,7 +141,7 @@ void DecomposeONNXToONNXPass::runOnOperation() { MLIRContext *context = &getContext(); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); // These ops will be decomposed into other ONNX ops. Hence, they will not be // available after this pass. diff --git a/src/Transform/ONNX/Decompose.td b/src/Transform/ONNX/Decompose.td index 07b88eb36ed8..954e92ec18e3 100644 --- a/src/Transform/ONNX/Decompose.td +++ b/src/Transform/ONNX/Decompose.td @@ -15,7 +15,7 @@ #define ONNX_DECOMPOSE #ifndef OP_BASE -include "src/Dialect/ONNX/ONNXOps.td" +include "src/Dialect/ONNX/ONNX.td" #endif // OP_BASE /// Note: The DRR definition used for defining patterns is shown below: diff --git a/src/Transform/ONNX/InstrumentONNXPass.cpp b/src/Transform/ONNX/InstrumentONNXPass.cpp index 62f41c29f876..3675131ec7e1 100644 --- a/src/Transform/ONNX/InstrumentONNXPass.cpp +++ b/src/Transform/ONNX/InstrumentONNXPass.cpp @@ -91,7 +91,7 @@ class InstrumentONNXPass // Iterate on the operations nested in this function getOperation().walk([&](mlir::Operation *op) { - if (isa(op->getDialect())) { + if (isa(op->getDialect())) { // Skip the prefix "onnx." of onnx op name const char *opName = op->getName().getStringRef().data() + 5; if (!allOpsAllowed && allowedOps.find(opName) == allowedOps.end()) diff --git a/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp b/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp index 5b5937b52733..4b97e82a67a5 100644 --- a/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp +++ b/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp @@ -51,7 +51,7 @@ class ONNXPreKrnlVerifyPass // Iterate on the operations for (Operation &op : funcBody.getOps()) { - if (isa(op.getDialect())) { + if (isa(op.getDialect())) { if (failed(verifyRanked(op))) signalPassFailure(); } diff --git a/test/onnx2mlir/CustomFnTest.cpp b/test/onnx2mlir/CustomFnTest.cpp index 811dba22e017..0a7a36c1f9d7 100644 --- a/test/onnx2mlir/CustomFnTest.cpp +++ b/test/onnx2mlir/CustomFnTest.cpp @@ -73,7 +73,7 @@ void RegisterFunSchema() { void registerDialects(mlir::MLIRContext &context) { context.getOrLoadDialect(); - context.getOrLoadDialect(); + context.getOrLoadDialect(); } int check(ModelProto &model) {