Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
support group_norm lowering (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzpmiracle committed Dec 29, 2022
1 parent ee781e6 commit 0f88ea9
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 6 deletions.
91 changes: 91 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4244,6 +4244,97 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
}];
}

def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$running_mean,
AnyTorchOptionalTensorType:$running_var,
Torch_BoolType:$use_input_stats,
Torch_FloatType:$momentum,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenInstanceNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}

def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
Torch_IntType:$num_groups,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_IntType:$N,
Torch_IntType:$C,
Torch_IntType:$HxW,
Torch_IntType:$group,
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 3);
}
void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 3);
}
}];
}

def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
121 changes: 115 additions & 6 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,119 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
return success();
}
};

class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
using OpRewritePattern<AtenGroupNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenGroupNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();
Value input = op.input();
auto inputTy = input.getType().cast<BaseTensorType>();
if (!inputTy.hasSizes())
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
ArrayRef<int64_t> inputSize = inputTy.getSizes();
int64_t inputRank = inputTy.getSizes().size();
if (inputRank != 4) {
return rewriter.notifyMatchFailure(
op, "group norm only support 4D input now.");
}
Value num_groups = op.num_groups();
int64_t num_groups_int;
if (!matchPattern(num_groups, m_TorchConstantInt(&num_groups_int)))
return rewriter.notifyMatchFailure(
op, "non const num_groups for AtenGroupNormOp");

// reshape input -> [N, G, -1(G//C), H, W]
Value negOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
SmallVector<int64_t, 5> inputForNormTySize(inputRank + 1,
ShapedType::kDynamicSize);
inputForNormTySize[1] = num_groups_int;
Type inputForNormTy = inputTy.getWithSizesAndDtype(
llvm::makeArrayRef(inputForNormTySize), inputTy.getDtype());
SmallVector<Value> orginInputSize;
for (int i = 0; i < inputRank; ++i) {
Value index =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
orginInputSize.push_back(
rewriter.create<AtenSizeIntOp>(loc, input, index));
}
SmallVector<Value, 5> inputForNormSize{orginInputSize.begin(),
orginInputSize.end()};
inputForNormSize.insert(inputForNormSize.begin() + 1, num_groups);
inputForNormSize[2] = negOne;
Value inputForNormSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), inputForNormSize);
Value reshapedInput = rewriter.create<AtenViewOp>(
loc, inputForNormTy, input, inputForNormSizeList);
// only keep N, G, reduce G//C, H, W
int64_t axis = 2;
std::vector<int64_t> meanVarTySizes(inputForNormTySize.size(), 1);
for (int i = 0; i < axis; i++)
meanVarTySizes[i] = inputForNormTySize[i];
auto meanVarTy = inputTy.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarTySizes), inputTy.getDtype());
SmallVector<Value> normalizedShapeSize{inputForNormSize.begin() + axis,
inputForNormSize.end()};
auto normalizedSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), normalizedShapeSize);

auto nativeLayerNorm =
rewriter
.create<AtenNativeLayerNormOp>(
loc, inputForNormTy, meanVarTy, meanVarTy, reshapedInput,
normalizedSizeList, none, none, op.eps())
.getResult(0);
// rehshape back to origin shape
Value inputSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), orginInputSize);
Value originOutput = rewriter.create<AtenViewOp>(
loc, op.getType(), nativeLayerNorm, inputSizeList);
// reshape weight and bias to [1, C, 1, 1]
Value weight = op.weight();
Value bias = op.bias();
if (!weight.getType().isa<Torch::NoneType>() ||
!bias.getType().isa<Torch::NoneType>()) {
SmallVector<Value> weightsAndBiasSize(inputRank - 1, one);
weightsAndBiasSize[0] = orginInputSize[1];

SmallVector<int64_t> weightsAndBiasTySize(inputRank - 1,
ShapedType::kDynamicSize);
// weightsAndBiasTySize[1] = ShapedType::kDynamicSize;

Value weightsAndBiasSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), weightsAndBiasSize);
if (!weight.getType().isa<Torch::NoneType>()) {
BaseTensorType weightType = weight.getType().cast<BaseTensorType>();
Type weightTy = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(weightsAndBiasTySize), weightType.getDtype());
weight = rewriter.create<AtenViewOp>(loc, weightTy, weight,
weightsAndBiasSizeList);
originOutput = rewriter.create<AtenMulTensorOp>(loc, op.getType(),
originOutput, weight);
}
if (!bias.getType().isa<Torch::NoneType>()) {
BaseTensorType biasType = bias.getType().cast<BaseTensorType>();
Type biasTy = biasType.getWithSizesAndDtype(
llvm::makeArrayRef(weightsAndBiasTySize), biasType.getDtype());
bias = rewriter.create<AtenViewOp>(loc, biasTy, bias,
weightsAndBiasSizeList);
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
originOutput = rewriter.create<AtenAddTensorOp>(
loc, op.getType(), originOutput, bias, alpha);
}
}
rewriter.replaceOp(op, {originOutput});
return success();
}
};

} // namespace

namespace {
Expand Down Expand Up @@ -2510,12 +2623,6 @@ class DecomposeAtenToDtypeLayoutOp
op, "unimplemented: pin_memory is expected to be false");
}

// TODO: Add support for non-None device arg.
if (!op.device().getType().isa<Torch::NoneType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: device arg must be None");
}

// TODO: Add support for non-strided layout.
// torch.layout is by default strided i.e. 0.
if (!op.layout().getType().isa<Torch::NoneType>()) {
Expand Down Expand Up @@ -3380,6 +3487,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
target.addIllegalOp<AtenGroupNormOp>();
patterns.add<DecomposeAtenGroupNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormBackwardOp>();
patterns.add<DecomposeAtenNativeLayerNormBackwardOp>(context);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,15 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit(
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
)
Expand Down

0 comments on commit 0f88ea9

Please sign in to comment.