diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index e1f544293045..130ec544aaaf 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo TorchToMhlo.cpp BasicOp.cpp + GatherOp.cpp ViewLikeOps.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Conversion/TorchToMhlo/GatherOp.cpp b/lib/Conversion/TorchToMhlo/GatherOp.cpp new file mode 100644 index 000000000000..05b38b9cca3d --- /dev/null +++ b/lib/Conversion/TorchToMhlo/GatherOp.cpp @@ -0,0 +1,194 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 +static constexpr size_t kMhloDimSizeBits = 32; +#else +static constexpr size_t kMhloDimSizeBits = 64; +#endif + +namespace { +Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, + Value input, Value indices, int64_t axis) { + auto loc = op->getLoc(); + Type intType = rewriter.getIntegerType(kMhloDimSizeBits); + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + + // sliceSizes + auto inputRankTy = input.getType().dyn_cast(); + auto inputRank = inputRankTy.getRank(); + SmallVector sliceSizes; + sliceSizes.reserve(inputRank); + for (int64_t r = 0; r < inputRank; ++r) { + if (r == axis) { + sliceSizes.push_back(one); + } else { + sliceSizes.push_back(rewriter.create( + loc, intType, rewriter.create(loc, input, r))); + } + } + auto sliceSizesTensor = + rewriter.create(loc, sliceSizes); + + // offsetDims + SmallVector offsetDims; + offsetDims.reserve(inputRank); + for (int64_t r = 0; r < axis; ++r) { + offsetDims.push_back(r); + } + auto indicesRankTy = indices.getType().dyn_cast(); + auto indicesRank = indicesRankTy.getRank(); + for (int64_t r = axis + 1; r < inputRank; ++r) { + offsetDims.push_back(r + indicesRank - 1); + } + + // collapsedSliceDims + SmallVector collapsedSliceDims(1, axis); + // startIndexMap + SmallVector startIndexMap(1, axis); + // indexVecDim + int64_t indexVecDim = indicesRank; + auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/offsetDims, + /*collapsedSliceDims=*/collapsedSliceDims, + /*startIndexMap=*/startIndexMap, + /*indexVecDim=*/indexVecDim); + + // outputShape = input.shape[:axis] + indices.shape + + // input.shape[axis + 1:] + auto inputShape = inputRankTy.getShape(); + auto indicesShape = indicesRankTy.getShape(); + SmallVector outputShape(inputShape.begin(), + inputShape.begin() + axis); + outputShape.insert(outputShape.end(), indicesShape.begin(), + indicesShape.end()); + outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1, + inputShape.end()); + + // create output tensor type + auto outputTy = + RankedTensorType::get(outputShape, inputRankTy.getElementType()); + return rewriter + .create(loc, outputTy, input, indices, + sliceSizesTensor, dimsAttr) + .getResult(); +} + +template +class ConvertAtenOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html +// padding_idx (int, optional) +// – If specified, the entries at padding_idx do not contribute to the gradient; +// therefore, the embedding vector at padding_idx is not updated during training, +// i.e. it remains as a fixed “pad”. +// scale_grad_by_freq (boolean, optional) +// – If given, this will scale gradients by the inverse of frequency of the +// words in the mini-batch. Default False. +// sparse (bool, optional) +// – If True, gradient w.r.t. weight matrix will be a sparse tensor. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenEmbeddingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto weight = adaptor.weight(); + auto weightTy = weight.getType().template cast(); + if (!weightTy) + return op.emitError("only ranked tensor types are supported"); + + int64_t padding_idx; + if (!matchPattern(op.padding_idx(), m_TorchConstantInt(&padding_idx))) + return rewriter.notifyMatchFailure( + op, "only constant padding_idx is currently supported"); + + bool scale_grad_by_freq; + if (!matchPattern(op.scale_grad_by_freq(), + m_TorchConstantBool(&scale_grad_by_freq))) + return rewriter.notifyMatchFailure( + op, "only constant scale_grad_by_freq is currently supported"); + if (scale_grad_by_freq) + return rewriter.notifyMatchFailure( + op, "scale gradients is currently not supported"); + bool sparse; + if (!matchPattern(op.sparse(), m_TorchConstantBool(&sparse))) + return rewriter.notifyMatchFailure( + op, "only constant sparse is currently supported"); + if (sparse) + return rewriter.notifyMatchFailure( + op, "sparse gradients is currently not supported"); + + Value output = + gatherTensorAlongSingleAxis(rewriter, op, weight, adaptor.indices(), 0); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), output); + + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexSelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("only ranked tensor types are supported"); + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); + + Value output = + gatherTensorAlongSingleAxis(rewriter, op, self, adaptor.index(), dim); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), output); + + return success(); +} +} // namespace + +void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); +#undef INSERT_ATENOP_PATTERN +} diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h index 97bb8602882d..35bca4019700 100644 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -22,7 +22,9 @@ void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); - +void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target); } // namespace torch_to_mhlo } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 2a052e17a0ec..890b426909c2 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -53,7 +53,9 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { target); torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns, target); - + torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns, + target); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { return signalPassFailure(); diff --git a/test/Conversion/TorchToMhlo/gather.mlir b/test/Conversion/TorchToMhlo/gather.mlir new file mode 100644 index 000000000000..a20b32d4994d --- /dev/null +++ b/test/Conversion/TorchToMhlo/gather.mlir @@ -0,0 +1,66 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.index_select$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2],si64> -> tensor<2xi64> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> +// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<2x4xf32> +// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32> +func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.index_select %arg0, %int0, %arg1 : !torch.vtensor<[?,4],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[2,4],f32> + return %0 : !torch.vtensor<[2,4],f32> +} + +// CHECK-LABEL: func.func @torch.aten.embedding$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],si64> -> tensor +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor +// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> { + %false = torch.constant.bool false + %int-1 = torch.constant.int -1 + %ret = torch.aten.embedding %weight, %indices, %int-1, %false, %false : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[?,?],f32> + return %ret: !torch.vtensor<[?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.embedding$rank_two_indices( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>) -> !torch.vtensor<[?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor +// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,1,?],f32> +// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32> +func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> { + %false = torch.constant.bool false + %int-1 = torch.constant.int -1 + %ret = torch.aten.embedding %weight, %indices, %int-1, %false, %false : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,1], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[?,1,?],f32> + return %ret: !torch.vtensor<[?,1,?],f32> +} +