diff --git a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp index 18d21c193ba6..0ecd96bf6293 100644 --- a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp @@ -36,31 +36,23 @@ static constexpr size_t kMhloDimSizeBits = 64; namespace { -SmallVector toPositiveDims(ArrayRef dims, int64_t rank) { - SmallVector posDims; - posDims.reserve(rank); - std::transform( - dims.begin(), dims.end(), std::back_inserter(posDims), - [rank](int64_t d) -> size_t { return toPositiveDim(d, rank); }); - return posDims; -} - -FailureOr> -getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, - ArrayRef inpDims) { +SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value) { auto valueTy = value.getType().dyn_cast(); if (!valueTy) { - return rewriter.notifyMatchFailure( - op, "getDimSizesOfTensor(): the input is not a ranked tensor"); + op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); + return {}; } auto rank = valueTy.getRank(); - auto dims = toPositiveDims(inpDims, rank); - SmallVector dimSizes; - dimSizes.reserve(dims.size()); + if (rank == 0) { + return {}; + } + SmallVector dimSizes; + dimSizes.reserve(rank); auto loc = op->getLoc(); - for (auto d : dims) { + for (auto d = 0; d < rank; ++d) { dimSizes.emplace_back(rewriter.create( loc, rewriter.getIntegerType(kMhloDimSizeBits), rewriter.create(loc, value, d))); @@ -68,21 +60,6 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, return dimSizes; } -FailureOr> -getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { - auto valueTy = value.getType().dyn_cast(); - if (!valueTy) { - return rewriter.notifyMatchFailure( - op, "getDimSizesOfTensor(): the input is not a ranked tensor"); - } - - auto rank = valueTy.getRank(); - // Get int vector [0, 1, ..., rank-1] - std::vector dims(rank); - std::iota(dims.begin(), dims.end(), 0); - return getDimSizesOfTensor(rewriter, op, value, dims); -} - // A dimension index from torch.dialect might outside the range [0, dimSize]. // The function is used to normalize the input index into the range. Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, @@ -163,11 +140,10 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, // Get a dynamic slice of the tensor from startIndex to endIndex with stride // step on the specifed dimension. The input startIndex(default to 0), // endIndex(default to dimSize), and step(default to 1) can be optional. -FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, - Value input, - llvm::Optional startIndexOpt, - llvm::Optional endIndexOpt, - llvm::Optional stepOpt, int64_t dim) { +Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input, + llvm::Optional startIndexOpt, + llvm::Optional endIndexOpt, + llvm::Optional stepOpt, int64_t dim) { auto loc = op->getLoc(); auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); @@ -198,13 +174,8 @@ FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, normEndIndex = rewriter.create(loc, i32Type, normEndIndex); step = rewriter.create(loc, i32Type, step); #endif - FailureOr> dimSizesInfo = - getDimSizesOfTensor(rewriter, op, input); - if (failed(dimSizesInfo)) - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); + auto dimSizes = getDimSizesOfTensor(rewriter, op, input); - auto dimSizes = *dimSizesInfo; return getDynamicSliceInternal(rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes); } @@ -226,11 +197,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("only ranked tensor types are supported"); + return op.emitError("Only ranked tensor types supported in MHLO Rsub"); int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( - op, "only constant dim is currently supported"); + op, "Only constant dim is currently supported"); auto getOptionalVal = [&](Value val) -> llvm::Optional { if (val.getType().isa()) { @@ -244,14 +215,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::Optional end = getOptionalVal(adaptor.end()); llvm::Optional step = getOptionalVal(adaptor.step()); - FailureOr sliceInfo = - getDynamicSlice(rewriter, op, self, start, end, step, dim); - if (failed(sliceInfo)) - return op.emitError("can not create a dynmaic slice"); - - auto slice = *sliceInfo; + Value sliced = getDynamicSlice(rewriter, op, self, start, end, step, dim); rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), slice); + op, getTypeConverter()->convertType(op.getType()), sliced); return success(); } @@ -350,160 +316,6 @@ bool ConvertAtenViewOp::getAtenViewOpSizes( return getListConstructElements(adaptor.shape(), dimSizes); } -FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value tensor, - ArrayRef inputUnsqzDims) { - // Returns a new tensor with dims of size 1 inserted at the specified - // position. - // - // The position indices (must be high to low dimension number of the returned - // tensor) are specified with unsqzDims. Indices must be in-order, and in - // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1, - // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not. - auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor); - if (failed(dimSizesInfo)) - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - - auto dimSizes = *dimSizesInfo; - auto rank = dimSizes.size(); - size_t newRank = rank + inputUnsqzDims.size(); - auto unsqzDims = toPositiveDims(inputUnsqzDims, newRank); - for (size_t k = 0, sz = unsqzDims.size(); k < sz; ++k) - if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1]) - return rewriter.notifyMatchFailure( - op, "unsqueeze dimensions must be specified in order"); - - auto loc = op->getLoc(); - auto rankTy = tensor.getType().dyn_cast(); - auto oldShape = rankTy.getShape(); - Type intType = rewriter.getIntegerType(kMhloDimSizeBits); - auto one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); - - std::vector newDimSizes; - std::vector newShape; - newDimSizes.reserve(newRank); - newShape.reserve(newRank); - for (size_t k = 0, i = 0, j = 0; k < newRank; ++k) { - if (j < unsqzDims.size() && unsqzDims[j] == k) { - newDimSizes.push_back(one); - newShape.push_back(1); - j++; - } else { - newDimSizes.push_back(dimSizes[i]); - newShape.push_back(oldShape[i]); - i++; - } - } - - auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); - auto mhloShape = rewriter.create(loc, newDimSizes); - return rewriter.create(loc, outTy, tensor, mhloShape) - .getResult(); -} - -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenSqueezeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto self = adaptor.self(); - auto selfTy = self.getType().template cast(); - if (!selfTy) - return op.emitError("only ranked tensor types are supported"); - - auto rank = selfTy.getRank(); - if (rank == 0) - return rewriter.notifyMatchFailure( - op, "The rank of tensor must be greater than 0"); - - SmallVector dims; - dims.reserve(rank); - for (int r = 0; r < rank; ++r) { - auto dSize = selfTy.getShape()[r]; - if (dSize == ShapedType::kDynamicSize) - return rewriter.notifyMatchFailure( - op, "the size of the dimension being squeezed can't be unknown"); - if (dSize != 1) - dims.push_back(r); - } - - auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims); - if (failed(newDimSizesInfo)) - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - auto newDimSizes = *newDimSizesInfo; - auto mhloShape = - rewriter.create(op.getLoc(), newDimSizes); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self, mhloShape); - return success(); -} - -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenSqueezeDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto self = adaptor.self(); - auto selfTy = self.getType().template cast(); - if (!selfTy) - return op.emitError("only ranked tensor types are supported"); - int64_t dim; - if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return rewriter.notifyMatchFailure( - op, "only constant dim is currently supported"); - - auto rank = selfTy.getRank(); - if (rank == 0) - return rewriter.notifyMatchFailure( - op, "the rank of tensor must be greater than 0"); - - dim = toPositiveDim(dim, rank); - if (selfTy.getShape()[dim] != 1) { - if (selfTy.getShape()[dim] == ShapedType::kDynamicSize) - return rewriter.notifyMatchFailure( - op, "the size of the dimension being squeezed is can't be unknown"); - - rewriter.replaceOp(op, adaptor.self()); - return success(); - } - - SmallVector dims(rank); - std::iota(dims.begin(), dims.end(), 0); - dims.erase(dims.begin() + dim); - auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims); - if (failed(newDimSizesInfo)) - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - auto newDimSizes = *newDimSizesInfo; - auto mhloShape = - rewriter.create(op.getLoc(), newDimSizes); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self, mhloShape); - return success(); -} - -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenUnsqueezeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto selfType = adaptor.self().getType().dyn_cast(); - if (!selfType) { - return op.emitError("only tensor types are currently supported"); - } - - int64_t dim; - if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return op->emitError("dim must be a Scalar constant"); - - auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, adaptor.self(), {dim}); - if (failed(unsqzTensorInfo)) - return rewriter.notifyMatchFailure(op, - "failed to create unsqueezed tensor"); - - rewriter.replaceOp(op, *unsqzTensorInfo); - return success(); -} } // namespace void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( @@ -515,9 +327,6 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenSliceTensorOp); - INSERT_ATENOP_PATTERN(AtenSqueezeOp); - INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); - INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); #undef INSERT_ATENOP_PATTERN #define INSERT_VIEW_OP_PATTERN(AtenOp) \ diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index db04f201ea19..2e6394a76192 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -414,168 +414,3 @@ func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vt %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> return %1 : !torch.vtensor<[],f32> } -// CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32> -// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32> -func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { - %int0 = torch.constant.int 0 - %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32> - return %0 : !torch.vtensor<[2,1,2,1,2],f32> -} - -// CHECK-LABEL: func.func @torch.aten.squeeze.dim$1( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 -// CHECK: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 -// CHECK: %[[C4:.*]] = arith.constant 4 : index -// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor -// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 -// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64> -// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<4xi64>) -> tensor -// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?,1,?],f32> -// CHECK: return %[[T11]] : !torch.vtensor<[?,?,1,?],f32> -func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { - %int1 = torch.constant.int 1 - %0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,?,1,?],f32> - return %0 : !torch.vtensor<[?,?,1,?],f32> -} - -// CHECK-LABEL: func.func @torch.aten.squeeze.dim$from_end( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor -// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 -// CHECK: %[[C4:.*]] = arith.constant 4 : index -// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor -// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 -// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64> -// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<4xi64>) -> tensor -// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,1,?,?],f32> -// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?],f32> -func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { - %int-2 = torch.constant.int -2 - %0 = torch.aten.squeeze.dim %arg0, %int-2 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?],f32> - return %0 : !torch.vtensor<[?,1,?,?],f32> -} - -// CHECK-LABEL: func.func @torch.aten.squeeze$static( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32> -// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32> -// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 -// CHECK: %[[C4:.*]] = arith.constant 4 : index -// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> -// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 -// CHECK: %[[T7:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T7]]) : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> -// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32> -// CHECK: return %[[T9]] : !torch.vtensor<[2,2,2],f32> -func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { - %0 = torch.aten.squeeze %arg0 : !torch.vtensor<[2,1,2,1,2],f32> -> !torch.vtensor<[2,2,2],f32> - return %0 : !torch.vtensor<[2,2,2],f32> -} - -// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$0( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 -// CHECK: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T9:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64> -// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> -// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32> -// CHECK: return %[[T11]] : !torch.vtensor<[1,?,?,?,?],f32> -func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { - %int0 = torch.constant.int 0 - %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[1,?,?,?,?],f32> - return %0 : !torch.vtensor<[1,?,?,?,?],f32> -} - -// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$1( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 -// CHECK: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64> -// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<5xi64>) -> tensor -// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,1,?,?,?],f32> -// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?,?],f32> -func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { - %int1 = torch.constant.int 1 - %0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?,?],f32> - return %0 : !torch.vtensor<[?,1,?,?,?],f32> -} - -// CHECK-LABEL: func.func @torch.aten.unsqueeze$from_end( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 -// CHECK: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[C1_I64]], %[[T8]] : tensor<5xi64> -// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<5xi64>) -> tensor -// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?,?,1,?],f32> -// CHECK: return %[[T11]] : !torch.vtensor<[?,?,?,1,?],f32> -func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { - %int-2 = torch.constant.int -2 - %0 = torch.aten.unsqueeze %arg0, %int-2 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?,?,1,?],f32> - return %0 : !torch.vtensor<[?,?,?,1,?],f32> -}