Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Add decomposition for aten::native_layer_norm #13

Merged
merged 1 commit into from
Jul 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "PassDetail.h"

#include <numeric>
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
Expand Down Expand Up @@ -1499,6 +1500,88 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
};
} // namespace

namespace {
class DecomposeAtenNativeLayerNormOp
: public OpRewritePattern<AtenNativeLayerNormOp> {
using OpRewritePattern<AtenNativeLayerNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeLayerNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();

auto inputTy = op.input().getType().cast<BaseTensorType>();
if (!inputTy.hasSizes())
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
int64_t inputRank = inputTy.getSizes().size();
Value normalizedShape = op.normalized_shape();
SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
SmallVector<int64_t> reduceDimInts(normalizedShapeSizesTorchInt.size());
std::iota(reduceDimInts.begin(), reduceDimInts.end(), axis);
auto reducedTy = op.getResult(1).getType();
auto sizeListType = ListType::get(IntType::get(context));

// build reduce dims
SmallVector<Value> reduceDimVals;
reduceDimVals.reserve(reduceDimInts.size());
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
std::back_inserter(reduceDimVals), [&](int64_t d) {
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d));
});
Value reduceDimList =
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));

Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
// mean(x)
Value inputMean = rewriter.create<AtenMeanDimOp>(
loc, reducedTy, op.input(), reduceDimList, cstTrue, none);

// x - mean(x)
Value inputMeanExpanded = rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.input());
Value inputZeroMean = rewriter.create<AtenSubTensorOp>(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputZeroMean => inputOneMean ?

loc, inputTy, op.input(), inputMeanExpanded, one);
// var(x) = mean((x - mean(x))^2)
Value inputZeroMeanSquare = rewriter.create<AtenMulTensorOp>(
loc, inputTy, inputZeroMean, inputZeroMean);
Value inputVar = rewriter.create<AtenMeanDimOp>(
loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none);

// rsqrt(var(x) + eps)
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
loc, reducedTy, inputVar, op.eps(), one);
Value inputRsqrtVar =
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);

// (x - mean(x)) * rsqrt(var(x) + eps)
Value inputRsqrtVarExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputRsqrtVar, op.input());
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
loc, inputTy, inputZeroMean, inputRsqrtVarExpanded);
Value out = rewriter.create<TensorStaticInfoCastOp>(loc, op.getResult(0).getType(),
inputNormalized);

Value weight = op.weight();
Value bias = op.bias();
if (!weight.getType().isa<Torch::NoneType>()) {
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
}
if (!bias.getType().isa<Torch::NoneType>()) {
out =
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
}
rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar});

return success();
}
};
} // namespace

namespace {
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
Expand Down Expand Up @@ -2239,6 +2322,9 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<DecomposeAtenNativeLayerNormOp>(context);

target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenConvolutionOverrideableOp>();
Expand Down