Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

add support for stable diffusion #25

Merged
merged 1 commit into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4244,6 +4244,97 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
}];
}

def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$running_mean,
AnyTorchOptionalTensorType:$running_var,
Torch_BoolType:$use_input_stats,
Torch_FloatType:$momentum,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenInstanceNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}

def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
Torch_IntType:$num_groups,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_IntType:$N,
Torch_IntType:$C,
Torch_IntType:$HxW,
Torch_IntType:$group,
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 3);
}
void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 3);
}
}];
}

def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
121 changes: 115 additions & 6 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,119 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
return success();
}
};

class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
using OpRewritePattern<AtenGroupNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenGroupNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();
Value input = op.input();
auto inputTy = input.getType().cast<BaseTensorType>();
if (!inputTy.hasSizes())
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
ArrayRef<int64_t> inputSize = inputTy.getSizes();
int64_t inputRank = inputTy.getSizes().size();
if (inputRank != 4) {
return rewriter.notifyMatchFailure(
op, "group norm only support 4D input now.");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM coding style: start the first sentence with a lower-case letter, and finish the last sentence without a period.

Ref:https://llvm.org/docs/CodingStandards.html#error-and-warning-messages

}
Value num_groups = op.num_groups();
int64_t num_groups_int;
if (!matchPattern(num_groups, m_TorchConstantInt(&num_groups_int)))
return rewriter.notifyMatchFailure(
op, "non const num_groups for AtenGroupNormOp");

// reshape input -> [N, G, -1(G//C), H, W]
Value negOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
SmallVector<int64_t, 5> inputForNormTySize(inputRank + 1,
ShapedType::kDynamicSize);
inputForNormTySize[1] = num_groups_int;
Type inputForNormTy = inputTy.getWithSizesAndDtype(
llvm::makeArrayRef(inputForNormTySize), inputTy.getDtype());
SmallVector<Value> orginInputSize;
for (int i = 0; i < inputRank; ++i) {
Value index =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
orginInputSize.push_back(
rewriter.create<AtenSizeIntOp>(loc, input, index));
}
SmallVector<Value, 5> inputForNormSize{orginInputSize.begin(),
orginInputSize.end()};
inputForNormSize.insert(inputForNormSize.begin() + 1, num_groups);
inputForNormSize[2] = negOne;
Value inputForNormSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), inputForNormSize);
Value reshapedInput = rewriter.create<AtenViewOp>(
loc, inputForNormTy, input, inputForNormSizeList);
// only keep N, G, reduce G//C, H, W
int64_t axis = 2;
std::vector<int64_t> meanVarTySizes(inputForNormTySize.size(), 1);
for (int i = 0; i < axis; i++)
meanVarTySizes[i] = inputForNormTySize[i];
auto meanVarTy = inputTy.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarTySizes), inputTy.getDtype());
SmallVector<Value> normalizedShapeSize{inputForNormSize.begin() + axis,
inputForNormSize.end()};
auto normalizedSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), normalizedShapeSize);

auto nativeLayerNorm =
rewriter
.create<AtenNativeLayerNormOp>(
loc, inputForNormTy, meanVarTy, meanVarTy, reshapedInput,
normalizedSizeList, none, none, op.eps())
.getResult(0);
// rehshape back to origin shape
Value inputSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), orginInputSize);
Value originOutput = rewriter.create<AtenViewOp>(
loc, op.getType(), nativeLayerNorm, inputSizeList);
// reshape weight and bias to [1, C, 1, 1]
Value weight = op.weight();
Value bias = op.bias();
if (!weight.getType().isa<Torch::NoneType>() ||
!bias.getType().isa<Torch::NoneType>()) {
SmallVector<Value> weightsAndBiasSize(inputRank - 1, one);
weightsAndBiasSize[0] = orginInputSize[1];

SmallVector<int64_t> weightsAndBiasTySize(inputRank - 1,
ShapedType::kDynamicSize);
// weightsAndBiasTySize[1] = ShapedType::kDynamicSize;

Value weightsAndBiasSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), weightsAndBiasSize);
if (!weight.getType().isa<Torch::NoneType>()) {
BaseTensorType weightType = weight.getType().cast<BaseTensorType>();
Type weightTy = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(weightsAndBiasTySize), weightType.getDtype());
weight = rewriter.create<AtenViewOp>(loc, weightTy, weight,
weightsAndBiasSizeList);
originOutput = rewriter.create<AtenMulTensorOp>(loc, op.getType(),
originOutput, weight);
}
if (!bias.getType().isa<Torch::NoneType>()) {
BaseTensorType biasType = bias.getType().cast<BaseTensorType>();
Type biasTy = biasType.getWithSizesAndDtype(
llvm::makeArrayRef(weightsAndBiasTySize), biasType.getDtype());
bias = rewriter.create<AtenViewOp>(loc, biasTy, bias,
weightsAndBiasSizeList);
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
originOutput = rewriter.create<AtenAddTensorOp>(
loc, op.getType(), originOutput, bias, alpha);
}
}
rewriter.replaceOp(op, {originOutput});
return success();
}
};

} // namespace

namespace {
Expand Down Expand Up @@ -2506,12 +2619,6 @@ class DecomposeAtenToDtypeLayoutOp
op, "unimplemented: pin_memory is expected to be false");
}

// TODO: Add support for non-None device arg.
if (!op.device().getType().isa<Torch::NoneType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: device arg must be None");
}

// TODO: Add support for non-strided layout.
// torch.layout is by default strided i.e. 0.
if (!op.layout().getType().isa<Torch::NoneType>()) {
Expand Down Expand Up @@ -3378,6 +3485,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
target.addIllegalOp<AtenGroupNormOp>();
patterns.add<DecomposeAtenGroupNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormBackwardOp>();
patterns.add<DecomposeAtenNativeLayerNormBackwardOp>(context);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,15 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit(
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
)
Expand Down