Skip to content

Commit

Permalink
Fuse convolution and batch normalization (llvm#253)
Browse files Browse the repository at this point in the history
* Rewriting rule

* Fix formulas

* Reuse op results

* Const propagation for Div and Sqrt

* Explicitly use ONNXConstantOp

* Minor revise

* Const propagation for unsqueeze

* Do const propagationnce all tensors have inferred shapes

* LIT tests for fusion

* Add LIT tests for constant propagation on Div, Sqrt, and Unsqueeze

* Missing dash

Co-authored-by: Tian Jin <tjingrant@gmail.com>
  • Loading branch information
tungld and tjingrant authored Aug 18, 2020
1 parent 38bd77e commit 7c1e678
Show file tree
Hide file tree
Showing 10 changed files with 409 additions and 37 deletions.
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX BatchNormalization operation in test mode";
let hasCanonicalizer = 1;
let description = [{
"Carries out batch normalization as described in the paper"
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
Expand Down
102 changes: 69 additions & 33 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -2894,17 +2894,29 @@ def ONNXNegOp:ONNX_Op<"Neg",
}];
let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$X);
let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{
auto elementType = X.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), X);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
}

def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression",
Expand Down Expand Up @@ -5098,17 +5110,29 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt",
}];
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X);
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{
auto elementType = X.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), X);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
}

def ONNXSqueezeOp:ONNX_Op<"Squeeze",
Expand Down Expand Up @@ -5574,17 +5598,29 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$data,
I64ArrayAttr:$axes);
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$expanded);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value data, ArrayAttr axes", [{
auto elementType = data.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), data, axes);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
}

def ONNXUpsampleOp:ONNX_Op<"Upsample",
Expand Down
5 changes: 5 additions & 0 deletions src/MainUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,11 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createAttributePromotionPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass());
// There are more opportunities for const propagation once all tensors have
// inferred shapes.
pm.addPass(mlir::createConstPropONNXToONNXPass());
// Clean dead code.
pm.addPass(mlir::createSymbolDCEPass());
}

void addONNXToKrnlPasses(mlir::PassManager &pm) {
Expand Down
55 changes: 55 additions & 0 deletions src/Transform/ONNX/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"

#include <math.h>

using namespace mlir;

namespace {
Expand Down Expand Up @@ -120,6 +122,26 @@ Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
llvm_unreachable("constant propagation for MulOp: unkonwn data type");
}

template <>
Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
assert(rhsVal != 0 && "division by a zero");
double res = lhsVal / rhsVal;
return rewriter.getFloatAttr(elementType, res);
}
if (elementType.isa<IntegerType>()) {
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
assert(rhsVal != 0 && "division by a zero");
uint64_t res = lhsVal / rhsVal;
return rewriter.getIntegerAttr(elementType, res);
}
llvm_unreachable("constant propagation for DivOp: unkonwn data type");
}
// Recursively process one dimension in the rank of the two references. There
// can be one of 3 cases.
// 1) We have fully defined accesses for both operands, launch the computations.
Expand Down Expand Up @@ -246,6 +268,17 @@ Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
llvm_unreachable("constant propagation for NegOp: unkonwn data type");
}

template <>
Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
if (elementType.isa<FloatType>()) {
double val = attr.cast<FloatAttr>().getValueAsDouble();
double res = sqrt(val);
return rewriter.getFloatAttr(elementType, res);
}
llvm_unreachable("constant propagation for SqrtOp: unkonwn data type");
}

template <typename ElementwiseUnaryOp>
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
Expand Down Expand Up @@ -340,6 +373,28 @@ DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
return DenseElementsAttr::get(resType, resRef);
}

//===----------------------------------------------------------------------===//
// Code to perform constant propagation for unsqueeze.
//===----------------------------------------------------------------------===//

DenseElementsAttr ConstPropUnsqueeze(
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
// Read dense attribute, the constant tensor we are transforming.
DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
assert(denseAttr && "expected dense attribute");
ShapedType resType = resOperand.getType().cast<RankedTensorType>();

// Unqueeze does not change the order of access, so just copy the whole data.
std::vector<Attribute> resVector;
for (auto value : denseAttr.getValues<Attribute>()) {
resVector.emplace_back(value);
}

ArrayRef<Attribute> resRef(resVector);
return DenseElementsAttr::get(resType, resRef);
}

//===----------------------------------------------------------------------===//
// Pattern definition.
//===----------------------------------------------------------------------===//
Expand Down
43 changes: 40 additions & 3 deletions src/Transform/ONNX/ConstProp.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,21 @@ def CreateSubOfTwoConst :
def CreateNegOfConst :
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;

def CreateMulOfTwoConst :
def CreateSqrtOfConst :
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXSqrtOp>($_builder, $0, $1)">;

def CreateMulOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;

def CreateDivOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXDivOp>($_builder, $0, $1, $2)">;

def CreateTransposeOfConst :
NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">;

def CreateUnsqueezeOfConst:
NativeCodeCall<"ConstPropUnsqueeze($_builder, $0, $1)">;

//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise ADD operations.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -163,7 +172,14 @@ def SubConstToNeg : Pat<
(ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))),
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;


// Constant Propagation for Sqrt
def SqrtofConst : Pat<
// From onnx.Sqrt(c)
(ONNXSqrtOp (ONNXConstantOp:$constOp $s, $v)),
// To sqrt(c)
(ONNXConstantOp (GetNullAttr), (CreateSqrtOfConst $constOp, $v)),
[(AttributeIsNull:$s)]>;

//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise MUL operations.
// Exactly the same pattern as for the elementwise ADD operations.
Expand Down Expand Up @@ -232,6 +248,16 @@ def MulConstProp : Pat<
// Mulitional constraints (no sparse)
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;

// Constant Propagation for Div
def DivConstProp : Pat<
// From div(c1, c2).
(ONNXDivOp:$mulOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
// To c1/c2
(ONNXConstantOp (GetNullAttr), (CreateDivOfTwoConst $mulOp, $v1, $v2)),
// Division constraints (no sparse)
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;


//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with Transpose operations.
//===----------------------------------------------------------------------===//
Expand All @@ -244,5 +270,16 @@ def TransposeofConst : Pat<
(ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)),
[(AttributeIsNull:$s)]>;


//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with Unsqueeze operations.
//===----------------------------------------------------------------------===//

def UnsqueezeofConst : Pat<
// From Unsqueeze (c, axis)
(ONNXUnsqueezeOp:$resOp (ONNXConstantOp $s, $v), $_),
// To c' where c' is the unsqueezed value.
(ONNXConstantOp (GetNullAttr), (CreateUnsqueezeOfConst $resOp, $v)),
[(AttributeIsNull:$s)]>;


#endif // ONNX_CONSTPROP
36 changes: 36 additions & 0 deletions src/Transform/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,36 @@ using namespace mlir;

namespace {

// Create a DenseElementsAttr from a float attribute.
DenseElementsAttr createDenseElementsAttrFromFloatAttr(
PatternRewriter &rewriter, Type elementType, FloatAttr attr) {
SmallVector<int64_t, 1> dims(1, 1);
SmallVector<float, 1> values(1, attr.getValue().convertToFloat());
auto tensorType = mlir::RankedTensorType::get(dims, elementType);
return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
}

// If 'lhs' is not NoneType, return 'lhs - rhs'.
// Otherwise, return '-rhs'.
Value subtractOrNeg(
PatternRewriter &rewriter, Location loc, Value lhs, Value rhs) {
if (lhs.getType().isa<NoneType>()) {
Value result = rewriter.create<ONNXNegOp>(loc, rhs);
return result;
} else {
Value result = rewriter.create<ONNXSubOp>(loc, lhs, rhs);
return result;
}
}

// Create an ArrayAttr of IntergerAttr(s) of values in [1, N].
ArrayAttr createArrayAttrOfOneToN(PatternRewriter &rewriter, int N) {
SmallVector<int64_t, 4> vals;
for (int i = 1; i <= N; ++i)
vals.emplace_back(i);
return rewriter.getI64ArrayAttr(vals);
}

// Check whether an ArrayAttr contains non-zero values or not.
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
bool allZeros = true;
Expand Down Expand Up @@ -92,3 +122,9 @@ void ONNXConvOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ConvOpPaddingPattern>(context);
}

/// on the ONNXBatchNormalizationTestModeOp.
void ONNXBatchNormalizationTestModeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FuseBatchNormTestModeConvPattern>(context);
}
Loading

0 comments on commit 7c1e678

Please sign in to comment.