From 24e62148dd14a41f00f06a56af6c0892074176ee Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Wed, 20 Jul 2022 10:50:39 +0800 Subject: [PATCH] Add decomposition for aten::native_layer_norm (#13) --- .../Torch/Transforms/DecomposeComplexOps.cpp | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 77c0ec6d2610..2bbb3fa36b52 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9,6 +9,7 @@ #include "PassDetail.h" +#include #include "mlir/IR/BuiltinDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -1670,6 +1671,88 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenNativeLayerNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeLayerNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + + auto inputTy = op.input().getType().cast(); + if (!inputTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + int64_t inputRank = inputTy.getSizes().size(); + Value normalizedShape = op.normalized_shape(); + SmallVector normalizedShapeSizesTorchInt; + getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); + int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); + SmallVector reduceDimInts(normalizedShapeSizesTorchInt.size()); + std::iota(reduceDimInts.begin(), reduceDimInts.end(), axis); + auto reducedTy = op.getResult(1).getType(); + auto sizeListType = ListType::get(IntType::get(context)); + + // build reduce dims + SmallVector reduceDimVals; + reduceDimVals.reserve(reduceDimInts.size()); + std::transform(reduceDimInts.begin(), reduceDimInts.end(), + std::back_inserter(reduceDimVals), [&](int64_t d) { + return rewriter.create( + loc, rewriter.getI64IntegerAttr(d)); + }); + Value reduceDimList = + rewriter.create(loc, sizeListType, reduceDimVals); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + Value cstTrue = rewriter.create(loc, true); + Value none = rewriter.create(loc); + // mean(x) + Value inputMean = rewriter.create( + loc, reducedTy, op.input(), reduceDimList, cstTrue, none); + + // x - mean(x) + Value inputMeanExpanded = rewriter.create(loc, inputTy, inputMean, op.input()); + Value inputZeroMean = rewriter.create( + loc, inputTy, op.input(), inputMeanExpanded, one); + // var(x) = mean((x - mean(x))^2) + Value inputZeroMeanSquare = rewriter.create( + loc, inputTy, inputZeroMean, inputZeroMean); + Value inputVar = rewriter.create( + loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none); + + // rsqrt(var(x) + eps) + Value inputVarPlusEps = rewriter.create( + loc, reducedTy, inputVar, op.eps(), one); + Value inputRsqrtVar = + rewriter.create(loc, reducedTy, inputVarPlusEps); + + // (x - mean(x)) * rsqrt(var(x) + eps) + Value inputRsqrtVarExpanded = + rewriter.create(loc, inputTy, inputRsqrtVar, op.input()); + Value inputNormalized = rewriter.create( + loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); + Value out = rewriter.create(loc, op.getResult(0).getType(), + inputNormalized); + + Value weight = op.weight(); + Value bias = op.bias(); + if (!weight.getType().isa()) { + out = rewriter.create(loc, out.getType(), out, weight); + } + if (!bias.getType().isa()) { + out = + rewriter.create(loc, out.getType(), out, bias, one); + } + rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar}); + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops. class DecomposeAtenEmptyLikeOp : public OpRewritePattern { @@ -2676,6 +2759,9 @@ class DecomposeComplexOpsPass target.addIllegalOp(); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp();