diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index a51a3c4d971..8fd13bf5c1f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6445,6 +6445,30 @@ def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ }]; } +def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pixel_shuffle : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$upscale_factor + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPixelShuffleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPixelShuffleOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 6f578d1ad5f..b0bd21f69a1 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6759,6 +6759,46 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %1, %arg2, %2) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %3 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pixel_shuffle\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int1 = torch.constant.int 1\n" +" %str = torch.constant.str \"AssertionError: number of input channels must be divisible by upscale_factor^2 in pixel_shuffle\"\n" +" %int-3 = torch.constant.int -3\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: input must be at least rank-3 in pixel_shuffle\"\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.mul.int %arg1, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.remainder.int %3, %2 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.slice.t %arg0, %int0, %int-3, %int1 : !torch.list, !torch.int, !torch.int, !torch.int -> !torch.list\n" +" %7 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.floordiv.int %7, %2 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.append.t %6, %8 : !torch.list, !torch.int -> !torch.list\n" +" %10 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.mul.int %10, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" %13 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.mul.int %13, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.append.t %6, %14 : !torch.list, !torch.int -> !torch.list\n" +" return %6 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8516,6 +8556,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pixel_shuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d3e7a47c4aa..b19d3f949f0 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -93,7 +93,7 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc, keepDimCst, dtype); } -// Redunction function to calculate max along given `dim`. +// Reduction function to calculate max along given `dim`. static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { @@ -211,6 +211,7 @@ class DecomposeAtenAmaxOp : public OpRewritePattern { Location loc = op.getLoc(); SmallVector dims; if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); @@ -227,8 +228,7 @@ class DecomposeAtenAmaxOp : public OpRewritePattern { } // For every dimension included in `dim` of the op, iterated over in // reverse order, we create a call to aten.max.dim. - std::sort(dims.begin(), dims.end()); - std::reverse(dims.begin(), dims.end()); + std::sort(dims.rbegin(), dims.rend()); for (int64_t dimInt : dims) { int64_t inputRank = inputTy.getSizes().size(); dimInt = toPositiveDim(dimInt, inputRank); @@ -255,6 +255,7 @@ class DecomposeAtenSizeOp : public OpRewritePattern { Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); + std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); @@ -386,9 +387,10 @@ class DecomposeAtenGluOp : public OpRewritePattern { Value remainder = rewriter.create(loc, dimSize, two); Value eqOrNot = rewriter.create(loc, remainder, zero); + rewriter.create( loc, eqOrNot, - rewriter.getStringAttr("AtenGluOp's dim size must be multiply of 2")); + rewriter.getStringAttr("AtenGluOp's dim size must be multiple of 2")); Value splitLength = rewriter.create(loc, dimSize, two); Value a = rewriter.create(loc, outputTy, self, dim, zero, @@ -443,6 +445,7 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); int64_t n; + if (!matchPattern(op.getN(), m_TorchConstantInt(&n))) return rewriter.notifyMatchFailure(op, "unimplemented: n must be constant"); @@ -1092,9 +1095,180 @@ class DecomposeAtenMvOp : public OpRewritePattern { }; } // namespace +// Decompose aten.pixel_shuffle into: aten.permute and aten.reshape operations. +// +// If input is a tensor of shape (*leading_dims, C*r*r, H, W), where +// leading_dims is of size N, then +// X = pixel_shuffle(input, upscale_factor) +// +// gets replaced with +// A = input.reshape(*leading_dims, C, r, r, H, W) +// B = A.permute(0, ..., N, N+3, N+1, N+4, N+2) +// X = B.reshape(*leading_dims, C, r*H, r*W) +// +// 'r' above is referred to as the 'upscale factor' or just 'factor' below. +namespace { +class DecomposeAtenPixelShuffleOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenPixelShuffleOp op, + PatternRewriter &rewriter) const override { + + + Location loc = op.getLoc(); + Value inValue = op.getSelf(); + auto inType = inValue.getType().cast(); + auto maybeSizes = inType.getOptionalSizes(); + if (!maybeSizes) { + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have known rank."); + } + auto inShape = maybeSizes.value(); + auto inRank = inShape.size(); + + // TODO support dynamic shapes, probably by lowering pixel_shuffle to linalg + // directly. Pixel shuffle does a reshape that is hard to recover + // through pure torch (view) ops, especially in dynamic cases. + // + // See: https://github.com/llvm/torch-mlir/issues/2559 + // + // For now, we just fail the decomposition here so that a sensible error is + // provided: + for (auto dimSize : inShape) { + if (dimSize == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "Currently we only decompose pixel_shuffle if the input tensor " + "is statically shaped"); + } + } + + // The input tensor must have at least 3 dimensions: (1) the channel + // dimension which gets smaller by 'factor*factor', (2) the H channel which + // gets larger by 'factor' and (3) the W channel which get larger by + // 'factor'. The total number of dimensions is 3 + N, where N is the number + // of leading dimensions, and N >= 0 so the input must have rank at least 3. + if (inRank < 3) + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have rank greater than 2."); + + auto nLeadingDims = inRank - 3; + + // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead + // of 'create': if the dimension size is known, then the AtenSizeIntOp is + // folded to a ConstantOp. + auto getDimSize = [&](uint64_t i) -> Value { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + return rewriter.createOrFold(loc, inValue, dim); + }; + + auto inC = getDimSize(inRank - 3); + auto inH = getDimSize(inRank - 2); + auto inW = getDimSize(inRank - 1); + + auto factor = op.getUpscaleFactor(); + + + Value factorSquared = + rewriter.createOrFold(loc, factor, factor); + Value outC = + rewriter.createOrFold(loc, inC, factorSquared); + + Value outH = rewriter.createOrFold(loc, inH, factor); + Value outW = rewriter.createOrFold(loc, inW, factor); + + // Shape of 'A' in the comment at the top + SmallVector prePermuteShape; + prePermuteShape.reserve(nLeadingDims + 5); + + // Shape of 'B' in the comment at the top. + SmallVector postPermuteShape; + postPermuteShape.reserve(nLeadingDims + 5); + + SmallVector outShape; + outShape.reserve(nLeadingDims + 3); + + SmallVector permutation; + permutation.reserve(nLeadingDims + 5); + + for (unsigned i = 0; i < nLeadingDims; ++i) { + auto dimensionAttr = rewriter.getI64IntegerAttr(i); + Value dimensionValue = rewriter.create(loc, dimensionAttr); + Value leadingDimSize = + rewriter.createOrFold(loc, inValue, dimensionValue); + prePermuteShape.push_back(leadingDimSize); + postPermuteShape.push_back(leadingDimSize); + outShape.push_back(leadingDimSize); + permutation.push_back(dimensionValue); + + } + + const auto inOptionalDType = inType.getOptionalDtype(); + + auto getTypeFromShape = [inOptionalDType](auto &&vals) { + // Get a vector of integers from a vector of Values. + auto getIntShape = [](auto &&vals) { + SmallVector shape; + shape.reserve(vals.size()); + for (auto v : vals) { + int64_t cst_val; + if (matchPattern(v, m_TorchConstantInt(&cst_val))) { + shape.push_back(cst_val); + } else { + shape.push_back(kUnknownSize); + } + } + return shape; + }; + + const auto intShape = getIntShape(vals); + return ValueTensorType::get(vals[0].getContext(), + llvm::ArrayRef(intShape), inOptionalDType); + }; + + prePermuteShape.insert(prePermuteShape.end(), + {outC, factor, factor, inH, inW}); + + postPermuteShape.insert(postPermuteShape.end(), + {outC, inH, factor, inW, factor}); + + outShape.insert(outShape.end(), {outC, outH, outW}); + + SmallVector permutationTail{0, 3, 1, 4, 2}; + for (uint64_t d : permutationTail) { + permutation.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(nLeadingDims + d))); + } + + auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); + + Value shapeA = + rewriter.create(loc, listType, prePermuteShape); + + Value A = rewriter.create( + loc, getTypeFromShape(prePermuteShape), inValue, shapeA); + + Value permuteDimsOrder = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + permutation); + + Value B = rewriter.create( + loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder); + + Value outShapeList = + rewriter.create(loc, listType, outShape); + + rewriter.replaceOpWithNewOp(op, op.getType(), B, + outShapeList); + return success(); + } +}; +} // namespace + // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) -static Value getRelu6Results(PatternRewriter &rewriter, Location loc, - Value input) { +static Value +getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { BaseTensorType inputType = input.getType().cast(); Value relu = rewriter.create(loc, inputType, input); @@ -4780,8 +4954,7 @@ class DecomposePrimsSqueezeOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "all dimensions must be constant ints"); - std::sort(dimensions.begin(), dimensions.end()); - std::reverse(dimensions.begin(), dimensions.end()); + std::sort(dimensions.rbegin(), dimensions.rend()); if (dimensions.size() == 0) { rewriter.replaceOp(op, input); @@ -5526,6 +5699,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 5cdcfbc5d23..38198a913d2 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -391,6 +391,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](AtenMatmulOp op) { diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 067cc410ddd..78e6cfc1c47 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -206,7 +206,7 @@ bool Torch::isViewLikeOp(Operation *op) { TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, - AtenViewAsComplexOp, AtenViewAsRealOp>(op); + AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 78251b3862e..570ae185268 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -943,6 +943,8 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", @@ -1352,6 +1354,8 @@ } LTC_XFAIL_SET = { + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", "_Convolution2DAllFalseModule_basic", "_Convolution2DBenchmarkModule_basic", "_Convolution2DCudnnModule_basic", diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 4450237bec4..8dadc8ac974 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -416,6 +416,20 @@ def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]], def aten〇prod〇dim_int〡shape(self: List[int], dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, [dim], keepdim, dtype) +def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[int]: + + assert len(self) >= 3, "input must be at least rank-3 in pixel_shuffle" + upscale_factor_squared = upscale_factor * upscale_factor + assert self[-3] % (upscale_factor_squared) == 0, "number of input channels must be divisible by upscale_factor^2 in pixel_shuffle" + + out = self[0:-3] + out.append(self[-3] // upscale_factor_squared) + out.append(self[-2] * upscale_factor) + out.append(self[-1] * upscale_factor) + return out + + + def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]: return upstream_shape_functions.permute(self, dims) @@ -1440,6 +1454,7 @@ def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1545,6 +1560,11 @@ def aten〇adaptive_avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_ self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4, 1, 1)], upscale_factor = 2)) +def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_factor: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2])) def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int: self_rank, self_dtype = self_rank_dtype @@ -1906,6 +1926,7 @@ def aten〇permute〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> self_rank, self_dtype = self_rank_dtype return self_dtype + @check_dtype_function(_check_two_tensor_op()) def aten〇pow〇Tensor_Tensor〡dtype(self_rank_dtype: Tuple[int, int], exponent_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3666,6 +3687,8 @@ def prims〇squeeze〡dtype(a_rank_dtype: Tuple[int, int], dimensions: List[int] a_rank, a_dtype = a_rank_dtype return a_dtype + + # ============================================================================== # Main # ============================================================================== diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index d4ac74fe53b..ee118cea059 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -482,6 +482,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)") + emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 431c5efec68..9eb1a8986d4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -592,6 +592,7 @@ def PermuteModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) + # ============================================================================== @@ -655,6 +656,39 @@ def TransposeIntNegDimsModule_basic(module, tu: TestUtils): # ============================================================================== +class PixelShuffleModuleStaticRank4Float32(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([3, 18, 2, 2], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 3) + +@register_test_case(module_factory=lambda: PixelShuffleModuleStaticRank4Float32()) +def PixelShuffleModuleStaticRank4Float32_basic(module, tu: TestUtils): + module.forward(tu.rand(3,18,2,2)) + + +# ============================================================================== + + +class PixelShuffleModuleStaticRank3Int64(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([12, 2, 3], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleStaticRank3Int64()) +def PixelShuffleModuleStaticRank3Int64_basic(module, tu: TestUtils): + module.forward(tu.randint(12, 2, 3, low = 0, high = 100)) + + + + class TensorsConcatModule(torch.nn.Module): def __init__(self):