Skip to content

Commit

Permalink
[MHLO] Init MHLO gather op patterns
Browse files Browse the repository at this point in the history
See RFC #999

Co-authored-by: Bairen Yi yibairen.byron@bytedance.com
Co-authored-by: Jiawei Wu xremold@gmail.com
Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com
Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com
Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.com
  • Loading branch information
Tanyo Kwok committed Jul 25, 2022
1 parent b80ce79 commit 91e3206
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 2 deletions.
1 change: 1 addition & 0 deletions lib/Conversion/TorchToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
BasicOp.cpp
GatherOp.cpp
ViewLikeOps.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
194 changes: 194 additions & 0 deletions lib/Conversion/TorchToMhlo/GatherOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"

#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
static constexpr size_t kMhloDimSizeBits = 32;
#else
static constexpr size_t kMhloDimSizeBits = 64;
#endif

namespace {
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
Value input, Value indices, int64_t axis) {
auto loc = op->getLoc();
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));

// sliceSizes
auto inputRankTy = input.getType().dyn_cast<RankedTensorType>();
auto inputRank = inputRankTy.getRank();
SmallVector<Value, 4> sliceSizes;
sliceSizes.reserve(inputRank);
for (int64_t r = 0; r < inputRank; ++r) {
if (r == axis) {
sliceSizes.push_back(one);
} else {
sliceSizes.push_back(rewriter.create<arith::IndexCastOp>(
loc, intType, rewriter.create<tensor::DimOp>(loc, input, r)));
}
}
auto sliceSizesTensor =
rewriter.create<tensor::FromElementsOp>(loc, sliceSizes);

// offsetDims
SmallVector<int64_t, 4> offsetDims;
offsetDims.reserve(inputRank);
for (int64_t r = 0; r < axis; ++r) {
offsetDims.push_back(r);
}
auto indicesRankTy = indices.getType().dyn_cast<RankedTensorType>();
auto indicesRank = indicesRankTy.getRank();
for (int64_t r = axis + 1; r < inputRank; ++r) {
offsetDims.push_back(r + indicesRank - 1);
}

// collapsedSliceDims
SmallVector<int64_t, 4> collapsedSliceDims(1, axis);
// startIndexMap
SmallVector<int64_t, 4> startIndexMap(1, axis);
// indexVecDim
int64_t indexVecDim = indicesRank;
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/offsetDims,
/*collapsedSliceDims=*/collapsedSliceDims,
/*startIndexMap=*/startIndexMap,
/*indexVecDim=*/indexVecDim);

// outputShape = input.shape[:axis] + indices.shape +
// input.shape[axis + 1:]
auto inputShape = inputRankTy.getShape();
auto indicesShape = indicesRankTy.getShape();
SmallVector<int64_t, 4> outputShape(inputShape.begin(),
inputShape.begin() + axis);
outputShape.insert(outputShape.end(), indicesShape.begin(),
indicesShape.end());
outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1,
inputShape.end());

// create output tensor type
auto outputTy =
RankedTensorType::get(outputShape, inputRankTy.getElementType());
return rewriter
.create<mhlo::DynamicGatherOp>(loc, outputTy, input, indices,
sliceSizesTensor, dimsAttr)
.getResult();
}

template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// padding_idx (int, optional)
// – If specified, the entries at padding_idx do not contribute to the gradient;
// therefore, the embedding vector at padding_idx is not updated during training,
// i.e. it remains as a fixed “pad”.
// scale_grad_by_freq (boolean, optional)
// – If given, this will scale gradients by the inverse of frequency of the
// words in the mini-batch. Default False.
// sparse (bool, optional)
// – If True, gradient w.r.t. weight matrix will be a sparse tensor.
template <>
LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
AtenEmbeddingOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto weight = adaptor.weight();
auto weightTy = weight.getType().template cast<RankedTensorType>();
if (!weightTy)
return op.emitError("only ranked tensor types are supported");

int64_t padding_idx;
if (!matchPattern(op.padding_idx(), m_TorchConstantInt(&padding_idx)))
return rewriter.notifyMatchFailure(
op, "only constant padding_idx is currently supported");

bool scale_grad_by_freq;
if (!matchPattern(op.scale_grad_by_freq(),
m_TorchConstantBool(&scale_grad_by_freq)))
return rewriter.notifyMatchFailure(
op, "only constant scale_grad_by_freq is currently supported");
if (scale_grad_by_freq)
return rewriter.notifyMatchFailure(
op, "scale gradients is currently not supported");
bool sparse;
if (!matchPattern(op.sparse(), m_TorchConstantBool(&sparse)))
return rewriter.notifyMatchFailure(
op, "only constant sparse is currently supported");
if (sparse)
return rewriter.notifyMatchFailure(
op, "sparse gradients is currently not supported");

Value output =
gatherTensorAlongSingleAxis(rewriter, op, weight, adaptor.indices(), 0);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output);

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
AtenIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("only ranked tensor types are supported");
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");

Value output =
gatherTensorAlongSingleAxis(rewriter, op, self, adaptor.index(), dim);

rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output);

return success();
}
} // namespace

void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();

#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
#undef INSERT_ATENOP_PATTERN
}
4 changes: 3 additions & 1 deletion lib/Conversion/TorchToMhlo/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
} // namespace torch_to_mhlo
} // namespace torch
} // namespace mlir
Expand Down
4 changes: 3 additions & 1 deletion lib/Conversion/TorchToMhlo/TorchToMhlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
target);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns,
target);

torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
target);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
Expand Down
66 changes: 66 additions & 0 deletions test/Conversion/TorchToMhlo/gather.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.index_select$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4],f32> -> tensor<?x4xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2],si64> -> tensor<2xi64>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<2x4xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.index_select %arg0, %int0, %arg1 : !torch.vtensor<[?,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,4],f32>
return %0 : !torch.vtensor<[2,4],f32>
}

// CHECK-LABEL: func.func @torch.aten.embedding$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],si64> -> tensor<?xi64>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%ret = torch.aten.embedding %weight, %indices, %int-1, %false, %false : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[?,?],f32>
return %ret: !torch.vtensor<[?,?],f32>
}

// CHECK-LABEL: func.func @torch.aten.embedding$rank_two_indices(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>) -> !torch.vtensor<[?,1,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x1x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>
func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%ret = torch.aten.embedding %weight, %indices, %int-1, %false, %false : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,1], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[?,1,?],f32>
return %ret: !torch.vtensor<[?,1,?],f32>
}

0 comments on commit 91e3206

Please sign in to comment.