From d86732fa1027fdda4e262ade3edd9f2622a488b9 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Tue, 2 Aug 2022 11:04:35 +0800 Subject: [PATCH 1/4] [MHLO] Init MHLO matmul op patterns See RFC https://github.com/llvm/torch-mlir/issues/999 Co-authored-by: Bairen Yi yibairen.byron@bytedance.com Co-authored-by: Jiawei Wu xremold@gmail.com Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.com --- lib/Conversion/TorchToMhlo/CMakeLists.txt | 1 + lib/Conversion/TorchToMhlo/MatmulOp.cpp | 652 ++++++++++++++++++ lib/Conversion/TorchToMhlo/PopulatePatterns.h | 4 + lib/Conversion/TorchToMhlo/TorchToMhlo.cpp | 5 +- test/Conversion/TorchToMhlo/matmul.mlir | 358 ++++++++++ 5 files changed, 1019 insertions(+), 1 deletion(-) create mode 100644 lib/Conversion/TorchToMhlo/MatmulOp.cpp create mode 100644 test/Conversion/TorchToMhlo/matmul.mlir diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 3c036e7ef8d8..859551a6a9cb 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo MhloLegalizeUtils.cpp BasicOp.cpp GatherOp.cpp + MatmulOp.cpp ViewLikeOps.cpp ReductionOp.cpp diff --git a/lib/Conversion/TorchToMhlo/MatmulOp.cpp b/lib/Conversion/TorchToMhlo/MatmulOp.cpp new file mode 100644 index 000000000000..c73dd6d998a2 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/MatmulOp.cpp @@ -0,0 +1,652 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace mlir { +namespace mhlo { +FailureOr getZeroRankTensor(PatternRewriter &rewriter, Operation *op, + Value tensor) { + auto rankTy = tensor.getType().dyn_cast(); + if (!rankTy) + return rewriter.notifyMatchFailure( + op, "can not reshape a tensor that is not ranked to 0-rank"); + + auto shape = rankTy.getShape(); + if (!(shape.size() == 1 && shape[0] == 1)) + return rewriter.notifyMatchFailure(op, "the shape must equal to [1]"); + + return rewriter + .create( + op->getLoc(), + RankedTensorType::get(ArrayRef{}, rankTy.getElementType()), + tensor) + .getResult(); +} + +Value getReshapedTensor(PatternRewriter &rewriter, Operation *op, Value tensor, + ArrayRef shape, ArrayRef dimSizes) { + // create mhlo::DynamicReshapeOp + auto loc = op->getLoc(); + auto tensorTy = tensor.getType().dyn_cast(); + auto outRankTy = RankedTensorType::get(shape, tensorTy.getElementType()); + Value mhloShape = rewriter.create(loc, dimSizes); + return rewriter.create(loc, outRankTy, tensor, + mhloShape); +} + +Value getExpandedTensor(PatternRewriter &rewriter, Operation *op, Value tensor, + ArrayRef expandDimSizes, int64_t expandPos) { + if (expandDimSizes.size() == 0) { + return tensor; + } + + auto tensorTy = tensor.getType().dyn_cast(); + auto dimSizes = *getDimSizesOfTensor(rewriter, op, tensor); + int64_t rank = dimSizes.size(); + expandPos = (expandPos + rank) % rank; + + std::vector newDimSizes; + std::vector newShape; + for (int64_t k = 0; k < rank; ++k) { + if (k == expandPos) { + newDimSizes.insert(newDimSizes.end(), expandDimSizes.begin(), + expandDimSizes.end()); + for (size_t j = 0; j < expandDimSizes.size(); ++j) { + newShape.push_back(ShapedType::kDynamicSize); + } + } else { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(tensorTy.getShape()[k]); + } + } + + return getReshapedTensor(rewriter, op, tensor, newShape, newDimSizes); +} + +Value getProductOfDimSizes(PatternRewriter &rewriter, Operation *op, + ArrayRef dimSizes) { + auto loc = op->getLoc(); + Type intTy = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); + auto prod = + rewriter.create(loc, rewriter.getIntegerAttr(intTy, 1)) + .getResult(); + + for (auto &d : dimSizes) { + prod = rewriter.create(loc, prod, d).getResult(); + } + return prod; +} + +FailureOr>> +getCollapsedTensor(PatternRewriter &rewriter, Operation *op, Value tensor, + ArrayRef inpCollapDims) { + // Ref to XLA:Collapse: + // https://www.tensorflow.org/xla/operation_semantics#collapse However we use + // high to low dimension indices. + // + // Collapse replaces the given subset of the operand's dimensions by a single + // dimension. The input arguments are an arbitrary array of type T and a + // compile-time-constant vector of dimension indices. The dimension indices + // must be an in-order (high to low dimension numbers), consecutive subset of + // T's dimensions. Thus, {0, 1, 2}, {0, 1}, or {1, 2} are all valid dimension + // sets, but {1, 0} or {0, 2} are not. + auto nCollaps = inpCollapDims.size(); + std::vector collapDimSizes; + if (nCollaps == 0) { + return std::make_tuple(tensor, collapDimSizes); + } + + // CHECK the input collapse dimensions are in-order, otherwise throw exception + auto tensorTy = tensor.getType().dyn_cast(); + size_t rank = tensorTy.getRank(); + auto collapDims = toPositiveDims(inpCollapDims, rank); + for (size_t k = 1; k < nCollaps; ++k) + if (collapDims[k] != collapDims[k - 1] + 1) + return rewriter.notifyMatchFailure( + op, "collapse dimensions are not in consecutive order"); + + // get original tensor shape in mlir standard dialect + auto dimSizes = *getDimSizesOfTensor(rewriter, op, tensor); + + // calculate the collapse new_dim, which build the graph in mlir standard + // dialect + for (auto k : collapDims) { + auto dsize = dimSizes[k]; + collapDimSizes.push_back(dsize); + } + + // gather the new dim size values + SmallVector newDimSizes; + SmallVector newShape; + for (size_t k = 0; k < collapDims[0]; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(tensorTy.getShape()[k]); + } + int64_t collapDimVal = 1; + for (size_t k = collapDims[0]; k < collapDims[nCollaps - 1] + 1; ++k) { + auto dsize = tensorTy.getShape()[k]; + if (dsize == ShapedType::kDynamicSize) { + collapDimVal = ShapedType::kDynamicSize; + break; + } + collapDimVal *= dsize; + } + newDimSizes.push_back(getProductOfDimSizes(rewriter, op, collapDimSizes)); + newShape.push_back(collapDimVal); + for (size_t k = collapDims[nCollaps - 1] + 1; k < rank; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(tensorTy.getShape()[k]); + } + + return std::make_tuple( + getReshapedTensor(rewriter, op, tensor, newShape, newDimSizes), + collapDimSizes); +} + +Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, + ArrayRef shape, ArrayRef dimSizes, + ArrayRef broadcastDims) { + auto tensorTy = tensor.getType().dyn_cast(); + auto loc = op->getLoc(); + Value mhloShape = rewriter.create(loc, dimSizes); + + RankedTensorType outTy = + RankedTensorType::get(shape, tensorTy.getElementType()); + + RankedTensorType attrTy = + RankedTensorType::get({static_cast(broadcastDims.size())}, + rewriter.getIntegerType(64)); + auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); + + auto broadcast = rewriter.create( + loc, outTy, tensor, mhloShape, broadcastAttr); + return broadcast; +} + +Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, + ArrayRef inpTransDims) { + auto inputTy = input.getType().dyn_cast(); + auto rank = inputTy.getRank(); + auto transDims = toPositiveDims(inpTransDims, rank); + auto inpShape = inputTy.getShape(); + std::vector newShape; + newShape.reserve(rank); + + for (auto d : transDims) { + newShape.push_back(inpShape[d]); + } + + auto attrTy = RankedTensorType::get({static_cast(transDims.size())}, + rewriter.getIntegerType(64)); + auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims); + + auto outTy = RankedTensorType::get(newShape, inputTy.getElementType()); + auto result = rewriter.create(op->getLoc(), outTy, input, + permuteAttr); + return result.getResult(); +} + +FailureOr getDotProduct(PatternRewriter &rewriter, Operation *op, + Value lhs, Value rhs, int64_t rank) { + if (rank < 2) + return rewriter.notifyMatchFailure( + op, "the input of DotProduct must has rank >= 2"); + + std::vector batchDims; + for (int64_t r = 0; r < rank - 2; ++r) { + batchDims.push_back(r); + } + auto lhsTy = lhs.getType().dyn_cast(); + auto rhsTy = rhs.getType().dyn_cast(); + + auto lhsShape = lhsTy.getShape(); + auto rhsShape = rhsTy.getShape(); + + // lhsShape[b, m, n], rhsShape[b', n', k] -> resultShape[b, m, k], + // assert b == b' and n == n', but we could only verify it at runtime + std::vector resultShape(lhsShape.begin(), lhsShape.end()); + resultShape[rank - 1] = rhsShape[rank - 1]; + + auto loc = op->getLoc(); + auto resultTy = RankedTensorType::get(resultShape, lhsTy.getElementType()); + auto dotDimAttr = mhlo::DotDimensionNumbersAttr::get( + op->getContext(), batchDims, batchDims, {rank - 1}, {rank - 2}); + auto result = rewriter.create( + loc, resultTy, lhs, rhs, dotDimAttr, /*precision_config*/ nullptr); + return result.getResult(); +} + +FailureOr getBmmDotProduct(PatternRewriter &rewriter, Operation *op, + Value inpLhs, Value inpRhs) { + Value lhs = inpLhs; + Value rhs = inpRhs; + auto lhsRankTy = inpLhs.getType().dyn_cast(); + auto rhsRankTy = inpRhs.getType().dyn_cast(); + + auto lhsRank = lhsRankTy.getRank(); + auto rhsRank = rhsRankTy.getRank(); + if (lhsRank < 2 || rhsRank < 2) + return rewriter.notifyMatchFailure( + op, "the input of batch-matmul must has rank >= 2"); + + // The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be + // broadcastable). + auto maxRank = std::max(lhsRank, rhsRank); + auto minRank = std::min(lhsRank, rhsRank); + if (maxRank != minRank) { + auto leadingRank = maxRank - minRank; + auto leadingDims = llvm::to_vector<4>(llvm::seq(0, leadingRank)); + auto broadcastDims = + llvm::to_vector<4>(llvm::seq(leadingRank, maxRank)); + auto lhsShape = lhsRankTy.getShape(); + auto rhsShape = rhsRankTy.getShape(); + if (lhsRank < rhsRank) { + std::vector newShape(rhsShape.begin(), + rhsShape.begin() + leadingRank); + newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); + auto newDimSizes = *getDimSizesOfTensor(rewriter, op, rhs, leadingDims); + auto lhsDimSizes = *getDimSizesOfTensor(rewriter, op, lhs); + newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), + lhsDimSizes.end()); + lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, + broadcastDims); + } else { + std::vector newShape(lhsShape.begin(), + lhsShape.begin() + leadingRank); + newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); + auto newDimSizes = *getDimSizesOfTensor(rewriter, op, lhs, leadingDims); + auto rhsDimSizes = *getDimSizesOfTensor(rewriter, op, rhs); + newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), + rhsDimSizes.end()); + rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, + broadcastDims); + } + } + + // [?, ?, m, n] x [?, n, k] ==> batch_matmul([m,n], [n,k]) + return getDotProduct(rewriter, op, lhs, rhs, /*rank*/ maxRank); +} + +FailureOr getMmDotProduct(PatternRewriter &rewriter, Operation *op, + Value inpLhs, Value inpRhs) { + auto lhsRankTy = inpLhs.getType().dyn_cast(); + auto rhsRankTy = inpRhs.getType().dyn_cast(); + + auto lhsRank = lhsRankTy.getRank(); + auto rhsRank = rhsRankTy.getRank(); + if (lhsRank < 2) + return rewriter.notifyMatchFailure( + op, "the left hand-side input of matmul must has rank >= 2"); + if (rhsRank != 2) + return rewriter.notifyMatchFailure( + op, "the right hand-side input of matmul must has rank == 2"); + + Value lhs = inpLhs; + Value rhs = inpRhs; + // [?, m, n] x [n, k] ==> [?xm, n] x [n, k] + std::vector collapDimSizes; + if (lhsRank > 2) { + std::vector collapDims; + for (int64_t d = 0; d < lhsRank - 1; ++d) { + collapDims.push_back(d); + } + auto collapDimSizesInfo = getCollapsedTensor(rewriter, op, lhs, collapDims); + if (failed(collapDimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to construct matrix-matrix multiply"); + std::tie(lhs, collapDimSizes) = *collapDimSizesInfo; + } + auto result = getDotProduct(rewriter, op, lhs, rhs, /*rank*/ 2); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to construct matrix-matrix multiply"); + + return getExpandedTensor(rewriter, op, *result, collapDimSizes, + /*expandPos*/ 0); +} + +FailureOr getMvDotProduct(PatternRewriter &rewriter, Operation *op, + Value inpLhs, Value inpRhs) { + auto lhsRankTy = inpLhs.getType().dyn_cast(); + auto rhsRankTy = inpRhs.getType().dyn_cast(); + + auto lhsRank = lhsRankTy.getRank(); + auto rhsRank = rhsRankTy.getRank(); + + if (rhsRank != 1) + return rewriter.notifyMatchFailure( + op, "the right hand-side input of matmul must has rank == 1"); + if (lhsRank < 2) + return rewriter.notifyMatchFailure( + op, "the left hand-side input of matmul must has rank >= 2"); + + auto unsqzRhsInfo = mhlo::unsqueezeTensor(rewriter, op, inpRhs, {1}); + if (failed(unsqzRhsInfo)) + return rewriter.notifyMatchFailure( + op, "failed to unsqueeze right hand-side input to rank 2"); + + auto unsqzRhs = *unsqzRhsInfo; + auto product = getMmDotProduct(rewriter, op, inpLhs, unsqzRhs); + if (failed(product)) + return rewriter.notifyMatchFailure( + op, "failed to construct matrix-vector multiply"); + Value result = *product; + std::vector collapDimSizes; + auto collapDimSizesInfo = getCollapsedTensor(rewriter, op, result, {-2, -1}); + if (failed(collapDimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to construct matrix-vector multiply"); + std::tie(result, collapDimSizes) = *collapDimSizesInfo; + return result; +} +} // namespace mhlo +} // namespace mlir + +namespace { +// Perform the basic n-dim matmul operation encompassing the handling of +// broadcasting and dynamic shape propagation. +// All PyTorch ops that leverage matrix multiplication will derive this and +// implement their specialized input processing (e.g transpose), and output +// processing, e.g. GEMM or fully connected bias handling. +template +class ConvertAtenMatmulBaseOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + // Each variant must implement corresponding parameter parsing options. + // Maintain separate input read functions for each variant because it is not + // necessarily true with all variants that the first two operands are the lhs + // and rhs. + virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const { + return rewriter.notifyMatchFailure( + op, + "unimplemented matrix multiplication variant input parsing function"); + } + LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &lhs, + Value &rhs, Value &output) const { + auto lhsTy = lhs.getType().cast(); + auto rhsTy = rhs.getType().cast(); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + auto lhsElemTy = lhsTy.getElementType(); + auto rhsElemTy = rhsTy.getElementType(); + + if (lhsElemTy != rhsElemTy) + return op.emitError("matmul: input datatypes mismatched"); + if (lhsRank < 1 || rhsRank < 1) { + return op.emitError("matmul: inputs can't be 0-rank"); + } + + FailureOr product; + if (rhsRank == 1) { + if (lhsRank == 1) { + // If both tensors are 1-dimensional, the dot product (scalar) is + // returned. + auto unsqzLhs = mhlo::unsqueezeTensor(rewriter, op, lhs, {0}); + product = mhlo::getMvDotProduct(rewriter, op, *unsqzLhs, rhs); + product = mhlo::getZeroRankTensor(rewriter, op, *product); + } else { + // If the first argument is 2-dimensional and the second argument is + // 1-dimensional, the matrix-vector product is returned. + // NB: if lhsRank > 2 reshape it to rank 2. + product = mhlo::getMvDotProduct(rewriter, op, lhs, rhs); + } + } else if (rhsRank == 2) { + if (lhsRank == 1) { + // If the first argument is 1-dimensional, a 1 is prepended to its + // dimension for the purpose of the batched matrix multiply and removed + // after. + auto unsqzLhs = mhlo::unsqueezeTensor(rewriter, op, lhs, {0}); + product = mhlo::getMmDotProduct(rewriter, op, *unsqzLhs, rhs); + auto collapDimSizesInfo = + mhlo::getCollapsedTensor(rewriter, op, *product, {-2, -1}); + if (failed(collapDimSizesInfo)) + return op.emitError("failed to construct matrix-vector multiply"); + + std::vector collapDimSizes; + std::tie(product, collapDimSizes) = *collapDimSizesInfo; + } else { + // If both arguments are 2-dimensional, the matrix-matrix product is + // returned. NB: if lhsRank > 2 reshape it to rank 2. + product = mhlo::getMmDotProduct(rewriter, op, lhs, rhs); + } + } else { + // rhsRank > 2 + if (lhsRank == 1) { + // If the first argument is 1-dimensional, a 1 is prepended to its + // dimension for the purpose of the batched matrix multiply and removed + // after. + auto unsqzLhs = mhlo::unsqueezeTensor(rewriter, op, lhs, {0}); + product = mhlo::getBmmDotProduct(rewriter, op, *unsqzLhs, rhs); + auto collapDimSizesInfo = + mhlo::getCollapsedTensor(rewriter, op, *product, {-2, -1}); + if (failed(collapDimSizesInfo)) + return op.emitError("failed to construct matrix-vector multiply"); + + std::vector collapDimSizes; + std::tie(product, collapDimSizes) = *collapDimSizesInfo; + } else { + product = mhlo::getBmmDotProduct(rewriter, op, lhs, rhs); + } + } + if (failed(product)) + return op.emitError("matmul: conversion failed"); + output = *product; + return success(); + } + + // The default version just reads two inputs, computes output and returns it. + // Other versions may add a bias, apply GEMM-style alpha/beta scaling etc. + virtual LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs, rhs; + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + return op.emitError("failed to read matmul inputs"); + + Value output; + + if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output))) + return op.emitError("failed to perform matmul operation"); + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(), + output); + + return success(); + } +}; + +// Legalizes the torch.matmul op for general n-dim matmul. +template +class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.other(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError( + "only ranked tensor types are supported in MHLO matmul"); + + return success(); + } +}; + +// Implements handling of aten.mm and aten.bmm ops. +template +class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.mat2(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError( + "only ranked tensor types are supported in MHLO matmul"); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + if (isa(op)) { + // Mm takes two 2D tensors. + if (lhsRank != 2 || rhsRank != 2) + return op.emitError("aten.mm called but matrix rank != 2"); + } else if (isa(op)) { + // Bmm takes two 3D tensors. + if (lhsRank != 3 || rhsRank != 3) + return op.emitError("aten.bmm called but matrix rank != 3"); + } + + return success(); + } +}; + +// Implements handling of aten.linear op. +template +class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.input(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.weight(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError( + "only ranked tensor types are supported in MHLO matmul"); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + if (lhsRank != 2 && lhsRank != 3) + return op.emitError("aten.Linear called but input rank not 2 or 3"); + if (rhsRank != 2 && rhsRank != 3) + return op.emitError("aten.Linear called but weight rank not 2 or 3"); + + return success(); + } + // Override the default rewriter to perform RHS transpose and bias addition + // as well. + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs, rhs; + + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + return op.emitError("failed to read matmul inputs"); + + // The aten.Linear op has a bias tensor that is added to the matmul + // output. + auto bias = adaptor.bias(); + auto biasTy = bias.getType(); + + // MHLO does not mandate that elementwise op tensors need to be ranked. + if (!biasTy.template isa() && + !biasTy.template isa()) + return op.emitError("only ranked tensor types are supported in MHLO " + "matmul for bias tensor"); + + // weight.T + auto weightT = mhlo::getPermutedTensor(rewriter, op, rhs, {1, 0}); + auto product = mhlo::getMmDotProduct(rewriter, op, lhs, weightT); + if (failed(product)) + return op.emitError("failed to perform matmul operation"); + + Value matmulOutput = *product; + Value matmulPlusBias = matmulOutput; + + if (!biasTy.template isa()) { + // Bias addition broadcasts to the matmul output shape. + matmulPlusBias = rewriter + .create( + op->getLoc(), matmulOutput.getType(), + matmulOutput, bias, nullptr) + .getResult(); + } + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + matmulPlusBias); + return success(); + } +}; + +} // namespace + +void mlir::torch::torch_to_mhlo::populateMatmulOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + +#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); +#undef INSERT_MATMUL_ATEMOP_PATTERN + +#define INSERT_MM_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); +#undef INSERT_MM_ATEMOP_PATTERN + +#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); +#undef INSERT_LINEAR_ATEMOP_PATTERN +} diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h index c84cec6380c6..bded499f051c 100644 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -28,6 +28,10 @@ void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); +void populateMatmulOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target); + } // namespace torch_to_mhlo } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 5007a8a26d33..bb5a90e05781 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -11,6 +11,7 @@ #include "../PassDetail.h" #include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" @@ -60,6 +61,8 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { target); torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter, patterns, target); + torch_to_mhlo::populateMatmulOpPatternsAndLegality(typeConverter, patterns, + target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -73,4 +76,4 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { std::unique_ptr> mlir::torch::createConvertTorchToMhloPass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/test/Conversion/TorchToMhlo/matmul.mlir b/test/Conversion/TorchToMhlo/matmul.mlir new file mode 100644 index 000000000000..de2ef0ad3c3b --- /dev/null +++ b/test/Conversion/TorchToMhlo/matmul.mlir @@ -0,0 +1,358 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.mm$basic$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<2x3xf32> +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32> +func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32> -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> +} + +// CHECK-LABEL: func.func @torch.aten.mm$basic$dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<3x?xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.bmm$basic$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[10,3,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<10x3x5xf32> +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[10,3,5],f32> +func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[10,3,4],f32>, !torch.vtensor<[10,4,5],f32> -> !torch.vtensor<[10,3,5],f32> + return %0 : !torch.vtensor<[10,3,5],f32> +} + +// CHECK-LABEL: func.func @torch.aten.bmm$basic$dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,4],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor +// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg1: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,4],f32>, !torch.vtensor<[?,4,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.matmul$basic$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256,120],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256,120],f32> -> tensor<256x120xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,120,256],f32> -> tensor<4x120x256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<4x120x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256x120xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x256x256xf32> +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32> +func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, %arg1: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256,120],f32>, !torch.vtensor<[4,120,256],f32> -> !torch.vtensor<[4,256,256],f32> + return %0 : !torch.vtensor<[4,256,256],f32> +} + +// CHECK-LABEL: func.func @torch.aten.matmul$basic$dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,?,256],f32> -> tensor<4x?x256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x?x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x?x256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x?x256xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T8:.*]] = arith.muli %[[C1_I64]], %[[T3]] : i64 +// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 +// CHECK: %[[T10:.*]] = tensor.from_elements %[[T9]], %[[T7]] : tensor<2xi64> +// CHECK: %[[T11:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor<4x?x256xf32>, tensor<2xi64>) -> tensor +// CHECK: %[[T12:.*]] = "mhlo.dot_general"(%[[T11]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x?xf32>) -> tensor +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T13:.*]] = tensor.dim %[[T12]], %[[C0_0]] : tensor +// CHECK: %[[T14:.*]] = arith.index_cast %[[T13]] : index to i64 +// CHECK: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[T15:.*]] = tensor.dim %[[T12]], %[[C1_1]] : tensor +// CHECK: %[[T16:.*]] = arith.index_cast %[[T15]] : index to i64 +// CHECK: %[[T17:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T16]] : tensor<3xi64> +// CHECK: %[[T18:.*]] = "mhlo.dynamic_reshape"(%[[T12]], %[[T17]]) : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T19:.*]] = mhlo.convert(%[[T18]]) : (tensor) -> tensor<4x?x?xf32> +// CHECK: %[[T20:.*]] = torch_c.from_builtin_tensor %[[T19]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> +// CHECK: return %[[T20]] : !torch.vtensor<[4,?,?],f32> +func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[4,?,?],f32> + return %0 : !torch.vtensor<[4,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.matmul$3dx1d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,?,256],f32> -> tensor<1x?x256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]], %[[C1_I64]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T1]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<1x?x256xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T8:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<1x?x256xf32> +// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T10:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<1x?x256xf32> +// CHECK: %[[T11:.*]] = arith.index_cast %[[T10]] : index to i64 +// CHECK: %[[C1_I64_1:.*]] = arith.constant 1 : i64 +// CHECK: %[[T12:.*]] = arith.muli %[[C1_I64_1]], %[[T7]] : i64 +// CHECK: %[[T13:.*]] = arith.muli %[[T12]], %[[T9]] : i64 +// CHECK: %[[T14:.*]] = tensor.from_elements %[[T13]], %[[T11]] : tensor<2xi64> +// CHECK: %[[T15:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T14]]) : (tensor<1x?x256xf32>, tensor<2xi64>) -> tensor +// CHECK: %[[T16:.*]] = "mhlo.dot_general"(%[[T15]], %[[T5]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x1xf32>) -> tensor +// CHECK: %[[C0_2:.*]] = arith.constant 0 : index +// CHECK: %[[T17:.*]] = tensor.dim %[[T16]], %[[C0_2]] : tensor +// CHECK: %[[T18:.*]] = arith.index_cast %[[T17]] : index to i64 +// CHECK: %[[C1_3:.*]] = arith.constant 1 : index +// CHECK: %[[T19:.*]] = tensor.dim %[[T16]], %[[C1_3]] : tensor +// CHECK: %[[T20:.*]] = arith.index_cast %[[T19]] : index to i64 +// CHECK: %[[T21:.*]] = tensor.from_elements %[[T7]], %[[T9]], %[[T20]] : tensor<3xi64> +// CHECK: %[[T22:.*]] = "mhlo.dynamic_reshape"(%[[T16]], %[[T21]]) : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[C0_4:.*]] = arith.constant 0 : index +// CHECK: %[[T23:.*]] = tensor.dim %[[T22]], %[[C0_4]] : tensor +// CHECK: %[[T24:.*]] = arith.index_cast %[[T23]] : index to i64 +// CHECK: %[[C1_5:.*]] = arith.constant 1 : index +// CHECK: %[[T25:.*]] = tensor.dim %[[T22]], %[[C1_5]] : tensor +// CHECK: %[[T26:.*]] = arith.index_cast %[[T25]] : index to i64 +// CHECK: %[[C2_6:.*]] = arith.constant 2 : index +// CHECK: %[[T27:.*]] = tensor.dim %[[T22]], %[[C2_6]] : tensor +// CHECK: %[[T28:.*]] = arith.index_cast %[[T27]] : index to i64 +// CHECK: %[[C1_I64_7:.*]] = arith.constant 1 : i64 +// CHECK: %[[T29:.*]] = arith.muli %[[C1_I64_7]], %[[T26]] : i64 +// CHECK: %[[T30:.*]] = arith.muli %[[T29]], %[[T28]] : i64 +// CHECK: %[[T31:.*]] = tensor.from_elements %[[T24]], %[[T30]] : tensor<2xi64> +// CHECK: %[[T32:.*]] = "mhlo.dynamic_reshape"(%[[T22]], %[[T31]]) : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T33:.*]] = mhlo.convert(%[[T32]]) : (tensor) -> tensor<1x?xf32> +// CHECK: %[[T34:.*]] = torch_c.from_builtin_tensor %[[T33]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> +// CHECK: return %[[T34]] : !torch.vtensor<[1,?],f32> +func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[1,?],f32> + return %0 : !torch.vtensor<[1,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.matmul$1dx3d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[T8:.*]] = tensor.dim %[[T5]], %[[C0_1]] : tensor<1x256xf32> +// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T10:.*]] = tensor.dim %[[T5]], %[[C1]] : tensor<1x256xf32> +// CHECK: %[[T11:.*]] = arith.index_cast %[[T10]] : index to i64 +// CHECK: %[[T12:.*]] = tensor.from_elements %[[T7]], %[[T9]], %[[T11]] : tensor<3xi64> +// CHECK: %[[T13:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T5]], %[[T12]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x256xf32>, tensor<3xi64>) -> tensor +// CHECK: %[[T14:.*]] = "mhlo.dot_general"(%[[T13]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[C0_2:.*]] = arith.constant 0 : index +// CHECK: %[[T15:.*]] = tensor.dim %[[T14]], %[[C0_2]] : tensor +// CHECK: %[[T16:.*]] = arith.index_cast %[[T15]] : index to i64 +// CHECK: %[[C1_3:.*]] = arith.constant 1 : index +// CHECK: %[[T17:.*]] = tensor.dim %[[T14]], %[[C1_3]] : tensor +// CHECK: %[[T18:.*]] = arith.index_cast %[[T17]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T19:.*]] = tensor.dim %[[T14]], %[[C2]] : tensor +// CHECK: %[[T20:.*]] = arith.index_cast %[[T19]] : index to i64 +// CHECK: %[[C1_I64_4:.*]] = arith.constant 1 : i64 +// CHECK: %[[T21:.*]] = arith.muli %[[C1_I64_4]], %[[T18]] : i64 +// CHECK: %[[T22:.*]] = arith.muli %[[T21]], %[[T20]] : i64 +// CHECK: %[[T23:.*]] = tensor.from_elements %[[T16]], %[[T22]] : tensor<2xi64> +// CHECK: %[[T24:.*]] = "mhlo.dynamic_reshape"(%[[T14]], %[[T23]]) : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T25:.*]] = mhlo.convert %[[T24]] : tensor +// CHECK: %[[T26:.*]] = torch_c.from_builtin_tensor %[[T25]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T26]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[?,256,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.matmul$2dx1d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]], %[[C1_I64]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T1]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> +// CHECK: %[[T6:.*]] = "mhlo.dot_general"(%[[T0]], %[[T5]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x1xf32>) -> tensor +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T6]], %[[C0_0]] : tensor +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T9:.*]] = tensor.dim %[[T6]], %[[C1]] : tensor +// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : index to i64 +// CHECK: %[[C1_I64_1:.*]] = arith.constant 1 : i64 +// CHECK: %[[T11:.*]] = arith.muli %[[C1_I64_1]], %[[T8]] : i64 +// CHECK: %[[T12:.*]] = arith.muli %[[T11]], %[[T10]] : i64 +// CHECK: %[[T13:.*]] = tensor.from_elements %[[T12]] : tensor<1xi64> +// CHECK: %[[T14:.*]] = "mhlo.dynamic_reshape"(%[[T6]], %[[T13]]) : (tensor, tensor<1xi64>) -> tensor +// CHECK: %[[T15:.*]] = mhlo.convert %[[T14]] : tensor +// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[T16]] : !torch.vtensor<[?],f32> +func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.matmul$1dx2d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[T6:.*]] = "mhlo.dot_general"(%[[T5]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x256xf32>, tensor<256x?xf32>) -> tensor<1x?xf32> +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T7:.*]] = tensor.dim %[[T6]], %[[C0_0]] : tensor<1x?xf32> +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T9:.*]] = tensor.dim %[[T6]], %[[C1]] : tensor<1x?xf32> +// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : index to i64 +// CHECK: %[[C1_I64_1:.*]] = arith.constant 1 : i64 +// CHECK: %[[T11:.*]] = arith.muli %[[C1_I64_1]], %[[T8]] : i64 +// CHECK: %[[T12:.*]] = arith.muli %[[T11]], %[[T10]] : i64 +// CHECK: %[[T13:.*]] = tensor.from_elements %[[T12]] : tensor<1xi64> +// CHECK: %[[T14:.*]] = "mhlo.dynamic_reshape"(%[[T6]], %[[T13]]) : (tensor<1x?xf32>, tensor<1xi64>) -> tensor +// CHECK: %[[T15:.*]] = mhlo.convert %[[T14]] : tensor +// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[T16]] : !torch.vtensor<[?],f32> +func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.matmul$1dx1d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[C1_I64_1:.*]] = arith.constant 1 : i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T7]], %[[C1_I64_1]] : tensor<2xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_reshape"(%[[T1]], %[[T8]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T5]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[C0_2:.*]] = arith.constant 0 : index +// CHECK: %[[T11:.*]] = tensor.dim %[[T10]], %[[C0_2]] : tensor<1x1xf32> +// CHECK: %[[T12:.*]] = arith.index_cast %[[T11]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T13:.*]] = tensor.dim %[[T10]], %[[C1]] : tensor<1x1xf32> +// CHECK: %[[T14:.*]] = arith.index_cast %[[T13]] : index to i64 +// CHECK: %[[C1_I64_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[T15:.*]] = arith.muli %[[C1_I64_3]], %[[T12]] : i64 +// CHECK: %[[T16:.*]] = arith.muli %[[T15]], %[[T14]] : i64 +// CHECK: %[[T17:.*]] = tensor.from_elements %[[T16]] : tensor<1xi64> +// CHECK: %[[T18:.*]] = "mhlo.dynamic_reshape"(%[[T10]], %[[T17]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T19:.*]] = "mhlo.reshape"(%[[T18]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T20:.*]] = mhlo.convert %[[T19]] : tensor +// CHECK: %[[T21:.*]] = torch_c.from_builtin_tensor %[[T20]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[T21]] : !torch.vtensor<[],f32> +func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// CHECK-LABEL: func.func @torch.aten.matmul$proj( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T8:.*]] = arith.muli %[[C1_I64]], %[[T3]] : i64 +// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 +// CHECK: %[[T10:.*]] = tensor.from_elements %[[T9]], %[[T7]] : tensor<2xi64> +// CHECK: %[[T11:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T12:.*]] = "mhlo.dot_general"(%[[T11]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x256xf32>) -> tensor +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T13:.*]] = tensor.dim %[[T12]], %[[C0_0]] : tensor +// CHECK: %[[T14:.*]] = arith.index_cast %[[T13]] : index to i64 +// CHECK: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[T15:.*]] = tensor.dim %[[T12]], %[[C1_1]] : tensor +// CHECK: %[[T16:.*]] = arith.index_cast %[[T15]] : index to i64 +// CHECK: %[[T17:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T16]] : tensor<3xi64> +// CHECK: %[[T18:.*]] = "mhlo.dynamic_reshape"(%[[T12]], %[[T17]]) : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T19:.*]] = mhlo.convert %[[T18]] : tensor +// CHECK: %[[T20:.*]] = torch_c.from_builtin_tensor %[[T19]] : tensor -> !torch.vtensor<[?,?,256],f32> +// CHECK: return %[[T20]] : !torch.vtensor<[?,?,256],f32> +func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { + %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32> + %1 = torch.aten.matmul %arg0, %0 : !torch.vtensor<[?,?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,?,256],f32> + return %1 : !torch.vtensor<[?,?,256],f32> +} + +// CHECK-LABEL: func.func @torch.aten.mm$proj( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x256xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,256],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32> +func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { + %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32> + %1 = torch.aten.mm %arg0, %0 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,256],f32> + return %1 : !torch.vtensor<[?,256],f32> +} From f70ea508ab4807920ee72f6e171c7b4a594e7bcc Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Wed, 3 Aug 2022 11:57:24 +0800 Subject: [PATCH 2/4] Rename to Linear --- lib/Conversion/TorchToMhlo/CMakeLists.txt | 2 +- lib/Conversion/TorchToMhlo/{MatmulOp.cpp => Linear.cpp} | 2 +- lib/Conversion/TorchToMhlo/PopulatePatterns.h | 2 +- lib/Conversion/TorchToMhlo/TorchToMhlo.cpp | 2 +- test/Conversion/TorchToMhlo/{matmul.mlir => linear.mlir} | 0 5 files changed, 4 insertions(+), 4 deletions(-) rename lib/Conversion/TorchToMhlo/{MatmulOp.cpp => Linear.cpp} (99%) rename test/Conversion/TorchToMhlo/{matmul.mlir => linear.mlir} (100%) diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 859551a6a9cb..47126ab8d7c3 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -3,7 +3,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo MhloLegalizeUtils.cpp BasicOp.cpp GatherOp.cpp - MatmulOp.cpp + Linear.cpp ViewLikeOps.cpp ReductionOp.cpp diff --git a/lib/Conversion/TorchToMhlo/MatmulOp.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp similarity index 99% rename from lib/Conversion/TorchToMhlo/MatmulOp.cpp rename to lib/Conversion/TorchToMhlo/Linear.cpp index c73dd6d998a2..e7809b28bbfa 100644 --- a/lib/Conversion/TorchToMhlo/MatmulOp.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -626,7 +626,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { } // namespace -void mlir::torch::torch_to_mhlo::populateMatmulOpPatternsAndLegality( +void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h index bded499f051c..2ff569cd0cce 100644 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -28,7 +28,7 @@ void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); -void populateMatmulOpPatternsAndLegality(TypeConverter &typeConverter, +void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index bb5a90e05781..a8314c7ccc90 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -61,7 +61,7 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { target); torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter, patterns, target); - torch_to_mhlo::populateMatmulOpPatternsAndLegality(typeConverter, patterns, + torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns, target); if (failed(applyPartialConversion(getOperation(), target, diff --git a/test/Conversion/TorchToMhlo/matmul.mlir b/test/Conversion/TorchToMhlo/linear.mlir similarity index 100% rename from test/Conversion/TorchToMhlo/matmul.mlir rename to test/Conversion/TorchToMhlo/linear.mlir From b15d128ff51be8e0a18d3b963f9e122f506cae66 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Wed, 3 Aug 2022 16:36:18 +0800 Subject: [PATCH 3/4] use mhlo::dot and mhlo::dot_general directly --- lib/Conversion/TorchToMhlo/Linear.cpp | 459 ++++++------------------ test/Conversion/TorchToMhlo/linear.mlir | 306 ++++++---------- 2 files changed, 214 insertions(+), 551 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index e7809b28bbfa..d9475e8efa8b 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -26,147 +26,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -namespace mlir { -namespace mhlo { -FailureOr getZeroRankTensor(PatternRewriter &rewriter, Operation *op, - Value tensor) { - auto rankTy = tensor.getType().dyn_cast(); - if (!rankTy) - return rewriter.notifyMatchFailure( - op, "can not reshape a tensor that is not ranked to 0-rank"); - - auto shape = rankTy.getShape(); - if (!(shape.size() == 1 && shape[0] == 1)) - return rewriter.notifyMatchFailure(op, "the shape must equal to [1]"); - - return rewriter - .create( - op->getLoc(), - RankedTensorType::get(ArrayRef{}, rankTy.getElementType()), - tensor) - .getResult(); -} - -Value getReshapedTensor(PatternRewriter &rewriter, Operation *op, Value tensor, - ArrayRef shape, ArrayRef dimSizes) { - // create mhlo::DynamicReshapeOp - auto loc = op->getLoc(); - auto tensorTy = tensor.getType().dyn_cast(); - auto outRankTy = RankedTensorType::get(shape, tensorTy.getElementType()); - Value mhloShape = rewriter.create(loc, dimSizes); - return rewriter.create(loc, outRankTy, tensor, - mhloShape); -} - -Value getExpandedTensor(PatternRewriter &rewriter, Operation *op, Value tensor, - ArrayRef expandDimSizes, int64_t expandPos) { - if (expandDimSizes.size() == 0) { - return tensor; - } - - auto tensorTy = tensor.getType().dyn_cast(); - auto dimSizes = *getDimSizesOfTensor(rewriter, op, tensor); - int64_t rank = dimSizes.size(); - expandPos = (expandPos + rank) % rank; - - std::vector newDimSizes; - std::vector newShape; - for (int64_t k = 0; k < rank; ++k) { - if (k == expandPos) { - newDimSizes.insert(newDimSizes.end(), expandDimSizes.begin(), - expandDimSizes.end()); - for (size_t j = 0; j < expandDimSizes.size(); ++j) { - newShape.push_back(ShapedType::kDynamicSize); - } - } else { - newDimSizes.push_back(dimSizes[k]); - newShape.push_back(tensorTy.getShape()[k]); - } - } - - return getReshapedTensor(rewriter, op, tensor, newShape, newDimSizes); -} - -Value getProductOfDimSizes(PatternRewriter &rewriter, Operation *op, - ArrayRef dimSizes) { - auto loc = op->getLoc(); - Type intTy = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); - auto prod = - rewriter.create(loc, rewriter.getIntegerAttr(intTy, 1)) - .getResult(); - - for (auto &d : dimSizes) { - prod = rewriter.create(loc, prod, d).getResult(); - } - return prod; -} - -FailureOr>> -getCollapsedTensor(PatternRewriter &rewriter, Operation *op, Value tensor, - ArrayRef inpCollapDims) { - // Ref to XLA:Collapse: - // https://www.tensorflow.org/xla/operation_semantics#collapse However we use - // high to low dimension indices. - // - // Collapse replaces the given subset of the operand's dimensions by a single - // dimension. The input arguments are an arbitrary array of type T and a - // compile-time-constant vector of dimension indices. The dimension indices - // must be an in-order (high to low dimension numbers), consecutive subset of - // T's dimensions. Thus, {0, 1, 2}, {0, 1}, or {1, 2} are all valid dimension - // sets, but {1, 0} or {0, 2} are not. - auto nCollaps = inpCollapDims.size(); - std::vector collapDimSizes; - if (nCollaps == 0) { - return std::make_tuple(tensor, collapDimSizes); - } - - // CHECK the input collapse dimensions are in-order, otherwise throw exception - auto tensorTy = tensor.getType().dyn_cast(); - size_t rank = tensorTy.getRank(); - auto collapDims = toPositiveDims(inpCollapDims, rank); - for (size_t k = 1; k < nCollaps; ++k) - if (collapDims[k] != collapDims[k - 1] + 1) - return rewriter.notifyMatchFailure( - op, "collapse dimensions are not in consecutive order"); - - // get original tensor shape in mlir standard dialect - auto dimSizes = *getDimSizesOfTensor(rewriter, op, tensor); - - // calculate the collapse new_dim, which build the graph in mlir standard - // dialect - for (auto k : collapDims) { - auto dsize = dimSizes[k]; - collapDimSizes.push_back(dsize); - } - - // gather the new dim size values - SmallVector newDimSizes; - SmallVector newShape; - for (size_t k = 0; k < collapDims[0]; ++k) { - newDimSizes.push_back(dimSizes[k]); - newShape.push_back(tensorTy.getShape()[k]); - } - int64_t collapDimVal = 1; - for (size_t k = collapDims[0]; k < collapDims[nCollaps - 1] + 1; ++k) { - auto dsize = tensorTy.getShape()[k]; - if (dsize == ShapedType::kDynamicSize) { - collapDimVal = ShapedType::kDynamicSize; - break; - } - collapDimVal *= dsize; - } - newDimSizes.push_back(getProductOfDimSizes(rewriter, op, collapDimSizes)); - newShape.push_back(collapDimVal); - for (size_t k = collapDims[nCollaps - 1] + 1; k < rank; ++k) { - newDimSizes.push_back(dimSizes[k]); - newShape.push_back(tensorTy.getShape()[k]); - } - - return std::make_tuple( - getReshapedTensor(rewriter, op, tensor, newShape, newDimSizes), - collapDimSizes); -} - +namespace { Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef shape, ArrayRef dimSizes, ArrayRef broadcastDims) { @@ -191,7 +51,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, ArrayRef inpTransDims) { auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); - auto transDims = toPositiveDims(inpTransDims, rank); + auto transDims = mhlo::toPositiveDims(inpTransDims, rank); auto inpShape = inputTy.getShape(); std::vector newShape; newShape.reserve(rank); @@ -210,38 +70,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, return result.getResult(); } -FailureOr getDotProduct(PatternRewriter &rewriter, Operation *op, - Value lhs, Value rhs, int64_t rank) { - if (rank < 2) - return rewriter.notifyMatchFailure( - op, "the input of DotProduct must has rank >= 2"); - - std::vector batchDims; - for (int64_t r = 0; r < rank - 2; ++r) { - batchDims.push_back(r); - } - auto lhsTy = lhs.getType().dyn_cast(); - auto rhsTy = rhs.getType().dyn_cast(); - - auto lhsShape = lhsTy.getShape(); - auto rhsShape = rhsTy.getShape(); - - // lhsShape[b, m, n], rhsShape[b', n', k] -> resultShape[b, m, k], - // assert b == b' and n == n', but we could only verify it at runtime - std::vector resultShape(lhsShape.begin(), lhsShape.end()); - resultShape[rank - 1] = rhsShape[rank - 1]; - - auto loc = op->getLoc(); - auto resultTy = RankedTensorType::get(resultShape, lhsTy.getElementType()); - auto dotDimAttr = mhlo::DotDimensionNumbersAttr::get( - op->getContext(), batchDims, batchDims, {rank - 1}, {rank - 2}); - auto result = rewriter.create( - loc, resultTy, lhs, rhs, dotDimAttr, /*precision_config*/ nullptr); - return result.getResult(); -} - -FailureOr getBmmDotProduct(PatternRewriter &rewriter, Operation *op, - Value inpLhs, Value inpRhs) { +void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, + Value &inpRhs, int64_t leadingRank) { Value lhs = inpLhs; Value rhs = inpRhs; auto lhsRankTy = inpLhs.getType().dyn_cast(); @@ -249,124 +79,43 @@ FailureOr getBmmDotProduct(PatternRewriter &rewriter, Operation *op, auto lhsRank = lhsRankTy.getRank(); auto rhsRank = rhsRankTy.getRank(); - if (lhsRank < 2 || rhsRank < 2) - return rewriter.notifyMatchFailure( - op, "the input of batch-matmul must has rank >= 2"); // The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be // broadcastable). - auto maxRank = std::max(lhsRank, rhsRank); auto minRank = std::min(lhsRank, rhsRank); - if (maxRank != minRank) { - auto leadingRank = maxRank - minRank; - auto leadingDims = llvm::to_vector<4>(llvm::seq(0, leadingRank)); - auto broadcastDims = - llvm::to_vector<4>(llvm::seq(leadingRank, maxRank)); - auto lhsShape = lhsRankTy.getShape(); - auto rhsShape = rhsRankTy.getShape(); - if (lhsRank < rhsRank) { - std::vector newShape(rhsShape.begin(), - rhsShape.begin() + leadingRank); - newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); - auto newDimSizes = *getDimSizesOfTensor(rewriter, op, rhs, leadingDims); - auto lhsDimSizes = *getDimSizesOfTensor(rewriter, op, lhs); - newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), - lhsDimSizes.end()); - lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, - broadcastDims); - } else { - std::vector newShape(lhsShape.begin(), - lhsShape.begin() + leadingRank); - newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); - auto newDimSizes = *getDimSizesOfTensor(rewriter, op, lhs, leadingDims); - auto rhsDimSizes = *getDimSizesOfTensor(rewriter, op, rhs); - newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), - rhsDimSizes.end()); - rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, - broadcastDims); - } - } - - // [?, ?, m, n] x [?, n, k] ==> batch_matmul([m,n], [n,k]) - return getDotProduct(rewriter, op, lhs, rhs, /*rank*/ maxRank); -} - -FailureOr getMmDotProduct(PatternRewriter &rewriter, Operation *op, - Value inpLhs, Value inpRhs) { - auto lhsRankTy = inpLhs.getType().dyn_cast(); - auto rhsRankTy = inpRhs.getType().dyn_cast(); - - auto lhsRank = lhsRankTy.getRank(); - auto rhsRank = rhsRankTy.getRank(); - if (lhsRank < 2) - return rewriter.notifyMatchFailure( - op, "the left hand-side input of matmul must has rank >= 2"); - if (rhsRank != 2) - return rewriter.notifyMatchFailure( - op, "the right hand-side input of matmul must has rank == 2"); - - Value lhs = inpLhs; - Value rhs = inpRhs; - // [?, m, n] x [n, k] ==> [?xm, n] x [n, k] - std::vector collapDimSizes; - if (lhsRank > 2) { - std::vector collapDims; - for (int64_t d = 0; d < lhsRank - 1; ++d) { - collapDims.push_back(d); - } - auto collapDimSizesInfo = getCollapsedTensor(rewriter, op, lhs, collapDims); - if (failed(collapDimSizesInfo)) - return rewriter.notifyMatchFailure( - op, "failed to construct matrix-matrix multiply"); - std::tie(lhs, collapDimSizes) = *collapDimSizesInfo; + auto leadingDims = llvm::to_vector<4>(llvm::seq(0, leadingRank)); + auto broadcastDims = llvm::to_vector<4>( + llvm::seq(leadingRank, minRank + leadingRank)); + auto lhsShape = lhsRankTy.getShape(); + auto rhsShape = rhsRankTy.getShape(); + if (lhsRank < rhsRank) { + std::vector newShape(rhsShape.begin(), + rhsShape.begin() + leadingRank); + newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); + auto newDimSizes = + *mhlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims); + auto lhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, lhs); + newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), + lhsDimSizes.end()); + lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, + broadcastDims); + } else { + std::vector newShape(lhsShape.begin(), + lhsShape.begin() + leadingRank); + newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); + auto newDimSizes = + *mhlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims); + auto rhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, rhs); + newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), + rhsDimSizes.end()); + rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, + broadcastDims); } - auto result = getDotProduct(rewriter, op, lhs, rhs, /*rank*/ 2); - if (failed(result)) - return rewriter.notifyMatchFailure( - op, "failed to construct matrix-matrix multiply"); - return getExpandedTensor(rewriter, op, *result, collapDimSizes, - /*expandPos*/ 0); + inpLhs = lhs; + inpRhs = rhs; } -FailureOr getMvDotProduct(PatternRewriter &rewriter, Operation *op, - Value inpLhs, Value inpRhs) { - auto lhsRankTy = inpLhs.getType().dyn_cast(); - auto rhsRankTy = inpRhs.getType().dyn_cast(); - - auto lhsRank = lhsRankTy.getRank(); - auto rhsRank = rhsRankTy.getRank(); - - if (rhsRank != 1) - return rewriter.notifyMatchFailure( - op, "the right hand-side input of matmul must has rank == 1"); - if (lhsRank < 2) - return rewriter.notifyMatchFailure( - op, "the left hand-side input of matmul must has rank >= 2"); - - auto unsqzRhsInfo = mhlo::unsqueezeTensor(rewriter, op, inpRhs, {1}); - if (failed(unsqzRhsInfo)) - return rewriter.notifyMatchFailure( - op, "failed to unsqueeze right hand-side input to rank 2"); - - auto unsqzRhs = *unsqzRhsInfo; - auto product = getMmDotProduct(rewriter, op, inpLhs, unsqzRhs); - if (failed(product)) - return rewriter.notifyMatchFailure( - op, "failed to construct matrix-vector multiply"); - Value result = *product; - std::vector collapDimSizes; - auto collapDimSizesInfo = getCollapsedTensor(rewriter, op, result, {-2, -1}); - if (failed(collapDimSizesInfo)) - return rewriter.notifyMatchFailure( - op, "failed to construct matrix-vector multiply"); - std::tie(result, collapDimSizes) = *collapDimSizesInfo; - return result; -} -} // namespace mhlo -} // namespace mlir - -namespace { // Perform the basic n-dim matmul operation encompassing the handling of // broadcasting and dynamic shape propagation. // All PyTorch ops that leverage matrix multiplication will derive this and @@ -405,61 +154,48 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return op.emitError("matmul: inputs can't be 0-rank"); } - FailureOr product; - if (rhsRank == 1) { - if (lhsRank == 1) { - // If both tensors are 1-dimensional, the dot product (scalar) is - // returned. - auto unsqzLhs = mhlo::unsqueezeTensor(rewriter, op, lhs, {0}); - product = mhlo::getMvDotProduct(rewriter, op, *unsqzLhs, rhs); - product = mhlo::getZeroRankTensor(rewriter, op, *product); - } else { - // If the first argument is 2-dimensional and the second argument is - // 1-dimensional, the matrix-vector product is returned. - // NB: if lhsRank > 2 reshape it to rank 2. - product = mhlo::getMvDotProduct(rewriter, op, lhs, rhs); - } - } else if (rhsRank == 2) { - if (lhsRank == 1) { - // If the first argument is 1-dimensional, a 1 is prepended to its - // dimension for the purpose of the batched matrix multiply and removed - // after. - auto unsqzLhs = mhlo::unsqueezeTensor(rewriter, op, lhs, {0}); - product = mhlo::getMmDotProduct(rewriter, op, *unsqzLhs, rhs); - auto collapDimSizesInfo = - mhlo::getCollapsedTensor(rewriter, op, *product, {-2, -1}); - if (failed(collapDimSizesInfo)) - return op.emitError("failed to construct matrix-vector multiply"); - - std::vector collapDimSizes; - std::tie(product, collapDimSizes) = *collapDimSizesInfo; - } else { - // If both arguments are 2-dimensional, the matrix-matrix product is - // returned. NB: if lhsRank > 2 reshape it to rank 2. - product = mhlo::getMmDotProduct(rewriter, op, lhs, rhs); - } + if (lhsRank <= 2 && rhsRank <= 2) { + output = rewriter.create(op->getLoc(), lhs, rhs, nullptr); + return success(); + } + + int64_t nBatchDims; + if (rhsRank <= 2) { + auto leadingRank = lhsRank - 2; + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + nBatchDims = leadingRank; + } else if (lhsRank <= 2) { + auto leadingRank = rhsRank - 2; + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + nBatchDims = leadingRank; } else { - // rhsRank > 2 - if (lhsRank == 1) { - // If the first argument is 1-dimensional, a 1 is prepended to its - // dimension for the purpose of the batched matrix multiply and removed - // after. - auto unsqzLhs = mhlo::unsqueezeTensor(rewriter, op, lhs, {0}); - product = mhlo::getBmmDotProduct(rewriter, op, *unsqzLhs, rhs); - auto collapDimSizesInfo = - mhlo::getCollapsedTensor(rewriter, op, *product, {-2, -1}); - if (failed(collapDimSizesInfo)) - return op.emitError("failed to construct matrix-vector multiply"); - - std::vector collapDimSizes; - std::tie(product, collapDimSizes) = *collapDimSizesInfo; - } else { - product = mhlo::getBmmDotProduct(rewriter, op, lhs, rhs); - } + assert(rhsRank > 2 && lhsRank > 2); + auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank); + nBatchDims = std::max(lhsRank - 2, rhsRank - 2); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); } - if (failed(product)) - return op.emitError("matmul: conversion failed"); - output = *product; + auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + auto lhsContractingDim = nBatchDims + 1; + auto rhsContractingDim = nBatchDims; + if (lhsRank == 1) + lhsContractingDim = nBatchDims; + + mhlo::DotDimensionNumbersAttr dotDimensionNumbers = + mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhsBatchingDimensions=*/batchDims, + /*rhsBatchingDimensions=*/batchDims, + /*lhsContractingDimensions=*/{lhsContractingDim}, + /*rhsContractingDimensions=*/{rhsContractingDim}); + auto resultTy = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + output = rewriter + .create(op->getLoc(), resultTy, lhs, rhs, + dotDimensionNumbers, nullptr) + .getResult(); + return success(); } @@ -598,28 +334,45 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { "matmul for bias tensor"); // weight.T - auto weightT = mhlo::getPermutedTensor(rewriter, op, rhs, {1, 0}); - auto product = mhlo::getMmDotProduct(rewriter, op, lhs, weightT); - if (failed(product)) - return op.emitError("failed to perform matmul operation"); + rhs = getPermutedTensor(rewriter, op, rhs, {1, 0}); - Value matmulOutput = *product; - Value matmulPlusBias = matmulOutput; + auto lhsTy = lhs.getType().cast(); + auto rhsTy = rhs.getType().cast(); + auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(), + rhsTy.getRank() - lhsTy.getRank()); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank()); + auto nBatchDims = resultRank - 2; + auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + auto lhsContractingDim = nBatchDims + 1; + auto rhsContractingDim = nBatchDims; + + mhlo::DotDimensionNumbersAttr dotDimensionNumbers = + mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhsBatchingDimensions=*/batchDims, + /*rhsBatchingDimensions=*/batchDims, + /*lhsContractingDimensions=*/{lhsContractingDim}, + /*rhsContractingDimensions=*/{rhsContractingDim}); + + auto resultTy = + OpConversionPattern::getTypeConverter()->convertType( + op.getType()); + Value matmulOutput = rewriter.create( + op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr); + + Value matmulPlusBias = matmulOutput; if (!biasTy.template isa()) { // Bias addition broadcasts to the matmul output shape. - matmulPlusBias = rewriter - .create( - op->getLoc(), matmulOutput.getType(), - matmulOutput, bias, nullptr) - .getResult(); + matmulPlusBias = + rewriter + .create(op->getLoc(), resultTy, + matmulOutput, bias, nullptr) + .getResult(); } - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - matmulPlusBias); + rewriter.replaceOpWithNewOp(op, resultTy, matmulPlusBias); return success(); } }; diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToMhlo/linear.mlir index de2ef0ad3c3b..18ea97654ebe 100644 --- a/test/Conversion/TorchToMhlo/linear.mlir +++ b/test/Conversion/TorchToMhlo/linear.mlir @@ -4,7 +4,7 @@ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> // CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<2x3xf32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> // CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32> @@ -13,11 +13,13 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: ! return %0 : !torch.vtensor<[2,3],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.mm$basic$dynamic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<3x?xf32>) -> tensor +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<3x?xf32>) -> tensor // CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32> @@ -26,32 +28,60 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: return %0 : !torch.vtensor<[?,?],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.bmm$basic$static( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[10,3,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> -// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<10x3x5xf32> -// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> -// CHECK: return %[[T4]] : !torch.vtensor<[10,3,5],f32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<10x4x5xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<10x4x5xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<10x3x5xf32> +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32> func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[10,3,4],f32>, !torch.vtensor<[10,4,5],f32> -> !torch.vtensor<[10,3,5],f32> return %0 : !torch.vtensor<[10,3,5],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.bmm$basic$dynamic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,4],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor -// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor -// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?],f32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg1: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,4],f32>, !torch.vtensor<[?,4,?],f32> -> !torch.vtensor<[?,?,?],f32> return %0 : !torch.vtensor<[?,?,?],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.matmul$basic$static( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256,120],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256,120],f32> -> tensor<256x120xf32> @@ -76,6 +106,8 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, return %0 : !torch.vtensor<[4,256,256],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.matmul$basic$dynamic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,?,256],f32> -> tensor<4x?x256xf32> @@ -83,230 +115,116 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x?x256xf32> // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x?x256xf32> +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x?xf32> // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x?x256xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T8:.*]] = arith.muli %[[C1_I64]], %[[T3]] : i64 -// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 -// CHECK: %[[T10:.*]] = tensor.from_elements %[[T9]], %[[T7]] : tensor<2xi64> -// CHECK: %[[T11:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor<4x?x256xf32>, tensor<2xi64>) -> tensor -// CHECK: %[[T12:.*]] = "mhlo.dot_general"(%[[T11]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x?xf32>) -> tensor -// CHECK: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK: %[[T13:.*]] = tensor.dim %[[T12]], %[[C0_0]] : tensor -// CHECK: %[[T14:.*]] = arith.index_cast %[[T13]] : index to i64 -// CHECK: %[[C1_1:.*]] = arith.constant 1 : index -// CHECK: %[[T15:.*]] = tensor.dim %[[T12]], %[[C1_1]] : tensor -// CHECK: %[[T16:.*]] = arith.index_cast %[[T15]] : index to i64 -// CHECK: %[[T17:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T16]] : tensor<3xi64> -// CHECK: %[[T18:.*]] = "mhlo.dynamic_reshape"(%[[T12]], %[[T17]]) : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[T19:.*]] = mhlo.convert(%[[T18]]) : (tensor) -> tensor<4x?x?xf32> -// CHECK: %[[T20:.*]] = torch_c.from_builtin_tensor %[[T19]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> -// CHECK: return %[[T20]] : !torch.vtensor<[4,?,?],f32> +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x?x?xf32> +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32> func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[4,?,?],f32> return %0 : !torch.vtensor<[4,?,?],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.matmul$3dx1d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,?,256],f32> -> tensor<1x?x256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<256xf32> +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<1x?x256xf32> // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]], %[[C1_I64]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T1]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> // CHECK: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<1x?x256xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T8:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<1x?x256xf32> -// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T10:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<1x?x256xf32> -// CHECK: %[[T11:.*]] = arith.index_cast %[[T10]] : index to i64 -// CHECK: %[[C1_I64_1:.*]] = arith.constant 1 : i64 -// CHECK: %[[T12:.*]] = arith.muli %[[C1_I64_1]], %[[T7]] : i64 -// CHECK: %[[T13:.*]] = arith.muli %[[T12]], %[[T9]] : i64 -// CHECK: %[[T14:.*]] = tensor.from_elements %[[T13]], %[[T11]] : tensor<2xi64> -// CHECK: %[[T15:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T14]]) : (tensor<1x?x256xf32>, tensor<2xi64>) -> tensor -// CHECK: %[[T16:.*]] = "mhlo.dot_general"(%[[T15]], %[[T5]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x1xf32>) -> tensor -// CHECK: %[[C0_2:.*]] = arith.constant 0 : index -// CHECK: %[[T17:.*]] = tensor.dim %[[T16]], %[[C0_2]] : tensor -// CHECK: %[[T18:.*]] = arith.index_cast %[[T17]] : index to i64 -// CHECK: %[[C1_3:.*]] = arith.constant 1 : index -// CHECK: %[[T19:.*]] = tensor.dim %[[T16]], %[[C1_3]] : tensor -// CHECK: %[[T20:.*]] = arith.index_cast %[[T19]] : index to i64 -// CHECK: %[[T21:.*]] = tensor.from_elements %[[T7]], %[[T9]], %[[T20]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = "mhlo.dynamic_reshape"(%[[T16]], %[[T21]]) : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[C0_4:.*]] = arith.constant 0 : index -// CHECK: %[[T23:.*]] = tensor.dim %[[T22]], %[[C0_4]] : tensor -// CHECK: %[[T24:.*]] = arith.index_cast %[[T23]] : index to i64 -// CHECK: %[[C1_5:.*]] = arith.constant 1 : index -// CHECK: %[[T25:.*]] = tensor.dim %[[T22]], %[[C1_5]] : tensor -// CHECK: %[[T26:.*]] = arith.index_cast %[[T25]] : index to i64 -// CHECK: %[[C2_6:.*]] = arith.constant 2 : index -// CHECK: %[[T27:.*]] = tensor.dim %[[T22]], %[[C2_6]] : tensor -// CHECK: %[[T28:.*]] = arith.index_cast %[[T27]] : index to i64 -// CHECK: %[[C1_I64_7:.*]] = arith.constant 1 : i64 -// CHECK: %[[T29:.*]] = arith.muli %[[C1_I64_7]], %[[T26]] : i64 -// CHECK: %[[T30:.*]] = arith.muli %[[T29]], %[[T28]] : i64 -// CHECK: %[[T31:.*]] = tensor.from_elements %[[T24]], %[[T30]] : tensor<2xi64> -// CHECK: %[[T32:.*]] = "mhlo.dynamic_reshape"(%[[T22]], %[[T31]]) : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T33:.*]] = mhlo.convert(%[[T32]]) : (tensor) -> tensor<1x?xf32> -// CHECK: %[[T34:.*]] = torch_c.from_builtin_tensor %[[T33]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> -// CHECK: return %[[T34]] : !torch.vtensor<[1,?],f32> +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> +// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor<1x?xf32> +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32> func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[1,?],f32> return %0 : !torch.vtensor<[1,?],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.matmul$1dx3d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<256xf32> +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> // CHECK: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[T8:.*]] = tensor.dim %[[T5]], %[[C0_1]] : tensor<1x256xf32> -// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T10:.*]] = tensor.dim %[[T5]], %[[C1]] : tensor<1x256xf32> -// CHECK: %[[T11:.*]] = arith.index_cast %[[T10]] : index to i64 -// CHECK: %[[T12:.*]] = tensor.from_elements %[[T7]], %[[T9]], %[[T11]] : tensor<3xi64> -// CHECK: %[[T13:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T5]], %[[T12]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x256xf32>, tensor<3xi64>) -> tensor -// CHECK: %[[T14:.*]] = "mhlo.dot_general"(%[[T13]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor -// CHECK: %[[C0_2:.*]] = arith.constant 0 : index -// CHECK: %[[T15:.*]] = tensor.dim %[[T14]], %[[C0_2]] : tensor -// CHECK: %[[T16:.*]] = arith.index_cast %[[T15]] : index to i64 -// CHECK: %[[C1_3:.*]] = arith.constant 1 : index -// CHECK: %[[T17:.*]] = tensor.dim %[[T14]], %[[C1_3]] : tensor -// CHECK: %[[T18:.*]] = arith.index_cast %[[T17]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T19:.*]] = tensor.dim %[[T14]], %[[C2]] : tensor -// CHECK: %[[T20:.*]] = arith.index_cast %[[T19]] : index to i64 -// CHECK: %[[C1_I64_4:.*]] = arith.constant 1 : i64 -// CHECK: %[[T21:.*]] = arith.muli %[[C1_I64_4]], %[[T18]] : i64 -// CHECK: %[[T22:.*]] = arith.muli %[[T21]], %[[T20]] : i64 -// CHECK: %[[T23:.*]] = tensor.from_elements %[[T16]], %[[T22]] : tensor<2xi64> -// CHECK: %[[T24:.*]] = "mhlo.dynamic_reshape"(%[[T14]], %[[T23]]) : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T25:.*]] = mhlo.convert %[[T24]] : tensor -// CHECK: %[[T26:.*]] = torch_c.from_builtin_tensor %[[T25]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[T26]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[?,256,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.matmul$2dx1d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]], %[[C1_I64]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T1]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> -// CHECK: %[[T6:.*]] = "mhlo.dot_general"(%[[T0]], %[[T5]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x1xf32>) -> tensor -// CHECK: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK: %[[T7:.*]] = tensor.dim %[[T6]], %[[C0_0]] : tensor -// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T9:.*]] = tensor.dim %[[T6]], %[[C1]] : tensor -// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : index to i64 -// CHECK: %[[C1_I64_1:.*]] = arith.constant 1 : i64 -// CHECK: %[[T11:.*]] = arith.muli %[[C1_I64_1]], %[[T8]] : i64 -// CHECK: %[[T12:.*]] = arith.muli %[[T11]], %[[T10]] : i64 -// CHECK: %[[T13:.*]] = tensor.from_elements %[[T12]] : tensor<1xi64> -// CHECK: %[[T14:.*]] = "mhlo.dynamic_reshape"(%[[T6]], %[[T13]]) : (tensor, tensor<1xi64>) -> tensor -// CHECK: %[[T15:.*]] = mhlo.convert %[[T14]] : tensor -// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor -> !torch.vtensor<[?],f32> -// CHECK: return %[[T16]] : !torch.vtensor<[?],f32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?],f32> func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[?],f32> return %0 : !torch.vtensor<[?],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.matmul$1dx2d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: %[[T6:.*]] = "mhlo.dot_general"(%[[T5]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x256xf32>, tensor<256x?xf32>) -> tensor<1x?xf32> -// CHECK: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK: %[[T7:.*]] = tensor.dim %[[T6]], %[[C0_0]] : tensor<1x?xf32> -// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T9:.*]] = tensor.dim %[[T6]], %[[C1]] : tensor<1x?xf32> -// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : index to i64 -// CHECK: %[[C1_I64_1:.*]] = arith.constant 1 : i64 -// CHECK: %[[T11:.*]] = arith.muli %[[C1_I64_1]], %[[T8]] : i64 -// CHECK: %[[T12:.*]] = arith.muli %[[T11]], %[[T10]] : i64 -// CHECK: %[[T13:.*]] = tensor.from_elements %[[T12]] : tensor<1xi64> -// CHECK: %[[T14:.*]] = "mhlo.dynamic_reshape"(%[[T6]], %[[T13]]) : (tensor<1x?xf32>, tensor<1xi64>) -> tensor -// CHECK: %[[T15:.*]] = mhlo.convert %[[T14]] : tensor -// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor -> !torch.vtensor<[?],f32> -// CHECK: return %[[T16]] : !torch.vtensor<[?],f32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?],f32> func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[?],f32> return %0 : !torch.vtensor<[?],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.matmul$1dx1d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[C1_I64_1:.*]] = arith.constant 1 : i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T7]], %[[C1_I64_1]] : tensor<2xi64> -// CHECK: %[[T9:.*]] = "mhlo.dynamic_reshape"(%[[T1]], %[[T8]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> -// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T5]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[C0_2:.*]] = arith.constant 0 : index -// CHECK: %[[T11:.*]] = tensor.dim %[[T10]], %[[C0_2]] : tensor<1x1xf32> -// CHECK: %[[T12:.*]] = arith.index_cast %[[T11]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T13:.*]] = tensor.dim %[[T10]], %[[C1]] : tensor<1x1xf32> -// CHECK: %[[T14:.*]] = arith.index_cast %[[T13]] : index to i64 -// CHECK: %[[C1_I64_3:.*]] = arith.constant 1 : i64 -// CHECK: %[[T15:.*]] = arith.muli %[[C1_I64_3]], %[[T12]] : i64 -// CHECK: %[[T16:.*]] = arith.muli %[[T15]], %[[T14]] : i64 -// CHECK: %[[T17:.*]] = tensor.from_elements %[[T16]] : tensor<1xi64> -// CHECK: %[[T18:.*]] = "mhlo.dynamic_reshape"(%[[T10]], %[[T17]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T19:.*]] = "mhlo.reshape"(%[[T18]]) : (tensor<1xf32>) -> tensor -// CHECK: %[[T20:.*]] = mhlo.convert %[[T19]] : tensor -// CHECK: %[[T21:.*]] = torch_c.from_builtin_tensor %[[T20]] : tensor -> !torch.vtensor<[],f32> -// CHECK: return %[[T21]] : !torch.vtensor<[],f32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[],f32> func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.matmul$proj( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor @@ -314,40 +232,31 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x256xf32> // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T8:.*]] = arith.muli %[[C1_I64]], %[[T3]] : i64 -// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 -// CHECK: %[[T10:.*]] = tensor.from_elements %[[T9]], %[[T7]] : tensor<2xi64> -// CHECK: %[[T11:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T12:.*]] = "mhlo.dot_general"(%[[T11]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x256xf32>) -> tensor -// CHECK: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK: %[[T13:.*]] = tensor.dim %[[T12]], %[[C0_0]] : tensor -// CHECK: %[[T14:.*]] = arith.index_cast %[[T13]] : index to i64 -// CHECK: %[[C1_1:.*]] = arith.constant 1 : index -// CHECK: %[[T15:.*]] = tensor.dim %[[T12]], %[[C1_1]] : tensor -// CHECK: %[[T16:.*]] = arith.index_cast %[[T15]] : index to i64 -// CHECK: %[[T17:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T16]] : tensor<3xi64> -// CHECK: %[[T18:.*]] = "mhlo.dynamic_reshape"(%[[T12]], %[[T17]]) : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[T19:.*]] = mhlo.convert %[[T18]] : tensor -// CHECK: %[[T20:.*]] = torch_c.from_builtin_tensor %[[T19]] : tensor -> !torch.vtensor<[?,?,256],f32> -// CHECK: return %[[T20]] : !torch.vtensor<[?,?,256],f32> +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,256],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32> func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32> %1 = torch.aten.matmul %arg0, %0 : !torch.vtensor<[?,?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,?,256],f32> return %1 : !torch.vtensor<[?,?,256],f32> } +// ----- + // CHECK-LABEL: func.func @torch.aten.mm$proj( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor // CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot_general"(%[[T0]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor<256x256xf32>) -> tensor +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256x256xf32>) -> tensor // CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,256],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32> @@ -356,3 +265,4 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten %1 = torch.aten.mm %arg0, %0 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,256],f32> return %1 : !torch.vtensor<[?,256],f32> } + From 19f6e52f3fa160e9556db30a001d1806ebb4adfd Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Thu, 4 Aug 2022 09:27:57 +0800 Subject: [PATCH 4/4] rebase & add to bazel --- utils/bazel/torch-mlir-overlay/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 2ce5e9b7d9f6..d28545e951cb 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -450,6 +450,7 @@ cc_library( "lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp", "lib/Conversion/TorchToMhlo/BasicOp.cpp", "lib/Conversion/TorchToMhlo/GatherOp.cpp", + "lib/Conversion/TorchToMhlo/Linear.cpp", "lib/Conversion/TorchToMhlo/ViewLikeOps.cpp", "lib/Conversion/TorchToMhlo/ReductionOp.cpp", "lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h",