Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MHLO] Eliminate explicit dynamic output shape generating in converting AtenSliceTensorOp #1245

Merged
merged 2 commits into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 31 additions & 42 deletions lib/Conversion/TorchToMhlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
}

Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
Value input, Value startIndex, Value endIndex,
Value step, size_t dimIndex,
Type outTy, Value input, Value startIndex,
Value endIndex, Value step, size_t dimIndex,
ArrayRef<Value> dimSizes) {
auto loc = op->getLoc();
// startIndex & endIndex has been normailized into range [0, dSize]
Expand Down Expand Up @@ -98,20 +98,15 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
auto stridesTensor =
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();

auto inputShape = inputTy.getShape();
SmallVector<int64_t, 4> sliceShape(inputShape.begin(), inputShape.end());
sliceShape[dimIndex] = ShapedType::kDynamicSize;
auto sliceoutputTy =
RankedTensorType::get(sliceShape, inputTy.getElementType());
return rewriter.create<mhlo::RealDynamicSliceOp>(
loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor);
loc, outTy, input, startTensor, endTensor, stridesTensor);
}

// Get a dynamic slice of the tensor from startIndex to endIndex with stride
// step on the specifed dimension. The input startIndex(default to 0),
// endIndex(default to dimSize), and step(default to 1) can be optional.
FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
Value input,
Type outTy, Value input,
llvm::Optional<Value> startIndexOpt,
llvm::Optional<Value> endIndexOpt,
llvm::Optional<Value> stepOpt, int64_t dim) {
Expand Down Expand Up @@ -152,7 +147,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
op, "failed to get dimension sizes of the input");

auto dimSizes = *dimSizesInfo;
return getDynamicSliceInternal(rewriter, op, input, normStartIndex,
return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex,
normEndIndex, step, dim, dimSizes);
}

Expand All @@ -174,6 +169,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("only ranked tensor types are supported");
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
Expand All @@ -192,29 +189,26 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
llvm::Optional<Value> step = getOptionalVal(adaptor.step());

FailureOr<Value> sliceInfo =
getDynamicSlice(rewriter, op, self, start, end, step, dim);
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim);
if (failed(sliceInfo))
return op.emitError("can not create a dynmaic slice");

Vremold marked this conversation as resolved.
Show resolved Hide resolved
auto slice = *sliceInfo;
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), slice);

rewriter.replaceOp(op, slice);
return success();
}

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename AtenOpT>
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
public:
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

LogicalResult matchAndRewrite(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rankType =
adaptor.self().getType().template dyn_cast<RankedTensorType>();
if (!rankType)
Expand All @@ -236,18 +230,19 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
return success();
}

std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
dSize = rewriter.create<ToI64Op>(loc, dSize).getResult();
return dSize;
});

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
// The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
// One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
// the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
// unlikely to exceed the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
// dimSize: cast i64 -> i32
dSize = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
dSize =
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
return dSize;
});
#endif
Expand All @@ -272,28 +267,22 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
return success();
}

bool getAtenViewOpSizes(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const;
bool getAtenViewOpSizes(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &dimSizes) const;
};

template <>
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
AtenViewOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &dimSizes) const {
return getListConstructElements(adaptor.size(), dimSizes);
}

template <>
bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
AtenReshapeOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
AtenReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &dimSizes) const {
return getListConstructElements(adaptor.shape(), dimSizes);
}

Expand Down Expand Up @@ -415,10 +404,10 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN
}
38 changes: 16 additions & 22 deletions test/Conversion/TorchToMhlo/view_like.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64>
// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T30:.*]] = mhlo.convert %[[T29]] : tensor<?x?x?xf32>
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T31]] : !torch.vtensor<[?,?,?],f32>
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T30]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
Expand Down Expand Up @@ -96,10 +95,9 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32
// CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64>
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64>
// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x65x256xf32>
// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<?x65x256xf32>) -> tensor<2x65x256xf32>
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
// CHECK: return %[[T31]] : !torch.vtensor<[2,65,256],f32>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
// CHECK: return %[[T30]] : !torch.vtensor<[2,65,256],f32>
func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
Expand Down Expand Up @@ -151,10 +149,9 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6
// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64>
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64>
// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<?x?x?xf32>) -> tensor<?x1x?xf32>
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T31]] : !torch.vtensor<[?,1,?],f32>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T30]] : !torch.vtensor<[?,1,?],f32>
func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
Expand Down Expand Up @@ -206,10 +203,9 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64>
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64>
// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32>
// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<4x?x256xf32>) -> tensor<4x1x256xf32>
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
// CHECK: return %[[T31]] : !torch.vtensor<[4,1,256],f32>
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
// CHECK: return %[[T30]] : !torch.vtensor<[4,1,256],f32>
func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
Expand Down Expand Up @@ -247,9 +243,8 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64>
// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T16:.*]] = mhlo.convert %[[T15]] : tensor<?x?x?xf32>
// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T17]] : !torch.vtensor<[?,?,?],f32>
// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T16]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
Expand Down Expand Up @@ -285,10 +280,9 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64>
// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64>
// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32>
// CHECK: %[[T16:.*]] = mhlo.convert(%[[T15]]) : (tensor<4x?x256xf32>) -> tensor<4x33x256xf32>
// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
// CHECK: return %[[T17]] : !torch.vtensor<[4,33,256],f32>
// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
// CHECK: return %[[T16]] : !torch.vtensor<[4,33,256],f32>
func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
Expand Down