From 413e0801de61ce34c3bd2242a39a566a7e7702ff Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Mon, 25 Jul 2022 23:28:48 +0800 Subject: [PATCH] [MHLO] Add [un]squeeze op patterns (#1099) * [MHLO] Add [un]squeeze op patterns * Conform to llvm coding standard * minor update See RFC https://github.com/llvm/torch-mlir/issues/999 Co-authored-by: Bairen Yi yibairen.byron@bytedance.com Co-authored-by: Jiawei Wu xremold@gmail.com Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.com --- lib/Conversion/TorchToMhlo/ViewLikeOps.cpp | 229 +++++++++++++++++++-- test/Conversion/TorchToMhlo/view_like.mlir | 165 +++++++++++++++ 2 files changed, 375 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp index 0ecd96bf6293..18d21c193ba6 100644 --- a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp @@ -36,23 +36,31 @@ static constexpr size_t kMhloDimSizeBits = 64; namespace { -SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, - Operation *op, Value value) { +SmallVector toPositiveDims(ArrayRef dims, int64_t rank) { + SmallVector posDims; + posDims.reserve(rank); + std::transform( + dims.begin(), dims.end(), std::back_inserter(posDims), + [rank](int64_t d) -> size_t { return toPositiveDim(d, rank); }); + return posDims; +} + +FailureOr> +getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, + ArrayRef inpDims) { auto valueTy = value.getType().dyn_cast(); if (!valueTy) { - op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); - return {}; + return rewriter.notifyMatchFailure( + op, "getDimSizesOfTensor(): the input is not a ranked tensor"); } auto rank = valueTy.getRank(); - if (rank == 0) { - return {}; - } - + auto dims = toPositiveDims(inpDims, rank); SmallVector dimSizes; - dimSizes.reserve(rank); + dimSizes.reserve(dims.size()); + auto loc = op->getLoc(); - for (auto d = 0; d < rank; ++d) { + for (auto d : dims) { dimSizes.emplace_back(rewriter.create( loc, rewriter.getIntegerType(kMhloDimSizeBits), rewriter.create(loc, value, d))); @@ -60,6 +68,21 @@ SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, return dimSizes; } +FailureOr> +getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { + auto valueTy = value.getType().dyn_cast(); + if (!valueTy) { + return rewriter.notifyMatchFailure( + op, "getDimSizesOfTensor(): the input is not a ranked tensor"); + } + + auto rank = valueTy.getRank(); + // Get int vector [0, 1, ..., rank-1] + std::vector dims(rank); + std::iota(dims.begin(), dims.end(), 0); + return getDimSizesOfTensor(rewriter, op, value, dims); +} + // A dimension index from torch.dialect might outside the range [0, dimSize]. // The function is used to normalize the input index into the range. Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, @@ -140,10 +163,11 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, // Get a dynamic slice of the tensor from startIndex to endIndex with stride // step on the specifed dimension. The input startIndex(default to 0), // endIndex(default to dimSize), and step(default to 1) can be optional. -Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input, - llvm::Optional startIndexOpt, - llvm::Optional endIndexOpt, - llvm::Optional stepOpt, int64_t dim) { +FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, + Value input, + llvm::Optional startIndexOpt, + llvm::Optional endIndexOpt, + llvm::Optional stepOpt, int64_t dim) { auto loc = op->getLoc(); auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); @@ -174,8 +198,13 @@ Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input, normEndIndex = rewriter.create(loc, i32Type, normEndIndex); step = rewriter.create(loc, i32Type, step); #endif - auto dimSizes = getDimSizesOfTensor(rewriter, op, input); + FailureOr> dimSizesInfo = + getDimSizesOfTensor(rewriter, op, input); + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + auto dimSizes = *dimSizesInfo; return getDynamicSliceInternal(rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes); } @@ -197,11 +226,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in MHLO Rsub"); + return op.emitError("only ranked tensor types are supported"); int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( - op, "Only constant dim is currently supported"); + op, "only constant dim is currently supported"); auto getOptionalVal = [&](Value val) -> llvm::Optional { if (val.getType().isa()) { @@ -215,9 +244,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::Optional end = getOptionalVal(adaptor.end()); llvm::Optional step = getOptionalVal(adaptor.step()); - Value sliced = getDynamicSlice(rewriter, op, self, start, end, step, dim); + FailureOr sliceInfo = + getDynamicSlice(rewriter, op, self, start, end, step, dim); + if (failed(sliceInfo)) + return op.emitError("can not create a dynmaic slice"); + + auto slice = *sliceInfo; rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), sliced); + op, getTypeConverter()->convertType(op.getType()), slice); return success(); } @@ -316,6 +350,160 @@ bool ConvertAtenViewOp::getAtenViewOpSizes( return getListConstructElements(adaptor.shape(), dimSizes); } +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, + ArrayRef inputUnsqzDims) { + // Returns a new tensor with dims of size 1 inserted at the specified + // position. + // + // The position indices (must be high to low dimension number of the returned + // tensor) are specified with unsqzDims. Indices must be in-order, and in + // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1, + // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not. + auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor); + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; + auto rank = dimSizes.size(); + size_t newRank = rank + inputUnsqzDims.size(); + auto unsqzDims = toPositiveDims(inputUnsqzDims, newRank); + for (size_t k = 0, sz = unsqzDims.size(); k < sz; ++k) + if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1]) + return rewriter.notifyMatchFailure( + op, "unsqueeze dimensions must be specified in order"); + + auto loc = op->getLoc(); + auto rankTy = tensor.getType().dyn_cast(); + auto oldShape = rankTy.getShape(); + Type intType = rewriter.getIntegerType(kMhloDimSizeBits); + auto one = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + + std::vector newDimSizes; + std::vector newShape; + newDimSizes.reserve(newRank); + newShape.reserve(newRank); + for (size_t k = 0, i = 0, j = 0; k < newRank; ++k) { + if (j < unsqzDims.size() && unsqzDims[j] == k) { + newDimSizes.push_back(one); + newShape.push_back(1); + j++; + } else { + newDimSizes.push_back(dimSizes[i]); + newShape.push_back(oldShape[i]); + i++; + } + } + + auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); + auto mhloShape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, mhloShape) + .getResult(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSqueezeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("only ranked tensor types are supported"); + + auto rank = selfTy.getRank(); + if (rank == 0) + return rewriter.notifyMatchFailure( + op, "The rank of tensor must be greater than 0"); + + SmallVector dims; + dims.reserve(rank); + for (int r = 0; r < rank; ++r) { + auto dSize = selfTy.getShape()[r]; + if (dSize == ShapedType::kDynamicSize) + return rewriter.notifyMatchFailure( + op, "the size of the dimension being squeezed can't be unknown"); + if (dSize != 1) + dims.push_back(r); + } + + auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims); + if (failed(newDimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + auto newDimSizes = *newDimSizesInfo; + auto mhloShape = + rewriter.create(op.getLoc(), newDimSizes); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, mhloShape); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSqueezeDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("only ranked tensor types are supported"); + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); + + auto rank = selfTy.getRank(); + if (rank == 0) + return rewriter.notifyMatchFailure( + op, "the rank of tensor must be greater than 0"); + + dim = toPositiveDim(dim, rank); + if (selfTy.getShape()[dim] != 1) { + if (selfTy.getShape()[dim] == ShapedType::kDynamicSize) + return rewriter.notifyMatchFailure( + op, "the size of the dimension being squeezed is can't be unknown"); + + rewriter.replaceOp(op, adaptor.self()); + return success(); + } + + SmallVector dims(rank); + std::iota(dims.begin(), dims.end(), 0); + dims.erase(dims.begin() + dim); + auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims); + if (failed(newDimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + auto newDimSizes = *newDimSizesInfo; + auto mhloShape = + rewriter.create(op.getLoc(), newDimSizes); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, mhloShape); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnsqueezeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) { + return op.emitError("only tensor types are currently supported"); + } + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return op->emitError("dim must be a Scalar constant"); + + auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, adaptor.self(), {dim}); + if (failed(unsqzTensorInfo)) + return rewriter.notifyMatchFailure(op, + "failed to create unsqueezed tensor"); + + rewriter.replaceOp(op, *unsqzTensorInfo); + return success(); +} } // namespace void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( @@ -327,6 +515,9 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenSqueezeOp); + INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); #undef INSERT_ATENOP_PATTERN #define INSERT_VIEW_OP_PATTERN(AtenOp) \ diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index 2e6394a76192..db04f201ea19 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -414,3 +414,168 @@ func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vt %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> return %1 : !torch.vtensor<[],f32> } +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32> +// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32> +func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32> + return %0 : !torch.vtensor<[2,1,2,1,2],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?,1,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?,1,?],f32> +func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,?,1,?],f32> + return %0 : !torch.vtensor<[?,?,1,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$from_end( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor +// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,1,?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?],f32> +func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { + %int-2 = torch.constant.int -2 + %0 = torch.aten.squeeze.dim %arg0, %int-2 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?],f32> + return %0 : !torch.vtensor<[?,1,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[T7:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xi64> +// CHECK: %[[T8:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T7]]) : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> +// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32> +// CHECK: return %[[T9]] : !torch.vtensor<[2,2,2],f32> +func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { + %0 = torch.aten.squeeze %arg0 : !torch.vtensor<[2,1,2,1,2],f32> -> !torch.vtensor<[2,2,2],f32> + return %0 : !torch.vtensor<[2,2,2],f32> +} + +// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$0( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[1,?,?,?,?],f32> +func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[1,?,?,?,?],f32> + return %0 : !torch.vtensor<[1,?,?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,1,?,?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?,?],f32> +func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?,?],f32> + return %0 : !torch.vtensor<[?,1,?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.unsqueeze$from_end( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[C1_I64]], %[[T8]] : tensor<5xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?,?,1,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?,?,1,?],f32> +func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { + %int-2 = torch.constant.int -2 + %0 = torch.aten.unsqueeze %arg0, %int-2 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?,?,1,?],f32> + return %0 : !torch.vtensor<[?,?,?,1,?],f32> +}