Skip to content

Commit

Permalink
Simplify rules in zhigh-to-onnx pass and change some pass order
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <tung@jp.ibm.com>
  • Loading branch information
tungld committed Sep 24, 2024
1 parent 087f069 commit d5ac54f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 86 deletions.
32 changes: 17 additions & 15 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,12 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {

pm.addPass(onnx_mlir::createONNXToZHighPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());

// There are more opportunities for const propagation once all zhigh ops were
// generated.
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
pm.addPass(mlir::createCanonicalizerPass());

// Layout propagation at ZHighIR.
pm.addNestedPass<func::FuncOp>(
onnx_mlir::zhigh::createZHighLayoutPropagationPass());
Expand All @@ -109,14 +111,7 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
onnx_mlir::zhigh::createZHighClipToDLFloatPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
}

// After all optimizations, if there are still light-weight ops (e.g. add,
// sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to
// use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle
// these ops, e.g vectorize the computation.
if (nnpaEnableZHighToOnnx)
pm.addNestedPass<func::FuncOp>(onnx_mlir::createZHighToONNXPass());


// One more call to ONNX shape inference/canonicalization/... to update shape
// if possible.
if (enableONNXHybridPass) {
Expand All @@ -134,13 +129,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
// ZHighConstPropagation currently assumes that DenseElementsAttr is used.
pm.addPass(createScrubDisposablePass());

// Constant propagation at ZHighIR: constant stickify.
// Only support BE machines.
bool isBE = llvm::endianness::native == llvm::endianness::big;
if (isBE)
pm.addNestedPass<func::FuncOp>(
onnx_mlir::zhigh::createZHighConstPropagationPass());

// Experimental feature: Decompose stick/unstick into two phases: layout
// transform and data conversion. Do some optimizations after decomposing.
// Then, recompose again layout and data conversion if they are not optimized.
Expand All @@ -152,6 +140,20 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
onnx_mlir::zhigh::createZHighRecomposeToStickUnstickPass());
}

// After all optimizations, if there are still light-weight ops (e.g. add,
// sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to
// use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle
// these ops, e.g vectorize the computation.
if (nnpaEnableZHighToOnnx)
pm.addNestedPass<func::FuncOp>(onnx_mlir::createZHighToONNXPass());

// Constant propagation at ZHighIR: constant stickify.
// Only support BE machines.
bool isBE = llvm::endianness::native == llvm::endianness::big;
if (isBE)
pm.addNestedPass<func::FuncOp>(
onnx_mlir::zhigh::createZHighConstPropagationPass());

// Remove common sub-expressions.
pm.addPass(mlir::createCSEPass());

Expand Down
96 changes: 25 additions & 71 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,107 +37,61 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create<ONNXMaxOp>($_loc, $0.getT
// ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighAddPattern1 : Pat<
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
(ONNXAddOp $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)]
>;

def replaceZHighAddPattern2 : Pat<
(ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
(ONNXAddOp (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)]
def replaceZHighAddPattern : Pat<
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
(ONNXAddOp $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
>;

//===----------------------------------------------------------------------===//
// ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighMulPattern1 : Pat<
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
(ONNXMulOp $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
(addBenefit 1)
>;

def replaceZHighMulPattern2 : Pat<
(ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
(ONNXMulOp (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [],
(addBenefit 0)
def replaceZHighMulPattern : Pat<
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
(ONNXMulOp $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
>;

//===----------------------------------------------------------------------===//
// ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighSubPattern1 : Pat<
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
(ONNXSubOp $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
(addBenefit 1)
>;

def replaceZHighSubPattern2 : Pat<
(ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
(ONNXSubOp (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
(addBenefit 0)
def replaceZHighSubPattern : Pat<
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
(ONNXSubOp $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
>;

//===----------------------------------------------------------------------===//
// ONNXDivOp %X = ZHighUnstickOp (ZHighDivOp (ZHighStickOp
// %X),(ZHighStickOp %Y))
// Note: turn off this pattern since NNPA is faster at this moment.
//===----------------------------------------------------------------------===//
//def replaceZHighDivPattern1 : Pat<
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)),
// (ONNXDivOp $x, (ZHighUnstickOp $y)),
// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
// (addBenefit 1)
//>;
//
//def replaceZHighDivPattern2 : Pat<
// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))),
// (ONNXDivOp (ZHighUnstickOp $x), $y),
// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
// (addBenefit 0)
//>;
// def replaceZHighDivPattern : Pat<
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
// (ONNXDivOp $x, $y),
// [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
// >;

//===----------------------------------------------------------------------===//
// ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighMinPattern1 : Pat<
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
(CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
(addBenefit 1)
>;

def replaceZHighMinPattern2 : Pat<
(ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
(CreateONNXMinOp $u, (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
(addBenefit 0)
def replaceZHighMinPattern : Pat<
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
(CreateONNXMinOp $u, $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
>;

//===----------------------------------------------------------------------===//
// ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighMaxPattern1 : Pat<
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
(CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
(addBenefit 1)
>;

def replaceZHighMaxPattern2 : Pat<
(ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
(CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
(addBenefit 0)
def replaceZHighMaxPattern : Pat<
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
(CreateONNXMaxOp $u, $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
>;

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit d5ac54f

Please sign in to comment.