From 8c454e3e989c547582d981ec00b3e31ab237ed82 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Fri, 22 Jul 2022 17:30:09 +0800 Subject: [PATCH 1/3] [MHLO] Add [un]squeeze op patterns --- lib/Conversion/TorchToMhlo/ViewLikeOps.cpp | 179 ++++++++++++++++++++- test/Conversion/TorchToMhlo/view_like.mlir | 165 +++++++++++++++++++ 2 files changed, 341 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp index 0ecd96bf6293..83d09623cc2d 100644 --- a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp @@ -36,8 +36,18 @@ static constexpr size_t kMhloDimSizeBits = 64; namespace { +SmallVector toPositiveDims(ArrayRef dims, int64_t rank) { + SmallVector posDims; + posDims.reserve(rank); + std::transform( + dims.begin(), dims.end(), std::back_inserter(posDims), + [rank](int64_t d) -> size_t { return toPositiveDim(d, rank); }); + return posDims; +} + SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, - Operation *op, Value value) { + Operation *op, Value value, + ArrayRef inpDims) { auto valueTy = value.getType().dyn_cast(); if (!valueTy) { op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); @@ -49,10 +59,11 @@ SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, return {}; } + auto dims = toPositiveDims(inpDims, rank); SmallVector dimSizes; - dimSizes.reserve(rank); + dimSizes.reserve(dims.size()); auto loc = op->getLoc(); - for (auto d = 0; d < rank; ++d) { + for (auto d : dims) { dimSizes.emplace_back(rewriter.create( loc, rewriter.getIntegerType(kMhloDimSizeBits), rewriter.create(loc, value, d))); @@ -60,6 +71,21 @@ SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, return dimSizes; } +SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value) { + auto valueTy = value.getType().dyn_cast(); + if (!valueTy) { + op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); + return {}; + } + + auto rank = valueTy.getRank(); + // Get int vector [0, 1, ..., rank-1] + std::vector dims(rank); + std::iota(dims.begin(), dims.end(), 0); + return getDimSizesOfTensor(rewriter, op, value, dims); +} + // A dimension index from torch.dialect might outside the range [0, dimSize]. // The function is used to normalize the input index into the range. Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, @@ -316,6 +342,150 @@ bool ConvertAtenViewOp::getAtenViewOpSizes( return getListConstructElements(adaptor.shape(), dimSizes); } +llvm::Optional unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, + ArrayRef inputUnsqzDims) { + // Returns a new tensor with dims of size 1 inserted at the specified + // position. + // + // The position indices (must be high to low dimension number of the returned + // tensor) are specified with unsqzDims. Indices must be in-order, and in + // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1, + // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not. + auto dimSizes = getDimSizesOfTensor(rewriter, op, tensor); + auto rank = dimSizes.size(); + size_t newRank = rank + inputUnsqzDims.size(); + auto unsqzDims = toPositiveDims(inputUnsqzDims, newRank); + for (size_t k = 0; k < unsqzDims.size(); ++k) { + if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1]) { + op->emitOpError("Unsqueeze dimensions must be specified in order."); + return llvm::None; + } + } + + auto loc = op->getLoc(); + auto rankTy = tensor.getType().dyn_cast(); + auto oldShape = rankTy.getShape(); + Type intType = rewriter.getIntegerType(kMhloDimSizeBits); + auto one = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + + std::vector newDimSizes; + std::vector newShape; + newDimSizes.reserve(newRank); + newShape.reserve(newRank); + for (size_t k = 0, i = 0, j = 0; k < newRank; ++k) { + if (j < unsqzDims.size() && unsqzDims[j] == k) { + newDimSizes.push_back(one); + newShape.push_back(1); + j++; + } else { + newDimSizes.push_back(dimSizes[i]); + newShape.push_back(oldShape[i]); + i++; + } + } + + auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); + auto mhloShape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, mhloShape) + .getResult(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSqueezeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("Only ranked tensor types supported in MHLO"); + + auto rank = selfTy.getRank(); + if (rank == 0) { + return rewriter.notifyMatchFailure( + op, "The rank of tensor must be greater than 0"); + } + + SmallVector dims; + dims.reserve(rank); + for (int r = 0; r < rank; ++r) { + auto dSize = selfTy.getShape()[r]; + if (dSize == ShapedType::kDynamicSize) { + return rewriter.notifyMatchFailure( + op, "The size of the dimension being squeezed can't be unknown"); + } + if (dSize != 1) { + dims.push_back(r); + } + } + + auto newDimSizes = getDimSizesOfTensor(rewriter, op, self, dims); + auto mhloShape = + rewriter.create(op.getLoc(), newDimSizes); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, mhloShape); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSqueezeDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("Only ranked tensor types supported in MHLO"); + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "Only constant dim is currently supported"); + + auto rank = selfTy.getRank(); + if (rank == 0) { + return rewriter.notifyMatchFailure( + op, "The rank of tensor must be greater than 0"); + } + + dim = toPositiveDim(dim, rank); + if (selfTy.getShape()[dim] != 1) { + if (selfTy.getShape()[dim] == ShapedType::kDynamicSize) { + return rewriter.notifyMatchFailure( + op, "The size of the dimension being squeezed is can't be unknown"); + } else { + rewriter.replaceOp(op, adaptor.self()); + return success(); + } + } + + SmallVector dims(rank); + std::iota(dims.begin(), dims.end(), 0); + dims.erase(dims.begin() + dim); + auto newDimSizes = getDimSizesOfTensor(rewriter, op, self, dims); + auto mhloShape = + rewriter.create(op.getLoc(), newDimSizes); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, mhloShape); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnsqueezeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) { + return op.emitError("Only tensor types are currently supported"); + } + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return op->emitError("dim must be a Scalar constant"); + + auto unsqzTensor = unsqueezeTensor(rewriter, op, adaptor.self(), {dim}); + rewriter.replaceOp(op, *unsqzTensor); + return success(); +} } // namespace void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( @@ -327,6 +497,9 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenSqueezeOp); + INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); #undef INSERT_ATENOP_PATTERN #define INSERT_VIEW_OP_PATTERN(AtenOp) \ diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index 2e6394a76192..db04f201ea19 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -414,3 +414,168 @@ func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vt %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> return %1 : !torch.vtensor<[],f32> } +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32> +// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32> +func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32> + return %0 : !torch.vtensor<[2,1,2,1,2],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?,1,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?,1,?],f32> +func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,?,1,?],f32> + return %0 : !torch.vtensor<[?,?,1,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$from_end( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor +// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,1,?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?],f32> +func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { + %int-2 = torch.constant.int -2 + %0 = torch.aten.squeeze.dim %arg0, %int-2 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?],f32> + return %0 : !torch.vtensor<[?,1,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[T7:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xi64> +// CHECK: %[[T8:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T7]]) : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> +// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32> +// CHECK: return %[[T9]] : !torch.vtensor<[2,2,2],f32> +func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { + %0 = torch.aten.squeeze %arg0 : !torch.vtensor<[2,1,2,1,2],f32> -> !torch.vtensor<[2,2,2],f32> + return %0 : !torch.vtensor<[2,2,2],f32> +} + +// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$0( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[1,?,?,?,?],f32> +func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[1,?,?,?,?],f32> + return %0 : !torch.vtensor<[1,?,?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,1,?,?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?,?],f32> +func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?,?],f32> + return %0 : !torch.vtensor<[?,1,?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.unsqueeze$from_end( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[C1_I64]], %[[T8]] : tensor<5xi64> +// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?,?,1,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?,?,1,?],f32> +func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { + %int-2 = torch.constant.int -2 + %0 = torch.aten.unsqueeze %arg0, %int-2 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?,?,1,?],f32> + return %0 : !torch.vtensor<[?,?,?,1,?],f32> +} From c9daade9e9a895ca171a369c3faa120d55a051a0 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Mon, 25 Jul 2022 13:42:21 +0800 Subject: [PATCH 2/3] Conform to llvm coding standard --- lib/Conversion/TorchToMhlo/ViewLikeOps.cpp | 131 ++++++++++++--------- 1 file changed, 76 insertions(+), 55 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp index 83d09623cc2d..e3eb341707e9 100644 --- a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp @@ -45,23 +45,23 @@ SmallVector toPositiveDims(ArrayRef dims, int64_t rank) { return posDims; } -SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, - Operation *op, Value value, - ArrayRef inpDims) { +FailureOr> +getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, + ArrayRef inpDims) { auto valueTy = value.getType().dyn_cast(); if (!valueTy) { - op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); - return {}; + return rewriter.notifyMatchFailure( + op, "getDimSizesOfTensor(): the input is not a ranked tensor"); } auto rank = valueTy.getRank(); - if (rank == 0) { - return {}; - } - auto dims = toPositiveDims(inpDims, rank); SmallVector dimSizes; dimSizes.reserve(dims.size()); + + if (rank == 0) { + return dimSizes; + } auto loc = op->getLoc(); for (auto d : dims) { dimSizes.emplace_back(rewriter.create( @@ -71,12 +71,12 @@ SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, return dimSizes; } -SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, - Operation *op, Value value) { +FailureOr> +getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { auto valueTy = value.getType().dyn_cast(); if (!valueTy) { - op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); - return {}; + return rewriter.notifyMatchFailure( + op, "getDimSizesOfTensor(): the input is not a ranked tensor"); } auto rank = valueTy.getRank(); @@ -166,10 +166,11 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, // Get a dynamic slice of the tensor from startIndex to endIndex with stride // step on the specifed dimension. The input startIndex(default to 0), // endIndex(default to dimSize), and step(default to 1) can be optional. -Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input, - llvm::Optional startIndexOpt, - llvm::Optional endIndexOpt, - llvm::Optional stepOpt, int64_t dim) { +FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, + Value input, + llvm::Optional startIndexOpt, + llvm::Optional endIndexOpt, + llvm::Optional stepOpt, int64_t dim) { auto loc = op->getLoc(); auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); @@ -200,8 +201,13 @@ Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input, normEndIndex = rewriter.create(loc, i32Type, normEndIndex); step = rewriter.create(loc, i32Type, step); #endif - auto dimSizes = getDimSizesOfTensor(rewriter, op, input); + FailureOr> dimSizesInfo = + getDimSizesOfTensor(rewriter, op, input); + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + auto dimSizes = *dimSizesInfo; return getDynamicSliceInternal(rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes); } @@ -223,11 +229,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in MHLO Rsub"); + return op.emitError("only ranked tensor types are supported"); int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( - op, "Only constant dim is currently supported"); + op, "only constant dim is currently supported"); auto getOptionalVal = [&](Value val) -> llvm::Optional { if (val.getType().isa()) { @@ -241,7 +247,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::Optional end = getOptionalVal(adaptor.end()); llvm::Optional step = getOptionalVal(adaptor.step()); - Value sliced = getDynamicSlice(rewriter, op, self, start, end, step, dim); + FailureOr slicedInfo = + getDynamicSlice(rewriter, op, self, start, end, step, dim); + if (failed(slicedInfo)) + return op.emitError("can not create a dynmaic slice"); + + auto sliced = *slicedInfo; rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), sliced); @@ -342,9 +353,9 @@ bool ConvertAtenViewOp::getAtenViewOpSizes( return getListConstructElements(adaptor.shape(), dimSizes); } -llvm::Optional unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value tensor, - ArrayRef inputUnsqzDims) { +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, + ArrayRef inputUnsqzDims) { // Returns a new tensor with dims of size 1 inserted at the specified // position. // @@ -352,16 +363,19 @@ llvm::Optional unsqueezeTensor(PatternRewriter &rewriter, Operation *op, // tensor) are specified with unsqzDims. Indices must be in-order, and in // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1, // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not. - auto dimSizes = getDimSizesOfTensor(rewriter, op, tensor); + auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor); + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; auto rank = dimSizes.size(); size_t newRank = rank + inputUnsqzDims.size(); auto unsqzDims = toPositiveDims(inputUnsqzDims, newRank); - for (size_t k = 0; k < unsqzDims.size(); ++k) { - if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1]) { - op->emitOpError("Unsqueeze dimensions must be specified in order."); - return llvm::None; - } - } + for (size_t k = 0, sz = unsqzDims.size(); k < sz; ++k) + if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1]) + return rewriter.notifyMatchFailure( + op, "unsqueeze dimensions must be specified in order"); auto loc = op->getLoc(); auto rankTy = tensor.getType().dyn_cast(); @@ -399,28 +413,29 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in MHLO"); + return op.emitError("only ranked tensor types are supported"); auto rank = selfTy.getRank(); - if (rank == 0) { + if (rank == 0) return rewriter.notifyMatchFailure( op, "The rank of tensor must be greater than 0"); - } SmallVector dims; dims.reserve(rank); for (int r = 0; r < rank; ++r) { auto dSize = selfTy.getShape()[r]; - if (dSize == ShapedType::kDynamicSize) { + if (dSize == ShapedType::kDynamicSize) return rewriter.notifyMatchFailure( - op, "The size of the dimension being squeezed can't be unknown"); - } - if (dSize != 1) { + op, "the size of the dimension being squeezed can't be unknown"); + if (dSize != 1) dims.push_back(r); - } } - auto newDimSizes = getDimSizesOfTensor(rewriter, op, self, dims); + auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims); + if (failed(newDimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + auto newDimSizes = *newDimSizesInfo; auto mhloShape = rewriter.create(op.getLoc(), newDimSizes); rewriter.replaceOpWithNewOp( @@ -435,33 +450,35 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) - return op.emitError("Only ranked tensor types supported in MHLO"); + return op.emitError("only ranked tensor types are supported"); int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( - op, "Only constant dim is currently supported"); + op, "only constant dim is currently supported"); auto rank = selfTy.getRank(); - if (rank == 0) { + if (rank == 0) return rewriter.notifyMatchFailure( - op, "The rank of tensor must be greater than 0"); - } + op, "the rank of tensor must be greater than 0"); dim = toPositiveDim(dim, rank); if (selfTy.getShape()[dim] != 1) { - if (selfTy.getShape()[dim] == ShapedType::kDynamicSize) { + if (selfTy.getShape()[dim] == ShapedType::kDynamicSize) return rewriter.notifyMatchFailure( - op, "The size of the dimension being squeezed is can't be unknown"); - } else { - rewriter.replaceOp(op, adaptor.self()); - return success(); - } + op, "the size of the dimension being squeezed is can't be unknown"); + + rewriter.replaceOp(op, adaptor.self()); + return success(); } SmallVector dims(rank); std::iota(dims.begin(), dims.end(), 0); dims.erase(dims.begin() + dim); - auto newDimSizes = getDimSizesOfTensor(rewriter, op, self, dims); + auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims); + if (failed(newDimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + auto newDimSizes = *newDimSizesInfo; auto mhloShape = rewriter.create(op.getLoc(), newDimSizes); rewriter.replaceOpWithNewOp( @@ -475,15 +492,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) { - return op.emitError("Only tensor types are currently supported"); + return op.emitError("only tensor types are currently supported"); } int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); - auto unsqzTensor = unsqueezeTensor(rewriter, op, adaptor.self(), {dim}); - rewriter.replaceOp(op, *unsqzTensor); + auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, adaptor.self(), {dim}); + if (failed(unsqzTensorInfo)) + return rewriter.notifyMatchFailure(op, + "failed to create unsqueezed tensor"); + + rewriter.replaceOp(op, *unsqzTensorInfo); return success(); } } // namespace From 99e158d7a5ccf125f1affaab8cae87568a55674a Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Mon, 25 Jul 2022 13:53:45 +0800 Subject: [PATCH 3/3] minor update --- lib/Conversion/TorchToMhlo/ViewLikeOps.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp index e3eb341707e9..18d21c193ba6 100644 --- a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp @@ -59,9 +59,6 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, SmallVector dimSizes; dimSizes.reserve(dims.size()); - if (rank == 0) { - return dimSizes; - } auto loc = op->getLoc(); for (auto d : dims) { dimSizes.emplace_back(rewriter.create( @@ -247,14 +244,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::Optional end = getOptionalVal(adaptor.end()); llvm::Optional step = getOptionalVal(adaptor.step()); - FailureOr slicedInfo = + FailureOr sliceInfo = getDynamicSlice(rewriter, op, self, start, end, step, dim); - if (failed(slicedInfo)) + if (failed(sliceInfo)) return op.emitError("can not create a dynmaic slice"); - auto sliced = *slicedInfo; + auto slice = *sliceInfo; rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), sliced); + op, getTypeConverter()->convertType(op.getType()), slice); return success(); }