Skip to content

Commit

Permalink
Generate dialect and type for Dialect/ONNX with Tablegen (llvm#1141)
Browse files Browse the repository at this point in the history
* compiled without type

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* change ONNXOpsDialect to ONNXDialect

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* new CMAKE

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* first step to add type

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* ONNXOps.td compiled with type

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* refine CMake

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* compiled with stringtype conflict

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* chang StringType and pass lit test

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* cleanup

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* format

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* change MLIR.cmake

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* split files

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* fix

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* format

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* fix cmake

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* comments

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* line length

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* typo

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* cmake order

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* change for new llvm

Signed-off-by: Tong Chen <chentong@us.ibm.com>
  • Loading branch information
chentong319 committed Feb 7, 2022
1 parent 31648cf commit 099493f
Show file tree
Hide file tree
Showing 25 changed files with 219 additions and 213 deletions.
12 changes: 9 additions & 3 deletions MLIR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,16 @@ function(add_onnx_mlir_dialect_doc dialect dialect_tablegen_file)
endfunction()
add_custom_target(onnx-mlir-docs)

function(add_onnx_mlir_dialect dialect)
# If an extra parameter, the dialect name, is provided,
# this function will generate dialect and type from the td file
function(add_onnx_mlir_dialect dialect dialect_name)
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
mlir_tablegen(${dialect}.hpp.inc -gen-op-decls "-I${ONNX_MLIR_SRC_ROOT}")
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs "-I${ONNX_MLIR_SRC_ROOT}")
mlir_tablegen(${dialect}Ops.hpp.inc -gen-op-decls "-I${ONNX_MLIR_SRC_ROOT}")
mlir_tablegen(${dialect}Ops.cpp.inc -gen-op-defs "-I${ONNX_MLIR_SRC_ROOT}")
mlir_tablegen(${dialect}Dialect.hpp.inc -gen-dialect-decls -dialect=${dialect_name} "-I${ONNX_MLIR_SRC_ROOT}")
mlir_tablegen(${dialect}Dialect.cpp.inc -gen-dialect-defs -dialect=${dialect_name} "-I${ONNX_MLIR_SRC_ROOT}")
mlir_tablegen(${dialect}Types.hpp.inc -gen-typedef-decls -typedefs-dialect=${dialect_name} "-I${ONNX_MLIR_SRC_ROOT}")
mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs -typedefs-dialect=${dialect_name} "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(OM${dialect}IncGen)
endfunction()

Expand Down
2 changes: 1 addition & 1 deletion src/Builder/FrontendDialectHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ mlir::Type convertONNXTypeToMLIRType(
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return builder_.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
return mlir::onnxmlir::StringType::get(builder_.getContext());
return mlir::ONNXStringType::get(builder_.getContext());

case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
Expand Down
6 changes: 4 additions & 2 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ class FrontendGenImpl {
assert(elem_type.value_case() == onnx::TypeProto::kTensorType &&
"expect tensor inside sequence type");
Type mlir_elem_type = ImportTensorType(elem_type);
Type seq_type = mlir::onnxmlir::SeqType::get(mlir_elem_type);
if (!mlir_elem_type.isa<ShapedType>())
llvm_unreachable("Seq type is incorrect");
Type seq_type = mlir::SeqType::get(mlir_elem_type.cast<ShapedType>(), -1);
return seq_type;
}
llvm_unreachable("unexpected type");
Expand Down Expand Up @@ -1274,7 +1276,7 @@ class FrontendGenImpl {
std::string comma = std::string("");

TypeSwitch<Type>(argType)
.Case<mlir::onnxmlir::SeqType>([&](mlir::onnxmlir::SeqType seqTy) {
.Case<mlir::SeqType>([&](mlir::SeqType seqTy) {
auto et = seqTy.getElementType();
dstream << " {\"seq\" : ";
concatTypeString(et, attr, dstream);
Expand Down
7 changes: 3 additions & 4 deletions src/Compiler/CompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ void registerDialects(mlir::MLIRContext &context) {
context.getOrLoadDialect<mlir::shape::ShapeDialect>();
context.getOrLoadDialect<mlir::math::MathDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
context.getOrLoadDialect<mlir::ONNXOpsDialect>();
context.getOrLoadDialect<mlir::ONNXDialect>();
context.getOrLoadDialect<mlir::KrnlOpsDialect>();
}

Expand Down Expand Up @@ -736,9 +736,8 @@ InputIRLevelType determineInputIRLevel(mlir::OwningModuleRef &module) {
});

// If there are ONNX ops, the input level is ONNX.
bool hasONNXOps = llvm::any_of(dialectNamespace, [&](StringRef ns) {
return (ns == ONNXOpsDialect::getDialectNamespace());
});
bool hasONNXOps = llvm::any_of(dialectNamespace,
[&](StringRef ns) { return (ns == ONNXDialect::getDialectNamespace()); });
if (hasONNXOps)
return ONNXLevel;

Expand Down
6 changes: 3 additions & 3 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ bool hasAllScalarValues(ArrayRef<Value> values) {
MemRefType convertToMemRefType(Type type) {
// Convert the element type of the (tensor or memref) to a valid Krnl type.
auto convertElemType = [](Type elemType) -> Type {
if (elemType.isa<onnxmlir::StringType>())
if (elemType.isa<ONNXStringType>())
return StringType::get(elemType.getContext());
return elemType;
};
Expand Down Expand Up @@ -755,13 +755,13 @@ KrnlTypeConverter::KrnlTypeConverter() {
// The order of type conversion is important: later ones are tried earlier.
addConversion([](Type type) { return type; });

addConversion([](onnxmlir::StringType stringType) {
addConversion([](ONNXStringType stringType) {
return StringType::get(stringType.getContext());
});

addConversion([](TensorType tensorType) {
assert(tensorType.hasRank() && "expected only ranked shapes");
if (tensorType.getElementType().isa<onnxmlir::StringType>()) {
if (tensorType.getElementType().isa<ONNXStringType>()) {
Type elementType = StringType::get(tensorType.getContext());
return MemRefType::get(tensorType.getShape(), elementType);
}
Expand Down
6 changes: 3 additions & 3 deletions src/Dialect/Krnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "s390x")
llvm_replace_compiler_option(CMAKE_CXX_FLAGS_RELEASE "-O3" "-O1")
endif()

add_onnx_mlir_dialect(KrnlOps)
add_onnx_mlir_dialect_doc(krnl KrnlOps.td)
add_onnx_mlir_dialect(Krnl krnl)
add_onnx_mlir_dialect_doc(krnl Krnl.td)

add_onnx_mlir_library(OMKrnlOps
KrnlOps.cpp
KrnlTypes.cpp
KrnlHelper.cpp

DEPENDS
OMKrnlOpsIncGen
OMKrnlIncGen

LINK_LIBS PUBLIC
OMONNXOps
Expand Down
File renamed without changes.
12 changes: 7 additions & 5 deletions src/Dialect/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "s390x")
llvm_replace_compiler_option(CMAKE_CXX_FLAGS_RELEASE "-O3" "-O1")
endif()

add_onnx_mlir_dialect(ONNXOps)
add_onnx_mlir_dialect_doc(onnx ONNXOps.td)
add_onnx_mlir_dialect(ONNX onnx)
add_onnx_mlir_dialect_doc(onnx ONNX.td)

add_onnx_mlir_rewriter(Rewrite)

add_onnx_mlir_library(OMONNXOps
ONNXOps.cpp
ONNXOpsHelper.cpp
IndexExpr.cpp
IndexExprDetail.cpp
MLIRDialectBuilder.cpp
ONNXDialect.cpp
ONNXOps.cpp
ONNXOpsHelper.cpp
ONNXTypes.cpp
Rewrite.cpp

ShapeInference/ArgMax.cpp
Expand Down Expand Up @@ -50,7 +52,7 @@ add_onnx_mlir_library(OMONNXOps

DEPENDS
OMHasOnnxSubgraphOpInterfaceIncGen
OMONNXOpsIncGen
OMONNXIncGen
OMONNXRewriteIncGen
OMResultTypeInferenceOpInterfaceIncGen
OMShapeInferenceOpInterfaceIncGen
Expand Down
50 changes: 49 additions & 1 deletion src/Dialect/ONNX/ONNXOps.td → src/Dialect/ONNX/ONNX.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ include "src/Interface/HasOnnxSubgraphOpInterface.td"
// Definition for rewrite rules for onnx dialect
// Can be used in other table gen files (.td) for onnx dialect

def StringType : Type<CPred<"$_self.isa<StringType>()">, "string type">;
def StringType : Type<CPred<"$_self.isa<ONNXStringType>()">, "string type">;

def IsSeqTypePred : CPred<"$_self.isa<SeqType>()">;

Expand Down Expand Up @@ -204,4 +204,52 @@ def ONNXBatchNormalizationInferenceModeOp: ONNX_Op<"BatchNormalizationInferenceM
}];
}

class ONNX_Type<string name, string typeMnemonic,
string baseCppClass = "::mlir::Type">
: TypeDef<ONNX_Dialect, name, [], baseCppClass> {
let mnemonic = typeMnemonic;
}

def ONNX_StringType : ONNX_Type<"ONNXString", "String"> {
let summary = " ONNX StringType";
let description = [{
An array of characters.
}];
}

def ONNX_SeqType : ONNX_Type<"Seq", "Seq"> {
let summary = " ONNX SeqType";
let description = [{
An list of tensors which may have different shape
}];
let parameters = (ins "::mlir::ShapedType":$ElementType, "int64_t":$Length);

// Previous implementation did not print/parse the length field
// May add the field in future
let printer = [{
$_printer << "<" << getImpl()->ElementType << ">";
}];
let parser = [{
if (parser.parseLess())
return Type();
Type elementType;
if ($_parser.parseType(elementType))
return Type();
if ($_parser.parseGreater())
return Type();
if (!elementType.isa<ShapedType>())
return Type();
else {
ShapedType ty = elementType.cast<ShapedType>();
return get($_ctxt, ty, -1);
}
}];
let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::ShapedType":$elementType,
"int64_t":$length), [{
return Base::get(elementType.getContext(), elementType, length);
}]>
];
}

#endif // ONNX_OPS
23 changes: 23 additions & 0 deletions src/Dialect/ONNX/ONNXDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------------------ ONNXDialect.cpp - ONNX Operations -----------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file provides definition of ONNX dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/DialectImplementation.h"

#include "src/Dialect/ONNX/ONNXDialect.hpp"

using namespace mlir;

// Code for ONNX_Dialect class
#include "src/Dialect/ONNX/ONNXDialect.cpp.inc"
19 changes: 19 additions & 0 deletions src/Dialect/ONNX/ONNXDialect.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ONNXDialect.hpp - ONNX Operations -------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file defines ONNX Dialect class.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "mlir/IR/Dialect.h"

#include "src/Dialect/ONNX/ONNXDialect.hpp.inc"
Loading

0 comments on commit 099493f

Please sign in to comment.