From a2fe7afab9d1fe4179c765ae30dfb4805b9d0b95 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 6 Nov 2023 16:35:44 -0800 Subject: [PATCH 1/7] collapsed commit. Fails the pattern with dynamic shapes. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 234 +++++++++++++++++- .../Transforms/LowerToBackendContract.cpp | 1 + lib/Dialect/Torch/Utils/Utils.cpp | 2 +- .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 32 +++ 6 files changed, 285 insertions(+), 9 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c0ac6a9419e8..41f996a75469 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6419,6 +6419,30 @@ def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ }]; } +def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pixel_shuffle : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$upscale_factor + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPixelShuffleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPixelShuffleOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e315bab00e68..384ada260792 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9,6 +9,7 @@ #include "PassDetail.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -93,7 +94,7 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc, keepDimCst, dtype); } -// Redunction function to calculate max along given `dim`. +// Reduction function to calculate max along given `dim`. static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { @@ -211,6 +212,7 @@ class DecomposeAtenAmaxOp : public OpRewritePattern { Location loc = op.getLoc(); SmallVector dims; if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); @@ -227,8 +229,7 @@ class DecomposeAtenAmaxOp : public OpRewritePattern { } // For every dimension included in `dim` of the op, iterated over in // reverse order, we create a call to aten.max.dim. - std::sort(dims.begin(), dims.end()); - std::reverse(dims.begin(), dims.end()); + std::sort(dims.rbegin(), dims.rend()); for (int64_t dimInt : dims) { int64_t inputRank = inputTy.getSizes().size(); dimInt = toPositiveDim(dimInt, inputRank); @@ -255,6 +256,7 @@ class DecomposeAtenSizeOp : public OpRewritePattern { Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); + std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); @@ -386,9 +388,10 @@ class DecomposeAtenGluOp : public OpRewritePattern { Value remainder = rewriter.create(loc, dimSize, two); Value eqOrNot = rewriter.create(loc, remainder, zero); + rewriter.create( loc, eqOrNot, - rewriter.getStringAttr("AtenGluOp's dim size must be multiply of 2")); + rewriter.getStringAttr("AtenGluOp's dim size must be multiple of 2")); Value splitLength = rewriter.create(loc, dimSize, two); Value a = rewriter.create(loc, outputTy, self, dim, zero, @@ -443,6 +446,7 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); int64_t n; + if (!matchPattern(op.getN(), m_TorchConstantInt(&n))) return rewriter.notifyMatchFailure(op, "unimplemented: n must be constant"); @@ -1092,9 +1096,223 @@ class DecomposeAtenMvOp : public OpRewritePattern { }; } // namespace +// Decompose aten.pixel_shuffle into: aten.permute and aten.reshape operations. +// +// If input is a tensor of shape (*leading_dims, C*r*r, H, W), where +// leading_dims is of size N, then +// X = pixel_shuffle(input, upscale_factor) +// +// gets replaced with +// A = input.reshape(*leading_dims, C, r, r, H, W) +// B = A.permute(0, ..., N, N+3, N+1, N+4, N+2) +// X = B.reshape(*leading_dims, C, r*H, r*W) +// +// 'r' above is referred to as the 'upscale factor' or just 'factor' below. +namespace { +class DecomposeAtenPixelShuffleOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenPixelShuffleOp op, + PatternRewriter &rewriter) const override { + + + Location loc = op.getLoc(); + Value inValue = op.getSelf(); + auto inType = inValue.getType().cast(); + auto maybeSizes = inType.getOptionalSizes(); + if (!maybeSizes) { + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have known rank."); + } + auto inShape = maybeSizes.value(); + auto inRank = inShape.size(); + + // TODO: Handle dynamic shapes. + // + // Currently the decomposition of inputs with dynamic shapes results in a + // failure when lowering to linalg. Specifically, the error: + // + // "Unhandled case of expand/collapse. " + // + // is hit in the lowering of aten.view to the linalg dialect (see + // ConvertAtenViewOp in TorchToLinalg/DataMovement.cpp where this error is + // thrown). + // + // The issue is that there is no way for the lowering from torch to linalg + // to know that the view ops (reshape ops) created in this decomposition can + // be mapped directly to tensor.expand_shape and tensor_collapse_shape. We + // DO know that this is the case: the first reshape is an expand_shape + // and the second one is a collapse_shape. But we only know this because of + // the semantics of pixel_shuffle, by the time the lowering sees the view + // ops, this special context is lost. + // + // There are a few possible options to fixing the current approach, as I see + // it: + // + // 1) Lower pixel_shuffle to linalg directly. This would mean a bit of + // duplication of the logic for lowering from aten.permute to linalg, and + // would not be the most 'gradual' lowering approach. But perhaps the most + // straightforward and understable. + // + // 2) Rather than use aten.reshape ops in the decomposition below, insert + // tensor.expand_shape and tensor.collapse_shape ops directly. This would + // require inserting unrealized_conversion_cast ops to convert the types of + // the operands from !torch.tensor to to the MLIR built-in tensor type, + // which isn't ideal (unrealized_conversion_cast should only be used in + // conversion passes AFAIK (?)). + // + // 3) Create 2 new ops in the torch dialect, with the same semantics as + // tensor.expand_shape and tensor.collapse_shape, but with !torch.tensor + // operands and results. Then use these ops in the decomposition below, + // instead of aten.reshape ops. + // + // For now, we just fail the decomposition here so that a sensible error is + // provided: + for (auto dimSize : inShape) { + if (dimSize == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "Currently we only decompose pixel_shuffle if the input tensor " + "is statically shaped"); + } + } + + // At least 3 dimensions are needed + // (case when leading_dims is empty). + if (inRank < 3) + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have rank greater than 2."); + + auto nLeadingDims = inRank - 3; + + // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead + // of 'create': if the dimension size is known, then the AtenSizeIntOp is + // folded to a ConstantOp. + auto getDimSize = [&](uint64_t i) -> Value { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + return rewriter.createOrFold(loc, inValue, dim); + }; + + auto inC = getDimSize(inRank - 3); + auto inH = getDimSize(inRank - 2); + auto inW = getDimSize(inRank - 1); + + auto factor = op.getUpscaleFactor(); + + + Value factorSquared = + rewriter.createOrFold(loc, factor, factor); + Value outC = + rewriter.createOrFold(loc, inC, factorSquared); + + Value outH = rewriter.createOrFold(loc, inH, factor); + Value outW = rewriter.createOrFold(loc, inW, factor); + + // Shape of 'A' in the comment at the top + SmallVector prePermuteShape; + prePermuteShape.reserve(nLeadingDims + 5); + + // Shape of 'B' in the comment at the top. + SmallVector postPermuteShape; + postPermuteShape.reserve(nLeadingDims + 5); + + SmallVector outShape; + outShape.reserve(nLeadingDims + 3); + + SmallVector permutation; + permutation.reserve(nLeadingDims + 5); + + for (unsigned i = 0; i < nLeadingDims; ++i) { + auto dimensionAttr = rewriter.getI64IntegerAttr(i); + Value dimensionValue = rewriter.create(loc, dimensionAttr); + Value leadingDimSize = + rewriter.createOrFold(loc, inValue, dimensionValue); + prePermuteShape.push_back(leadingDimSize); + postPermuteShape.push_back(leadingDimSize); + outShape.push_back(leadingDimSize); + permutation.push_back(dimensionValue); + + } + + const auto inOptionalDType = inType.getOptionalDtype(); + + auto getTypeFromShape = [inOptionalDType](auto &&vals) { + // Get a vector of integers from a vector of Values. + auto getIntShape = [](auto &&vals) { + SmallVector shape; + shape.reserve(vals.size()); + for (auto v : vals) { + int64_t cst_val; + if (matchPattern(v, m_TorchConstantInt(&cst_val))) { + shape.push_back(cst_val); + } else { + shape.push_back(kUnknownSize); + } + } + return shape; + }; + + const auto intShape = getIntShape(vals); + return ValueTensorType::get(vals[0].getContext(), + llvm::ArrayRef(intShape), inOptionalDType); + }; + + prePermuteShape.insert(prePermuteShape.end(), + {outC, factor, factor, inH, inW}); + + postPermuteShape.insert(postPermuteShape.end(), + {outC, inH, factor, inW, factor}); + + outShape.insert(outShape.end(), {outC, outH, outW}); + + SmallVector permutationTail{0, 3, 1, 4, 2}; + for (uint64_t d : permutationTail) { + permutation.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(nLeadingDims + d))); + } + + auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); + + + Value shapeA = + rewriter.create(loc, listType, prePermuteShape); + + Value A = rewriter.create( + loc, getTypeFromShape(prePermuteShape), inValue, shapeA); + + Value permuteDimsOrder = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + permutation); + + Value B = rewriter.create( + loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder); + + Value outShapeList = + rewriter.create(loc, listType, outShape); + + // TODO(jn) figure out why the type of the returned value must be made + // undefined. + auto finalShape = getTypeFromShape(outShape); + auto finalType = finalShape.getWithSizesAndDtype({}, {}); + + + rewriter.replaceOpWithNewOp(op, finalType, B, outShapeList); + +// Value out = +// rewriter.createOrFold(loc, finalType, B, outShapeList); +// rewriter.replaceAllUsesWith(op.getResult(), out); +// rewriter.eraseOp(op); + + + return success(); + } +}; +} // namespace + // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) -static Value getRelu6Results(PatternRewriter &rewriter, Location loc, - Value input) { +static Value +getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { BaseTensorType inputType = input.getType().cast(); Value relu = rewriter.create(loc, inputType, input); @@ -4717,8 +4935,7 @@ class DecomposePrimsSqueezeOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "all dimensions must be constant ints"); - std::sort(dimensions.begin(), dimensions.end()); - std::reverse(dimensions.begin(), dimensions.end()); + std::sort(dimensions.rbegin(), dimensions.rend()); if (dimensions.size() == 0) { rewriter.replaceOp(op, input); @@ -5463,6 +5680,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index d40f99066fc9..a8d065f122f0 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -391,6 +391,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](AtenMatmulOp op) { diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 14a264ada342..c4e0950fab6a 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -206,7 +206,7 @@ bool Torch::isViewLikeOp(Operation *op) { TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, - AtenViewAsComplexOp, AtenViewAsRealOp>(op); + AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 7042ab5756e6..a7afcbe4e6cb 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -481,6 +481,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)") + emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 3bd9b91ebf99..ddb1ab2a76e1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -595,6 +595,38 @@ def PermuteModule_basic(module, tu: TestUtils): # ============================================================================== +class PixelShuffleModuleStatic_3_18_2_2(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([3, 18, 2, 2], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 3) + +@register_test_case(module_factory=lambda: PixelShuffleModuleStatic_3_18_2_2()) +def PixelShuffleModuleStatic_3_18_2_2_basic(module, tu: TestUtils): + module.forward(tu.rand(3,18,2,2)) + + +class PixelShuffleModuleStatic_12_2_3(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([12, 2, 3], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleStatic_12_2_3()) +def PixelShuffleModuleStatic_12_2_3_basic(module, tu: TestUtils): + module.forward(tu.randint(12, 2, 3, low = 0, high = 100)) + + + +# ============================================================================== + + class PermuteNegativeIndexModule(torch.nn.Module): def __init__(self): From daa65d4325be69d5e9c347d8fdb1352b782bb5cc Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 6 Nov 2023 21:13:28 -0800 Subject: [PATCH 2/7] address a few review comments --- .../Torch/Transforms/DecomposeComplexOps.cpp | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 384ada260792..efcf8f006429 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1160,7 +1160,8 @@ class DecomposeAtenPixelShuffleOp // require inserting unrealized_conversion_cast ops to convert the types of // the operands from !torch.tensor to to the MLIR built-in tensor type, // which isn't ideal (unrealized_conversion_cast should only be used in - // conversion passes AFAIK (?)). + // conversion passes AFAIK (?)). UPDATE: not a viable option as we cannot + // make assumptions about the dialects that torch is lowered to. // // 3) Create 2 new ops in the torch dialect, with the same semantics as // tensor.expand_shape and tensor.collapse_shape, but with !torch.tensor @@ -1176,9 +1177,11 @@ class DecomposeAtenPixelShuffleOp "is statically shaped"); } } - - // At least 3 dimensions are needed - // (case when leading_dims is empty). + // The input tensor must have at least 3 dimensions: (1) the channel + // dimension which gets smaller by 'factor*factor', (2) the H channel which + // gets larger by 'factor' and (3) the W channel which get larger by + // 'factor'. The total number of dimensions is 3 + N, where N is the number + // of leading dimensions, and N >= 0 so the input must have rank at least 3. if (inRank < 3) return rewriter.notifyMatchFailure( op, "Expected input tensor to have rank greater than 2."); @@ -1291,19 +1294,16 @@ class DecomposeAtenPixelShuffleOp Value outShapeList = rewriter.create(loc, listType, outShape); - // TODO(jn) figure out why the type of the returned value must be made - // undefined. - auto finalShape = getTypeFromShape(outShape); - auto finalType = finalShape.getWithSizesAndDtype({}, {}); - - - rewriter.replaceOpWithNewOp(op, finalType, B, outShapeList); - -// Value out = -// rewriter.createOrFold(loc, finalType, B, outShapeList); -// rewriter.replaceAllUsesWith(op.getResult(), out); -// rewriter.eraseOp(op); + // TODO(jn) figure out why the deduced return type (like + // !torch.vtensor<[3,4,6],si64>) cannot be used for the type of the + // replacement of the pixel_shuffle op's output. The pattern seems to + // require the generic !torch.vtensor type. + + auto deducedReturnType = getTypeFromShape(outShape); + auto genericReturnType = deducedReturnType.getWithSizesAndDtype({}, {}); + rewriter.replaceOpWithNewOp(op, genericReturnType, B, + outShapeList); return success(); } From 4ad0731a0844c60bde46a524ee7d4233f808028b Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 7 Nov 2023 10:03:25 -0800 Subject: [PATCH 3/7] implement type inference, remove large comment, remove unused include. --- .../Transforms/AbstractInterpLibrary.cpp | 44 ++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 81 +++++-------------- .../build_tools/abstract_interp_lib_gen.py | 19 +++++ 3 files changed, 82 insertions(+), 62 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 7ffa776e2d41..853b362a064e 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6749,6 +6749,46 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %1, %arg2, %2) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %3 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pixel_shuffle\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int1 = torch.constant.int 1\n" +" %str = torch.constant.str \"AssertionError: number of input channels must be divisible by upscale_factor^2 in pixel_shuffle\"\n" +" %int-3 = torch.constant.int -3\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: input must be at least rank-3 in pixel_shuffle\"\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.mul.int %arg1, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.remainder.int %3, %2 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.slice.t %arg0, %int0, %int-3, %int1 : !torch.list, !torch.int, !torch.int, !torch.int -> !torch.list\n" +" %7 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.floordiv.int %7, %2 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.append.t %6, %8 : !torch.list, !torch.int -> !torch.list\n" +" %10 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.mul.int %10, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" %13 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.mul.int %13, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.append.t %6, %14 : !torch.list, !torch.int -> !torch.list\n" +" return %6 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8333,6 +8373,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pixel_shuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index efcf8f006429..57d8e9c22a93 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9,7 +9,6 @@ #include "PassDetail.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -1128,45 +1127,11 @@ class DecomposeAtenPixelShuffleOp auto inShape = maybeSizes.value(); auto inRank = inShape.size(); - // TODO: Handle dynamic shapes. + // TODO support dynamic shapes, probably by lowering pixel_shuffle to linalg + // directly. Pixel shuffle does a reshape that is hard to recover + // through pure torch (view) ops, especially in dynamic cases. // - // Currently the decomposition of inputs with dynamic shapes results in a - // failure when lowering to linalg. Specifically, the error: - // - // "Unhandled case of expand/collapse. " - // - // is hit in the lowering of aten.view to the linalg dialect (see - // ConvertAtenViewOp in TorchToLinalg/DataMovement.cpp where this error is - // thrown). - // - // The issue is that there is no way for the lowering from torch to linalg - // to know that the view ops (reshape ops) created in this decomposition can - // be mapped directly to tensor.expand_shape and tensor_collapse_shape. We - // DO know that this is the case: the first reshape is an expand_shape - // and the second one is a collapse_shape. But we only know this because of - // the semantics of pixel_shuffle, by the time the lowering sees the view - // ops, this special context is lost. - // - // There are a few possible options to fixing the current approach, as I see - // it: - // - // 1) Lower pixel_shuffle to linalg directly. This would mean a bit of - // duplication of the logic for lowering from aten.permute to linalg, and - // would not be the most 'gradual' lowering approach. But perhaps the most - // straightforward and understable. - // - // 2) Rather than use aten.reshape ops in the decomposition below, insert - // tensor.expand_shape and tensor.collapse_shape ops directly. This would - // require inserting unrealized_conversion_cast ops to convert the types of - // the operands from !torch.tensor to to the MLIR built-in tensor type, - // which isn't ideal (unrealized_conversion_cast should only be used in - // conversion passes AFAIK (?)). UPDATE: not a viable option as we cannot - // make assumptions about the dialects that torch is lowered to. - // - // 3) Create 2 new ops in the torch dialect, with the same semantics as - // tensor.expand_shape and tensor.collapse_shape, but with !torch.tensor - // operands and results. Then use these ops in the decomposition below, - // instead of aten.reshape ops. + // See: https://github.com/llvm/torch-mlir/issues/2559 // // For now, we just fail the decomposition here so that a sensible error is // provided: @@ -1277,35 +1242,27 @@ class DecomposeAtenPixelShuffleOp auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); + Value shapeA = + rewriter.create(loc, listType, prePermuteShape); - Value shapeA = - rewriter.create(loc, listType, prePermuteShape); - - Value A = rewriter.create( - loc, getTypeFromShape(prePermuteShape), inValue, shapeA); - - Value permuteDimsOrder = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), - permutation); + Value A = rewriter.create( + loc, getTypeFromShape(prePermuteShape), inValue, shapeA); - Value B = rewriter.create( - loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder); + Value permuteDimsOrder = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + permutation); - Value outShapeList = - rewriter.create(loc, listType, outShape); + Value B = rewriter.create( + loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder); - // TODO(jn) figure out why the deduced return type (like - // !torch.vtensor<[3,4,6],si64>) cannot be used for the type of the - // replacement of the pixel_shuffle op's output. The pattern seems to - // require the generic !torch.vtensor type. - - auto deducedReturnType = getTypeFromShape(outShape); - auto genericReturnType = deducedReturnType.getWithSizesAndDtype({}, {}); + Value outShapeList = + rewriter.create(loc, listType, outShape); - rewriter.replaceOpWithNewOp(op, genericReturnType, B, - outShapeList); + auto deducedReturnType = getTypeFromShape(outShape); - return success(); + rewriter.replaceOpWithNewOp(op, deducedReturnType, B, + outShapeList); + return success(); } }; } // namespace diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 9839c34be618..9144229f9484 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -412,6 +412,23 @@ def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]], def aten〇prod〇dim_int〡shape(self: List[int], dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, [dim], keepdim, dtype) +def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[int]: + + assert len(self) >= 3, "input must be at least rank-3 in pixel_shuffle" + upscale_factor_squared = upscale_factor * upscale_factor + assert self[-3] % (upscale_factor_squared) == 0, "number of input channels must be divisible by upscale_factor^2 in pixel_shuffle" + + out = self[0:-3] + out.append(self[-3] // upscale_factor_squared) + out.append(self[-2] * upscale_factor) + out.append(self[-1] * upscale_factor) + return out + +def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_factor: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + + def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]: return upstream_shape_functions.permute(self, dims) @@ -1436,6 +1453,7 @@ def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1893,6 +1911,7 @@ def aten〇permute〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> self_rank, self_dtype = self_rank_dtype return self_dtype + @check_dtype_function(_check_two_tensor_op()) def aten〇pow〇Tensor_Tensor〡dtype(self_rank_dtype: Tuple[int, int], exponent_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype From d49820626b7d01e2433bded977c45ad1ddc684c2 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 7 Nov 2023 10:35:27 -0800 Subject: [PATCH 4/7] rerun ./build_tools/update_abstract_interp_lib.sh --- lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 853b362a064e..6f6f7ea127e1 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6789,6 +6789,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %15 = torch.aten.append.t %6, %14 : !torch.list, !torch.int -> !torch.list\n" " return %6 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pixel_shuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8373,10 +8377,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.pixel_shuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" return %0#1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" From 09ec2a067728f81df19b49081ac0002bfe20f593 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 7 Nov 2023 13:35:27 -0800 Subject: [PATCH 5/7] add to tosa pass set --- projects/pt1/e2e_testing/xfail_sets.py | 2 + .../torch_mlir_e2e_test/test_suite/basic.py | 64 ++++++++++--------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e671b9241287..d0f3feec26e7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -941,6 +941,8 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "PixelShuffleModuleStatic_12_2_3_basic", + "PixelShuffleModuleStatic_3_18_2_2_basic", "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index ddb1ab2a76e1..f4d36e6f23fc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -592,37 +592,6 @@ def PermuteModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) -# ============================================================================== - - -class PixelShuffleModuleStatic_3_18_2_2(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([None, ([3, 18, 2, 2], torch.float32, True)]) - def forward(self, x): - return torch.ops.aten.pixel_shuffle(x, 3) - -@register_test_case(module_factory=lambda: PixelShuffleModuleStatic_3_18_2_2()) -def PixelShuffleModuleStatic_3_18_2_2_basic(module, tu: TestUtils): - module.forward(tu.rand(3,18,2,2)) - - -class PixelShuffleModuleStatic_12_2_3(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([None, ([12, 2, 3], torch.int64, True)]) - def forward(self, x): - return torch.ops.aten.pixel_shuffle(x, 2) - -@register_test_case(module_factory=lambda: PixelShuffleModuleStatic_12_2_3()) -def PixelShuffleModuleStatic_12_2_3_basic(module, tu: TestUtils): - module.forward(tu.randint(12, 2, 3, low = 0, high = 100)) - - # ============================================================================== @@ -687,6 +656,39 @@ def TransposeIntNegDimsModule_basic(module, tu: TestUtils): # ============================================================================== +class PixelShuffleModuleStatic_3_18_2_2(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([3, 18, 2, 2], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 3) + +@register_test_case(module_factory=lambda: PixelShuffleModuleStatic_3_18_2_2()) +def PixelShuffleModuleStatic_3_18_2_2_basic(module, tu: TestUtils): + module.forward(tu.rand(3,18,2,2)) + + +# ============================================================================== + + +class PixelShuffleModuleStatic_12_2_3(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([12, 2, 3], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleStatic_12_2_3()) +def PixelShuffleModuleStatic_12_2_3_basic(module, tu: TestUtils): + module.forward(tu.randint(12, 2, 3, low = 0, high = 100)) + + + + class TensorsConcatModule(torch.nn.Module): def __init__(self): From 3858f5ab4aee658c1c4a391ae19f163ce779c479 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 7 Nov 2023 15:06:20 -0800 Subject: [PATCH 6/7] add ltc fail expectation --- projects/pt1/e2e_testing/xfail_sets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d0f3feec26e7..32fa461ffefc 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1352,6 +1352,8 @@ } LTC_XFAIL_SET = { + "PixelShuffleModuleStatic_12_2_3_basic", + "PixelShuffleModuleStatic_3_18_2_2_basic", "_Convolution2DAllFalseModule_basic", "_Convolution2DBenchmarkModule_basic", "_Convolution2DCudnnModule_basic", From 019245ddc18ed90187251725435aa10a14f063f2 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 7 Nov 2023 18:06:50 -0800 Subject: [PATCH 7/7] update xfail set name, after e2e test name change --- .../Torch/Transforms/AbstractInterpLibrary.cpp | 8 ++++---- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 7 +++---- projects/pt1/e2e_testing/xfail_sets.py | 8 ++++---- .../jit_ir/build_tools/abstract_interp_lib_gen.py | 10 +++++++--- .../python/torch_mlir_e2e_test/test_suite/basic.py | 12 ++++++------ 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 6f6f7ea127e1..2dfe5c058bad 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6789,10 +6789,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %15 = torch.aten.append.t %6, %14 : !torch.list, !torch.int -> !torch.list\n" " return %6 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.pixel_shuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" return %0#1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8550,6 +8546,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pixel_shuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 57d8e9c22a93..6efe3d131415 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1142,11 +1142,12 @@ class DecomposeAtenPixelShuffleOp "is statically shaped"); } } + // The input tensor must have at least 3 dimensions: (1) the channel // dimension which gets smaller by 'factor*factor', (2) the H channel which // gets larger by 'factor' and (3) the W channel which get larger by // 'factor'. The total number of dimensions is 3 + N, where N is the number - // of leading dimensions, and N >= 0 so the input must have rank at least 3. + // of leading dimensions, and N >= 0 so the input must have rank at least 3. if (inRank < 3) return rewriter.notifyMatchFailure( op, "Expected input tensor to have rank greater than 2."); @@ -1258,9 +1259,7 @@ class DecomposeAtenPixelShuffleOp Value outShapeList = rewriter.create(loc, listType, outShape); - auto deducedReturnType = getTypeFromShape(outShape); - - rewriter.replaceOpWithNewOp(op, deducedReturnType, B, + rewriter.replaceOpWithNewOp(op, op.getType(), B, outShapeList); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 32fa461ffefc..307b8a635eb2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -941,8 +941,8 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "PixelShuffleModuleStatic_12_2_3_basic", - "PixelShuffleModuleStatic_3_18_2_2_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", @@ -1352,8 +1352,8 @@ } LTC_XFAIL_SET = { - "PixelShuffleModuleStatic_12_2_3_basic", - "PixelShuffleModuleStatic_3_18_2_2_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", "_Convolution2DAllFalseModule_basic", "_Convolution2DBenchmarkModule_basic", "_Convolution2DCudnnModule_basic", diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 9144229f9484..af66de0b7e5c 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -424,9 +424,6 @@ def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[i out.append(self[-1] * upscale_factor) return out -def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_factor: int) -> int: - self_rank, self_dtype = self_rank_dtype - return self_dtype def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]: @@ -1559,6 +1556,11 @@ def aten〇adaptive_avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_ self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4, 1, 1)], upscale_factor = 2)) +def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_factor: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2])) def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int: self_rank, self_dtype = self_rank_dtype @@ -3672,6 +3674,8 @@ def prims〇squeeze〡dtype(a_rank_dtype: Tuple[int, int], dimensions: List[int] a_rank, a_dtype = a_rank_dtype return a_dtype + + # ============================================================================== # Main # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index f4d36e6f23fc..c513dd607d61 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -656,7 +656,7 @@ def TransposeIntNegDimsModule_basic(module, tu: TestUtils): # ============================================================================== -class PixelShuffleModuleStatic_3_18_2_2(torch.nn.Module): +class PixelShuffleModuleStaticRank4Float32(torch.nn.Module): def __init__(self): super().__init__() @@ -665,15 +665,15 @@ def __init__(self): def forward(self, x): return torch.ops.aten.pixel_shuffle(x, 3) -@register_test_case(module_factory=lambda: PixelShuffleModuleStatic_3_18_2_2()) -def PixelShuffleModuleStatic_3_18_2_2_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: PixelShuffleModuleStaticRank4Float32()) +def PixelShuffleModuleStaticRank4Float32_basic(module, tu: TestUtils): module.forward(tu.rand(3,18,2,2)) # ============================================================================== -class PixelShuffleModuleStatic_12_2_3(torch.nn.Module): +class PixelShuffleModuleStaticRank3Int64(torch.nn.Module): def __init__(self): super().__init__() @@ -682,8 +682,8 @@ def __init__(self): def forward(self, x): return torch.ops.aten.pixel_shuffle(x, 2) -@register_test_case(module_factory=lambda: PixelShuffleModuleStatic_12_2_3()) -def PixelShuffleModuleStatic_12_2_3_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: PixelShuffleModuleStaticRank3Int64()) +def PixelShuffleModuleStaticRank3Int64_basic(module, tu: TestUtils): module.forward(tu.randint(12, 2, 3, low = 0, high = 100))