Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add group norm op #946

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3932,6 +3932,38 @@ def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [
}];
}

def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_IntType:$N,
Torch_IntType:$C,
Torch_IntType:$HxW,
Torch_IntType:$group,
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 3);
}
void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 3);
}
}];
}

def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
233 changes: 233 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"

#include "../PassDetail.h"
Expand Down Expand Up @@ -1621,6 +1623,235 @@ class ConvertTensorStaticInfoCastOp
};
} // namespace

namespace {
class ConvertAtenNativeGroupNormOp
: public OpConversionPattern<AtenNativeGroupNormOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenNativeGroupNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
Value N = adaptor.N();
Value C = adaptor.C();
Value HxW = adaptor.HxW();
Value group = adaptor.group();
Value eps = adaptor.eps();

op->getParentOfType<ModuleOp>()->dump();

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

// TODO: Handle the None cases for the optional parameters:
// weight, bias.
if (failed(checkNotNone(rewriter, op, weight)) ||
failed(checkNotNone(rewriter, op, bias)))
return failure();

auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
Type elemTy = inputType.getElementType();

// Common parts to be used for getting mean and var.

// The shape of mean and var is always {N, num_groups}. Hence its rank is 2.
int64_t meanAndVarShapeRank = 2;
// Get sizes and affineMaps needed for mean and var.
SmallVector<AffineExpr> inputExprs, meanAndVarShapeExprs;
for (int i = 0; i < inputRank; i++)
inputExprs.push_back(mlir::getAffineDimExpr(i, context));
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(0, context));
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(inputRank, context));
auto inputShapeAffineMap = AffineMap::get(
/*dimCount=*/inputRank + 1,
/*symbolCount=*/0, inputExprs, context);
auto meanAndVarShapeAffineMap = AffineMap::get(
/*dimCount=*/inputRank + 1,
/*symbolCount=*/0, meanAndVarShapeExprs, context);
SmallVector<Value> meanAndVarShapeSizes;
meanAndVarShapeSizes.push_back(castIntToIndex(rewriter, loc, N));
meanAndVarShapeSizes.push_back(castIntToIndex(rewriter, loc, group));

// Get number of elements to be used for calculating mean and var.
Value elemCnts = rewriter.create<arith::MulIOp>(loc, C, HxW);
elemCnts = rewriter.create<arith::DivSIOp>(loc, elemCnts, group);
Value elemCntsFloat =
rewriter.create<arith::SIToFPOp>(loc, elemTy, elemCnts);

// Get iterator types for input shape.
SmallVector<StringRef> meanAndVarIterationTypes(
meanAndVarShapeRank, getParallelIteratorTypeName());
SmallVector<StringRef> iteratorTypes;
iteratorTypes.push_back(getParallelIteratorTypeName());
for (unsigned i = 1; i < inputRank; i++)
iteratorTypes.push_back(getReductionIteratorTypeName());
iteratorTypes.push_back(getParallelIteratorTypeName());

// Helper to calculate mean and var.
auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
SmallVector<AffineMap> indexingMaps(
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
Value initShapeTensor = rewriter.create<linalg::InitTensorOp>(
loc, meanAndVarShapeSizes, elemTy);
return rewriter
.create<linalg::GenericOp>(
loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/meanAndVarIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value sumOrSqureSum = args[0];
Value result =
b.create<arith::DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
};

// Get mean.

// Get sum to be used for calculating mean.
SmallVector<AffineMap, 2> sumIndexingMaps = {
inputShapeAffineMap, // input
meanAndVarShapeAffineMap, // output
};
auto initSumTensor =
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
Value sum = rewriter
.create<linalg::GenericOp>(
loc, initSumTensor.getType(), input, initSumTensor,
/*indexingMaps=*/sumIndexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], sum = args[1];
Value result =
rewriter.create<arith::AddFOp>(loc, sum, input);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Value mean = genMeanOrVarCalculation(sum);

// Get rSTD.

// Calculate squareSum for the layer.
SmallVector<AffineMap> squareSumIndexingMaps{
inputShapeAffineMap,
meanAndVarShapeAffineMap,
meanAndVarShapeAffineMap,
};
auto initSquareSumTensor =
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
Value squareSum =
rewriter
.create<linalg::GenericOp>(
loc, initSquareSumTensor.getType(), ValueRange{input, mean},
initSquareSumTensor,
/*indexingMaps=*/squareSumIndexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], squareSum = args[2];
Value sub = rewriter.create<arith::SubFOp>(loc, input, mean);
Value square = rewriter.create<arith::MulFOp>(loc, sub, sub);
Value result =
rewriter.create<arith::AddFOp>(loc, squareSum, square);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Value var = genMeanOrVarCalculation(squareSum);
Value rSTDTensor = rewriter.create<linalg::InitTensorOp>(
loc, meanAndVarShapeSizes, elemTy);
SmallVector<AffineMap> rSTDIndexingMap(
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));

Value rSTD = rewriter
.create<linalg::GenericOp>(
loc, rSTDTensor.getType(), var, rSTDTensor,
rSTDIndexingMap, meanAndVarIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value result =
calculateRSTD(b, loc, elemTy, eps, args[0]);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);

// Get groupnorm.

// Get affineMap for normalized shape.
SmallVector<AffineExpr> normalizedShapeExprs;
for (int i = meanAndVarShapeRank; i < inputRank; i++)
normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
auto normalizedShapeAffineMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, normalizedShapeExprs, context);
auto inputSizes = getTensorSizes(rewriter, loc, input);
Value initLayerNormTensor =
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
SmallVector<AffineMap> indexingMaps(1, inputShapeAffineMap);
indexingMaps.resize(3, meanAndVarShapeAffineMap);
// indexingMaps.resize(5, normalizedShapeAffineMap);
indexingMaps.push_back(inputShapeAffineMap);
SmallVector<StringRef> layerNormIterationTypes(
inputRank, getParallelIteratorTypeName());
Value layerNorm =
rewriter
.create<linalg::GenericOp>(
loc, initLayerNormTensor.getType(),
ValueRange{input, mean, rSTD, weight, bias},
initLayerNormTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/layerNormIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], rSTD = args[2],
weight = args[3], bias = args[4];
Value result =
createLinalgPayloadCalculationForNormOpsWithRSTD(
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
SmallVector<int64_t> expandShape(inputRank, 1);
for (int i = 0; i < meanAndVarShapeRank; i++) {
// `mean` and `rstd` are not yet casted, so they will be having
// dynamic
// shape. Hence to match them, for each dimension corresponding to
// `mean`
// or `rstd` assign -1.
expandShape[i] = -1;
}
auto expandShapeType = RankedTensorType::get(expandShape, elemTy);
SmallVector<ReassociationIndices> reassociation(meanAndVarShapeRank);
for (auto i : llvm::seq<int64_t>(0, meanAndVarShapeRank)) {
reassociation[i].push_back(i);
if (i == meanAndVarShapeRank - 1) {
for (auto j : llvm::seq<int64_t>(0, normalizedShapeRank))
reassociation[i].push_back(i + j + 1);
}
}
Value meanResult = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, mean, reassociation);
Value rSTDResult = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, rSTD, reassociation);
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
Type rSTDResultType = getTypeConverter()->convertType(op.getType(2));
Value layerNorm_ =
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
Value mean_ =
rewriter.create<tensor::CastOp>(loc, meanResultType, meanResult);
Value var_ =
rewriter.create<tensor::CastOp>(loc, rSTDResultType, rSTDResult);
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -1650,4 +1881,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
target.addIllegalOp<TensorStaticInfoCastOp>();
target.addIllegalOp<AtenNativeGroupNormOp>();
patterns.add<ConvertAtenNativeGroupNormOp>(typeConverter, context);
}
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,8 @@ ChangeResult TypeAnalyzer::visitOperation(
}

// 3 results take dtype from first operand.
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp>(op)) {
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp, AtenNativeGroupNormOp>(
op)) {
auto self = operands[0]->getValue();
auto result0Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6384,6 +6384,15 @@ module {
}
return %0 : !torch.tuple<list<int>, list<int>, list<int>>
}
func.func @"__torch_mlir_shape_fn.aten.native_group_norm"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
%2 = torch.aten.append.t %0, %1 : !torch.list<int>, !torch.int -> !torch.list<int>
%3 = torch.aten.append.t %0, %arg6 : !torch.list<int>, !torch.int -> !torch.list<int>
%4 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
return %4 : !torch.tuple<list<int>, list<int>, list<int>>
}
func.func @"__torch_mlir_shape_fn.aten.constant_pad_nd"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,12 @@ def aten〇native_batch_norm(input: List[int], weight: Optional[List[int]], bias
return input, [input[1]], [input[1]]
return input, [0], [0]

def aten〇native_group_norm(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]:
reduction_shape: List[int] = []
reduction_shape.append(input[0])
reduction_shape.append(group)
return input, reduction_shape, reduction_shape

# TODO: This should be upstreamed.
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
def pad_shape_fn(input: List[int], pad: List[int]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)")

# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
Expand Down
24 changes: 24 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/norm_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,27 @@ def forward(self, x):
@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 3))


# ==============================================================================

class NativeGroupNormModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([2, 5, 2, 2, 3], torch.float32, True),
([5], torch.float32, True),
([5], torch.float32, True),
])
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_group_norm(
x, weight, bias, 2, 5, 12, 1, eps=0.5)


@register_test_case(module_factory=lambda: NativeGroupNormModule())
def NativeGroupNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5))