Skip to content

Commit

Permalink
[MHLO] Eliminate explicit dynamic output shape generating in converti…
Browse files Browse the repository at this point in the history
…ng AtenSliceTensorOp (#1245)

[MHLO] Eliminate explicit dynamic output shape generating in converting AtenSliceTensorOp
  • Loading branch information
Vremold authored Aug 19, 2022
1 parent 0d1aa43 commit 7bd173a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 64 deletions.
73 changes: 31 additions & 42 deletions lib/Conversion/TorchToMhlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
}

Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
Value input, Value startIndex, Value endIndex,
Value step, size_t dimIndex,
Type outTy, Value input, Value startIndex,
Value endIndex, Value step, size_t dimIndex,
ArrayRef<Value> dimSizes) {
auto loc = op->getLoc();
// startIndex & endIndex has been normailized into range [0, dSize]
Expand Down Expand Up @@ -98,20 +98,15 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
auto stridesTensor =
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();

auto inputShape = inputTy.getShape();
SmallVector<int64_t, 4> sliceShape(inputShape.begin(), inputShape.end());
sliceShape[dimIndex] = ShapedType::kDynamicSize;
auto sliceoutputTy =
RankedTensorType::get(sliceShape, inputTy.getElementType());
return rewriter.create<mhlo::RealDynamicSliceOp>(
loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor);
loc, outTy, input, startTensor, endTensor, stridesTensor);
}

// Get a dynamic slice of the tensor from startIndex to endIndex with stride
// step on the specifed dimension. The input startIndex(default to 0),
// endIndex(default to dimSize), and step(default to 1) can be optional.
FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
Value input,
Type outTy, Value input,
llvm::Optional<Value> startIndexOpt,
llvm::Optional<Value> endIndexOpt,
llvm::Optional<Value> stepOpt, int64_t dim) {
Expand Down Expand Up @@ -152,7 +147,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
op, "failed to get dimension sizes of the input");

auto dimSizes = *dimSizesInfo;
return getDynamicSliceInternal(rewriter, op, input, normStartIndex,
return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex,
normEndIndex, step, dim, dimSizes);
}

Expand All @@ -174,6 +169,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("only ranked tensor types are supported");
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
Expand All @@ -192,29 +189,26 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
llvm::Optional<Value> step = getOptionalVal(adaptor.step());

FailureOr<Value> sliceInfo =
getDynamicSlice(rewriter, op, self, start, end, step, dim);
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim);
if (failed(sliceInfo))
return op.emitError("can not create a dynmaic slice");

auto slice = *sliceInfo;
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), slice);

rewriter.replaceOp(op, slice);
return success();
}

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename AtenOpT>
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
public:
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

LogicalResult matchAndRewrite(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rankType =
adaptor.self().getType().template dyn_cast<RankedTensorType>();
if (!rankType)
Expand All @@ -236,18 +230,19 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
return success();
}

std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
dSize = rewriter.create<ToI64Op>(loc, dSize).getResult();
return dSize;
});

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
// The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
// One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
// the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
// unlikely to exceed the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
// dimSize: cast i64 -> i32
dSize = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
dSize =
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
return dSize;
});
#endif
Expand All @@ -272,28 +267,22 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
return success();
}

bool getAtenViewOpSizes(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const;
bool getAtenViewOpSizes(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &dimSizes) const;
};

template <>
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
AtenViewOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &dimSizes) const {
return getListConstructElements(adaptor.size(), dimSizes);
}

template <>
bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
AtenReshapeOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
AtenReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &dimSizes) const {
return getListConstructElements(adaptor.shape(), dimSizes);
}

Expand Down Expand Up @@ -415,10 +404,10 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN
}
38 changes: 16 additions & 22 deletions test/Conversion/TorchToMhlo/view_like.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64>
// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T30:.*]] = mhlo.convert %[[T29]] : tensor<?x?x?xf32>
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T31]] : !torch.vtensor<[?,?,?],f32>
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T30]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
Expand Down Expand Up @@ -96,10 +95,9 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32
// CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64>
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64>
// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x65x256xf32>
// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<?x65x256xf32>) -> tensor<2x65x256xf32>
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
// CHECK: return %[[T31]] : !torch.vtensor<[2,65,256],f32>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
// CHECK: return %[[T30]] : !torch.vtensor<[2,65,256],f32>
func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
Expand Down Expand Up @@ -151,10 +149,9 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6
// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64>
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64>
// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<?x?x?xf32>) -> tensor<?x1x?xf32>
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T31]] : !torch.vtensor<[?,1,?],f32>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T30]] : !torch.vtensor<[?,1,?],f32>
func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
Expand Down Expand Up @@ -206,10 +203,9 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64>
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64>
// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32>
// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<4x?x256xf32>) -> tensor<4x1x256xf32>
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
// CHECK: return %[[T31]] : !torch.vtensor<[4,1,256],f32>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
// CHECK: return %[[T30]] : !torch.vtensor<[4,1,256],f32>
func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
Expand Down Expand Up @@ -247,9 +243,8 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64>
// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T16:.*]] = mhlo.convert %[[T15]] : tensor<?x?x?xf32>
// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T17]] : !torch.vtensor<[?,?,?],f32>
// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T16]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
Expand Down Expand Up @@ -285,10 +280,9 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64>
// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64>
// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32>
// CHECK: %[[T16:.*]] = mhlo.convert(%[[T15]]) : (tensor<4x?x256xf32>) -> tensor<4x33x256xf32>
// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
// CHECK: return %[[T17]] : !torch.vtensor<[4,33,256],f32>
// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
// CHECK: return %[[T16]] : !torch.vtensor<[4,33,256],f32>
func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
Expand Down

0 comments on commit 7bd173a

Please sign in to comment.