From 500d11c7e52409fbdaeb53c742795f89ed6134d9 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Fri, 12 Aug 2022 11:07:22 +0000 Subject: [PATCH] BladeDISC patch 20230623 * Rewrite mhlo with stablehlo after rebase. * Fix BAZEL building error of multiple definition. * Fix float width * Fix divide_floor & export promoteTypes api (#9) * To comply with the old pytorch versions * Add native_dropout_backward & native_layer_norm_backward decomposition (#15) * Add native_dropout and related ops pattern (#1211) * [MHLO] fix dot general contract * Fix batch_norm, div.Tensor_mode and folder (#21) * Reimplement linear lowering * Reimplement 2-D rhs for mutmul * Add torchdynamo * Decompose torch.slice_scatter (#1622) * Fix i64 torch.tensor dtype * Add more mhlo basic converters * Alleviate softmax datatype check (#24) * Fix decompose native_batch_norm (#27) * Support group_norm lowering (#25) * Decompose torch.ones/zeros (#28) * Fix softmax output type * Fix gather * Fix some decompose patterns * Not check assert at runtime (#31) * Fix bool tensor attr conversion bug (#32) * Fix mlirDenseElementsAttrBoolGet --- .../TorchToStablehlo/StablehloLegalizeUtils.h | 3 +- .../Dialect/Torch/IR/GeneratedTorchOps.td | 96 ++- .../Dialect/Torch/Utils/TorchUpstream.h | 1 + .../TorchConversion/IR/TorchConversionOps.td | 4 + lib/Conversion/TorchToArith/TorchToArith.cpp | 9 +- lib/Conversion/TorchToStablehlo/Basic.cpp | 48 +- lib/Conversion/TorchToStablehlo/Linear.cpp | 117 +++- lib/Conversion/TorchToStablehlo/Pooling.cpp | 7 +- lib/Conversion/TorchToStablehlo/Reduction.cpp | 109 +++- .../StablehloLegalizeUtils.cpp | 9 +- lib/Conversion/TorchToStablehlo/ViewLike.cpp | 30 +- lib/Dialect/Torch/IR/TorchTypes.cpp | 4 + .../Torch/Transforms/DecomposeComplexOps.cpp | 555 ++++++++++++++++-- .../Torch/Transforms/InlineGlobalSlots.cpp | 3 +- lib/Dialect/Torch/Utils/TorchUpstream.cpp | 2 +- .../TorchConversion/IR/TorchConversionOps.cpp | 49 ++ .../jit_ir/build_tools/torch_ods_gen.py | 9 + .../importer/jit_ir/csrc/class_annotator.cpp | 6 +- .../jit_ir/csrc/function_importer.cpp | 2 +- .../importer/jit_ir/csrc/ivalue_importer.cpp | 48 +- .../importer/jit_ir/csrc/node_importer.cpp | 2 +- .../jit_ir/csrc/torch_to_mlir_utils.cpp | 31 +- .../jit_ir/csrc/torch_to_mlir_utils.h | 7 + test/Conversion/TorchToMhlo/dropout.mlir | 47 ++ test/Conversion/TorchToStablehlo/linear.mlir | 57 ++ utils/bazel/torch-mlir-overlay/BUILD.bazel | 15 +- 26 files changed, 1128 insertions(+), 142 deletions(-) create mode 100644 test/Conversion/TorchToMhlo/dropout.mlir diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 6d31d267ac0..e8d57b7f6a7 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -45,7 +45,8 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Operation *op, Value scalarValue, Type dtype); -Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); +Value promoteType(PatternRewriter &rewriter, Location loc, Value input, + TensorType outType); Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, TensorType outType); diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 3d7bee6d118..cb34784b999 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4929,6 +4929,65 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ }]; } +def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$use_input_stats, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + +def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + Torch_IntType:$num_groups, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenGroupNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -7233,9 +7292,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [ } def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ + Pure, AllowsTypeRefinement, HasValueSemantics, - ReadOnly + ReadOnly, ]> { let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins @@ -7742,53 +7802,53 @@ def Torch_AtenMaxOp : Torch_Op<"aten.max", [ }]; } -def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [ +def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::amax : (Tensor, int[]?, bool) -> Tensor`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim, + AnyTorchOptionalListOfTorchIntType:$dim, Torch_BoolType:$keepdim ); let results = (outs - AnyTorchTensorType:$values, - AnyTorchTensorType:$indices + AnyTorchTensorType:$results ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 2); + ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenMaxDimOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 2); + void AtenAmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [ +def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::amax : (Tensor, int[], bool) -> (Tensor)`"; + let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dim, + Torch_IntType:$dim, Torch_BoolType:$keepdim ); let results = (outs - AnyTorchTensorType:$result + AnyTorchTensorType:$values, + AnyTorchTensorType:$indices ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); } - void AtenAmaxOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenMaxDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); } }]; } diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index efb114fbfa1..6db44fe64ca 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -170,6 +170,7 @@ enum ReductionType { MAX, MEAN, MIN, SUM, PROD }; ReductionType get_reduction_enum(const llvm::StringRef &reduce); +ScalarType promoteTypes(ScalarType a, ScalarType b); } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td index a02d7b46b60..cafafff67de 100644 --- a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td +++ b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td @@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", let assemblyFormat = [{ $operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result)) }]; + let hasCanonicalizer = 1; let hasVerifier = 1; } @@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso let assemblyFormat = [{ $operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result)) }]; + let hasCanonicalizer = 1; let hasVerifier = 1; } @@ -80,6 +82,7 @@ def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [ let assemblyFormat = [{ $operand attr-dict }]; + let hasFolder = 1; } def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [ @@ -98,6 +101,7 @@ def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [ let assemblyFormat = [{ $operand attr-dict }]; + let hasFolder = 1; } def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [ diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 9e3cc2f7537..1c393714f27 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -429,13 +429,20 @@ class ConvertTorchToArith : public ConvertTorchToArithBase target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); + // target.addIllegalOp(); + // patterns.add>(typeConverter, + // context); + target.addIllegalOp(); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); + target.addIllegalOp(); patterns.add>( typeConverter, context); diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 6ed3e5d7dc3..a23c206a128 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -32,6 +32,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; +using namespace mlir::torch::TorchConversion; LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, mlir::Value &self, mlir::Value &other, @@ -148,7 +149,7 @@ class ConvertAtenUnaryOp : public OpConversionPattern { auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - self = hlo::promoteType(rewriter, self, outType); + self = hlo::promoteType(rewriter, op.getLoc(), self, outType); rewriter.replaceOpWithNewOp(op, outType, self); return success(); } @@ -253,8 +254,8 @@ class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { ->convertType(op.getType()) .template cast(); - lhs = hlo::promoteType(rewriter, lhs, outTy); - rhs = hlo::promoteType(rewriter, rhs, outTy); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); @@ -300,8 +301,8 @@ class ConvertAtenAddSubOp : public OpConversionPattern { } } - lhs = hlo::promoteType(rewriter, lhs, outType); - rhs = hlo::promoteType(rewriter, rhs, outType); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); if (!skipMultiplyAlpha(op.getAlpha())) { Value alpha = hlo::scalarToStablehloTensor(rewriter, op, @@ -354,8 +355,8 @@ class ConvertAtenMulDivOp : public OpConversionPattern { outElemTy); } DenseIntElementsAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, lhs, outType); - rhs = hlo::promoteType(rewriter, rhs, outType); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); @@ -427,7 +428,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { } // TODO: what is the PyTorch default type promotion? - rhs = hlo::promoteType(rewriter, rhs, lhsTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonDirectionAttr compareDirectionAttr; @@ -494,8 +495,10 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern { TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType); - Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType); + Value lhs = + hlo::promoteType(rewriter, op.getLoc(), adaptor.getSelf(), outType); + Value rhs = + hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, @@ -610,8 +613,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = getTypeConverter()->convertType(op.getType()).cast(); // promote self and other types - self = hlo::promoteType(rewriter, self, outType); - other = hlo::promoteType(rewriter, other, outType); + self = hlo::promoteType(rewriter, op.getLoc(), self, outType); + other = hlo::promoteType(rewriter, op.getLoc(), other, outType); if (failed( broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits))) @@ -807,8 +810,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); } DenseIntElementsAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, lhs, outType); - rhs = hlo::promoteType(rewriter, rhs, outType); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); @@ -1212,7 +1215,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Promote type for (auto &v : builtinTensors) { - v = hlo::promoteType(rewriter, v, outType); + v = hlo::promoteType(rewriter, op.getLoc(), v, outType); } rewriter.replaceOpWithNewOp( @@ -1404,8 +1407,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outTy = this->getTypeConverter()->convertType(op.getType()).cast(); - lhs = hlo::promoteType(rewriter, lhs, outTy); - rhs = hlo::promoteType(rewriter, rhs, outTy); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); @@ -1547,12 +1550,10 @@ class ConvertRuntimeAssertOp : public OpConversionPattern { matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { bool condition; - if (!matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) { - return rewriter.notifyMatchFailure( - op, "unimplemented: condition must be a constant"); - } - if (!condition) { - return op->emitError("condition must be true"); + if (matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) { + if (!condition) { + return op->emitError("condition must be true"); + } } rewriter.eraseOp(op); return success(); @@ -1679,7 +1680,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index 0786151cb21..0a403da9bb9 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -73,50 +73,57 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, } RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op, - Value &lhs, Value &rhs, + Value &lhs, Value &rhs, int64_t nBatchDims, int64_t lhsResultDim, int64_t rhsResultDim, int64_t lhsContractingDim, int64_t rhsContractingDim) { auto lhsTy = lhs.getType().dyn_cast(); auto rhsTy = rhs.getType().dyn_cast(); + bool shouldCastLhs = false; + bool shouldCastRhs = false; auto oldLhsShape = lhsTy.getShape(); auto oldRhsShape = rhsTy.getShape(); SmallVector lhsShape; SmallVector rhsShape; + SmallVector outShape; + lhsShape.append(oldLhsShape.begin(), oldLhsShape.end()); rhsShape.append(oldRhsShape.begin(), oldRhsShape.end()); + // set batch dims + for (auto k = 0; k < nBatchDims; ++k) { + if (lhsShape[k] == ShapedType::kDynamic && rhsShape[k] >= 0) { + lhsShape[k] = rhsShape[k]; + shouldCastLhs = true; + } + if (rhsShape[k] == ShapedType::kDynamic && lhsShape[k] >= 0) { + rhsShape[k] = lhsShape[k]; + shouldCastRhs = true; + } + outShape.push_back(lhsShape[k]); + } + // set contracting dims auto lhsContractingDimSize = lhsShape[lhsContractingDim]; auto rhsContractingDimSize = rhsShape[rhsContractingDim]; if (lhsContractingDimSize != rhsContractingDimSize) { if (lhsContractingDimSize == ShapedType::kDynamic && rhsContractingDimSize >= 0) { lhsShape[lhsContractingDim] = rhsContractingDimSize; - auto newRankTy = RankedTensorType::get(lhsShape, lhsTy.getElementType()); - lhs = rewriter.create(op->getLoc(), newRankTy, lhs); + shouldCastLhs = true; } else if (rhsContractingDimSize == ShapedType::kDynamic && lhsContractingDimSize >= 0) { rhsShape[rhsContractingDim] = lhsContractingDimSize; - auto newRankTy = RankedTensorType::get(rhsShape, rhsTy.getElementType()); - rhs = rewriter.create(op->getLoc(), newRankTy, rhs); + shouldCastRhs = true; } } - SmallVector outShape; - // set batch dims, will skip invalid dimensions - for (int64_t k = 0; k < static_cast(lhsShape.size()); ++k) { - if (k == lhsResultDim || k == lhsContractingDim) - continue; - outShape.push_back(lhsShape[k]); + if (shouldCastLhs) { + auto newRankTy = RankedTensorType::get(lhsShape, lhsTy.getElementType()); + lhs = rewriter.create(op->getLoc(), newRankTy, lhs); } - for (int64_t k = 0, b = 0; k < static_cast(rhsShape.size()); ++k) { - if (b >= static_cast(outShape.size())) - break; - if (k == rhsResultDim || k == rhsContractingDim) - continue; - if (outShape[b] == ShapedType::kDynamic && rhsShape[k] >= 0) { - outShape[b] = rhsShape[k]; - } - b++; + + if (shouldCastRhs) { + auto newRankTy = RankedTensorType::get(rhsShape, rhsTy.getElementType()); + rhs = rewriter.create(op->getLoc(), newRankTy, rhs); } // set result dimensions @@ -226,9 +233,57 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { return success(); } + // For lhsRank > 2 and rhsRrank == 2, use dynamicReshapeOp to convert lhs to + // 2-D Tensor. Then use Dot op to perform matrix * vector/matrix, and + // finnally reshape to right shape. It avoids extra data copy and compute. + + auto loc = op->getLoc(); const auto &options = ConvertAtenOp::getOptions(); + if (rhsRank == 2) { + SmallVector resultDims; + auto dotLhsTy = RankedTensorType::get( + {lhsRank == 2 ? lhsTy.getShape()[0] : ShapedType::kDynamic, + lhsTy.getShape()[lhsRank - 1]}, + lhsElemTy); + Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); + Value numel = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + // reshape lhs to 2-D tensor and record output shape + for (int i = 0; i < lhsRank - 1; ++i) { + // May use create or fold. + Value dimValue = rewriter.create(loc, lhs, i); + resultDims.push_back(dimValue); + numel = rewriter.createOrFold( + loc, numel, + rewriter.create(loc, intType, dimValue)); + } + Value lhsLastRankDim = rewriter.create( + loc, intType, rewriter.create(loc, lhs, lhsRank - 1)); + resultDims.push_back(rewriter.create(loc, rhs, 1)); + Value reshapeDim = + rewriter + .create( + op->getLoc(), ValueRange{numel, lhsLastRankDim}) + .getResult(); + lhs = rewriter.create(loc, dotLhsTy, lhs, + reshapeDim); + auto outTy = + ConvertAtenOp::getTypeConverter()->convertType(op.getType()); + auto dotType = RankedTensorType::get( + {lhsRank == 1 ? lhsTy.getShape()[0] : ShapedType::kDynamic, + rhsTy.getShape()[1]}, + outTy.template cast().getElementType()); + Value matmulOutput = + rewriter.create(loc, dotType, lhs, rhs, nullptr); + // reshape result to output shape + output = rewriter.create( + loc, outTy, matmulOutput, + rewriter.create(loc, resultDims)); + return success(); + } + int64_t nBatchDims; - if (rhsRank <= 2) { + if (rhsRank == 1) { auto leadingRank = lhsRank - 2; getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, options.dimSizeIndexBits); @@ -264,10 +319,10 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); auto outTy = - castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, - lhsContractingDim, rhsContractingDim); + castContractingDim(rewriter, op, lhs, rhs, nBatchDims, lhsResultDim, + rhsResultDim, lhsContractingDim, rhsContractingDim); output = rewriter - .create(op->getLoc(), outTy, lhs, rhs, + .create(loc, outTy, lhs, rhs, dotDimensionNumbers, nullptr) .getResult(); return success(); @@ -379,10 +434,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); - if (lhsRank != 2 && lhsRank != 3) - return op.emitError("aten.Linear called but input rank not 2 or 3"); - if (rhsRank != 2 && rhsRank != 3) - return op.emitError("aten.Linear called but weight rank not 2 or 3"); + if (lhsRank < 1) + return op.emitError("aten.Linear called but input rank 0"); + if (rhsRank != 2) + return op.emitError("aten.Linear called but weight rank not 2"); return success(); } @@ -428,8 +483,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto rhsContractingDim = nBatchDims; auto outTy = - castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, - lhsContractingDim, rhsContractingDim); + castContractingDim(rewriter, op, lhs, rhs, nBatchDims, lhsResultDim, + rhsResultDim, lhsContractingDim, rhsContractingDim); stablehlo::DotDimensionNumbersAttr dotDimensionNumbers = stablehlo::DotDimensionNumbersAttr::get( rewriter.getContext(), @@ -785,7 +840,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { const auto &options = getOptions(); bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, options.dimSizeIndexBits); - bias = hlo::promoteType(rewriter, bias, outTy); + bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 4bfe6c6110e..5b5271a2c5e 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -16,13 +16,13 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include #include @@ -484,7 +484,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value divisor = hlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) .value(); - divisor = hlo::promoteType(rewriter, divisor, outTy); + divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); @@ -494,7 +494,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Use another stablehlo.ReduceWindowOp to get the divisor Value windowSizeConst = hlo::getConstTensor(rewriter, op, {1.0}, {}).value(); - windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy); + windowSizeConst = + hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy); const auto &options = getOptions(); auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index ce0d1f371cb..5fb7e6e6927 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -16,13 +16,13 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" using namespace mlir; using namespace mlir::torch; @@ -50,7 +50,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } - if (isa(op)) { + if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getInf( @@ -481,6 +481,109 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenAmaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in STABLEHLO"); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + + // Currently, (u)int8 dtype is not supported + if (inputElemTy.isa() && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenSumDimIntListOp to STABLEHLO"); + } + + SmallVector inputDims; + SmallVector dims; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { + return rewriter.notifyMatchFailure(op, "non-int dim list unsupported"); + } + if (inputDims.size() == 0) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } + + for (auto d : inputDims) { + d = toPositiveDim(d, inputTy.getRank()); + // Drop invalid dims + if (isValidDim(d, inputTy.getRank())) { + dims.push_back(d); + } + } + + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + + Region ®ion = stablehloReduceOp.getBody(); + Block &block = region.emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value addResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), addResult); + } + + if (keepDim) { + const auto &options = getOptions(); + auto outShapeInfo = + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto outShapeVec = *outShapeInfo; + auto one = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + for (int64_t i : dims) { + outShapeVec[i] = one; + } + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + stablehloReduceOp.getResult(0), outShapeTensor); + return success(); + } + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + stablehloReduceOp.getResults()); + return success(); +} +} // namespace + // AtenSumDimIntListOp namespace { template <> @@ -588,6 +691,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( stablehloReduceOp.getResult(0), outShapeTensor); return success(); } + rewriter.replaceOpWithNewOp(op, outTy, stablehloReduceOp.getResults()); return success(); @@ -836,6 +940,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 785ae50e6b0..a25a66bbb29 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -185,15 +185,14 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, dtype_tensor); } -Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { - Operation *op = input.getDefiningOp(); - TensorType in_type = input.getType().dyn_cast(); +Value promoteType(PatternRewriter &rewriter, Location loc, Value input, + TensorType outType) { + TensorType in_type = input.getType().cast(); if (in_type.getElementType() != outType.getElementType()) { TensorType promotedType = in_type.cloneWith(in_type.getShape(), outType.getElementType()); - return rewriter.create(op->getLoc(), promotedType, - input); + return rewriter.create(loc, promotedType, input); } return input; } diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index ea19092e6c8..fe86f52dbcb 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -22,7 +23,6 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include using namespace mlir; @@ -179,7 +179,12 @@ class ConvertAtenViewOp : public ConvertAtenOp { auto loc = op.getLoc(); auto newRank = dimSizes.size(); - if (newRank == 0 || rankType.getRank() == 0) { + auto outType = + OpConversionPattern::getTypeConverter()->convertType( + op.getType()); + bool isStaticShape = + outType.template dyn_cast().getNumDynamicDims() == 0; + if ((newRank == 0 || rankType.getRank() == 0) && isStaticShape) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( @@ -208,12 +213,15 @@ class ConvertAtenViewOp : public ConvertAtenOp { Value numel = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); - for (auto d : dimSizes) { - numel = rewriter.create(loc, numel, d); + auto rank = rankType.getRank(); + for (size_t d = 0; d < rank; ++d) { + Value dimSize = rewriter.create( + loc, intType, + rewriter.create(loc, adaptor.getSelf(), d)); + numel = rewriter.create(loc, numel, dimSize); } numel = rewriter.create(loc, rewriter.getIndexType(), numel); - if (dimSizes.size() == 0) { rewriter.replaceOpWithNewOp( op, @@ -237,7 +245,7 @@ class ConvertAtenViewOp : public ConvertAtenOp { bool getAtenViewOpSizes(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, SmallVector &dimSizes) const; -}; +}; // namespace template <> bool ConvertAtenViewOp::getAtenViewOpSizes( @@ -366,7 +374,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "the size of the dimension being squeezed is can't be unknown"); - rewriter.replaceOp(op, adaptor.getSelf()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf()); return success(); } @@ -403,7 +412,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); - int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); + int64_t inputRank = + adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank + 1); if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); @@ -413,8 +423,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (failed(unsqzTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create unsqueezed tensor"); - - rewriter.replaceOp(op, *unsqzTensorInfo); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), *unsqzTensorInfo); return success(); } diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 8eb844cbd00..9fae814a1cc 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -402,6 +402,10 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { if (auto floatType = dtype.dyn_cast()) { return dtype; } else if (auto integerType = dtype.dyn_cast()) { + if (integerType.isUnsignedInteger()) { + return IntegerType::get(context, integerType.getWidth(), + IntegerType::Unsigned); + } return IntegerType::get(context, integerType.getWidth(), IntegerType::Signless); } else if (auto complexType = dtype.dyn_cast()) { diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b0dce438e07..7024d74a837 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -418,9 +418,10 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); - if (!op.getDtype().getType().isa()) - return rewriter.notifyMatchFailure( - op, "Unimplemented non-None dtype for softmax"); + // Do not need check dtype args here, since dtype have been infered in op.getType() + // if (!op.getDtype().getType().isa()) + // return rewriter.notifyMatchFailure( + // op, "Unimplemented non-None dtype for softmax"); BaseTensorType tensorType = self.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) @@ -429,6 +430,12 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { Value result = getSoftmaxResult(op, self, tensorType, rewriter); if (!result) return failure(); + + auto outDtype = op.getType().cast().getDtype(); + if (outDtype != tensorType.getDtype()) { + result = convertTensorToDtype(rewriter, op.getLoc(), result, outDtype); + } + rewriter.replaceOpWithNewOp(op, op.getType(), result); return success(); @@ -474,7 +481,13 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern { Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); if (!result) return op.emitError("failed to get softmax result"); - rewriter.replaceOpWithNewOp(op, resultTensorType, + + auto outDtype = op.getType().cast().getDtype(); + if (outDtype != tensorType.getDtype()) { + result = convertTensorToDtype(rewriter, op.getLoc(), result, outDtype); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), result); return success(); } @@ -801,19 +814,21 @@ class DecomposeAtenLogSoftmaxIntOp LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); - if (!op.getDtype().getType().isa()) - return rewriter.notifyMatchFailure( - op, "Unimplemented non-None dtype for log_softmax"); - BaseTensorType tensorType = self.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); - Value logSoftmax = getLogSoftmaxResult(op, rewriter); - if (!logSoftmax) + Value result = getLogSoftmaxResult(op, rewriter); + if (!result) return rewriter.notifyMatchFailure( op, "getLogSoftmaxResult function returned nullptr"); - rewriter.replaceOp(op, logSoftmax); + + auto outDtype = op.getType().cast().getDtype(); + if (outDtype != tensorType.getDtype()) { + result = convertTensorToDtype(rewriter, op.getLoc(), result, outDtype); + } + + rewriter.replaceOp(op, result); return success(); } }; @@ -835,11 +850,18 @@ class DecomposeAten_LogSoftmaxOp : public OpRewritePattern { if (halfToFloat) return rewriter.notifyMatchFailure( op, "halfToFloat is currently not supported."); - Value _logSoftmax = getLogSoftmaxResult(op, rewriter); - if (!_logSoftmax) + Value result = getLogSoftmaxResult(op, rewriter); + if (!result) return rewriter.notifyMatchFailure( op, "getLogSoftmaxResult function returned nullptr"); - rewriter.replaceOp(op, _logSoftmax); + + BaseTensorType tensorType = op.getSelf().getType().cast(); + auto outDtype = op.getType().cast().getDtype(); + if (outDtype != tensorType.getDtype()) { + result = convertTensorToDtype(rewriter, op.getLoc(), result, outDtype); + } + + rewriter.replaceOp(op, result); return success(); } }; @@ -1141,11 +1163,14 @@ class DecomposeAtenRollOp : public OpRewritePattern { auto self = op.getSelf(); auto selfTy = self.getType().cast(); // roll(input, shift, dim) = cat({ - // slice(input, dim, -shift, none), - // slice(input, dim, 0, -shift)}, dim) + // slice(input, dim, (dimSize-shift)%dimSize, none), + // slice(input, dim, 0, (dimSize-shift)%dimSize}, dim) auto imitateRoll = [&](Value input, Value shift, Value dim, int64_t cstDim) { - Value negShift = rewriter.create(loc, shift); + Value dimSize = rewriter.create(loc, input, dim); + Value shiftPlus = rewriter.create(loc, dimSize, shift); + Value splitPos = + rewriter.create(loc, shiftPlus, dimSize); ArrayRef inputShape = selfTy.getSizes(); SmallVector sizes; sizes.append(inputShape.begin(), inputShape.end()); @@ -1153,14 +1178,14 @@ class DecomposeAtenRollOp : public OpRewritePattern { Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes), selfTy.getOptionalDtype()); Value slice0 = rewriter.create( - loc, sliceTy, input, dim, negShift, constNone, constOne); + loc, sliceTy, input, dim, splitPos, constNone, constOne); Value slice1 = rewriter.create( - loc, sliceTy, input, dim, constZero, negShift, constOne); + loc, sliceTy, input, dim, constZero, splitPos, constOne); Type listType = Torch::ListType::get(sliceTy); Value slices = rewriter.create( loc, listType, llvm::ArrayRef{slice0, slice1}); - return rewriter.create(loc, self.getType(), slices, dim); + return rewriter.create(loc, op.getType(), slices, dim); }; std::optional maybeRank = getTensorRank(self); if (!maybeRank) @@ -2125,6 +2150,67 @@ class DecomposeAtenDropoutOp : public OpRewritePattern { }; } // namespace +// grad_output * mask * scale +namespace { +class DecomposeAtenNativeDropoutBackwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutBackwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Value maskedGradOutput = rewriter.create( + loc, op.getType(), op.getGradOutput(), op.getMask()); + rewriter.replaceOpWithNewOp(op, op.getType(), + maskedGradOutput, op.getScale()); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenNativeDropoutOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value input = op.getInput(); + Value prob = op.getP(); + bool train = false; + if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) + return rewriter.notifyMatchFailure(op, "train must be a boolean constant"); + + BaseTensorType inputType = input.getType().cast(); + if (!train) { + // TODO(yancey.yx): supports inference mode + return op.emitError( + "native_dropout does not support argument train is false"); + } + if (!inputType.hasDtype() || !inputType.getDtype().isa()) + return rewriter.notifyMatchFailure( + op, "only support floating type input for training mode"); + Value noneVal = rewriter.create(loc); + Value floatOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value oneMinusP = rewriter.create(loc, floatOne, prob); + Value boolMask = rewriter.create( + loc, inputType, input, oneMinusP, /*generator=*/noneVal); + Value maskedInput = + rewriter.create(loc, inputType, boolMask, input); + Value output = + rewriter.create(loc, inputType, maskedInput, oneMinusP); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + boolMask = rewriter.create( + loc, op.getResult(1).getType(), boolMask, one); + rewriter.replaceOp(op, {output, boolMask}); + return success(); + } +}; +} // namespace + // Decompose aten.var into: aten.var.dim op. namespace { class DecomposeAtenVarOp : public OpRewritePattern { @@ -2379,6 +2465,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, Value input, Value prob, Value &output) { auto inputType = input.getType().cast(); + auto inputDtype = inputType.getDtype(); auto probType = prob.getType().cast(); // Both the `input` and `prob` must be ranked tensors. if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() || @@ -2394,8 +2481,14 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, // Since the `aten.randLike` op expects float-type operand, create a // float-type tensor with the same shape as that of the `input`. + Type floatDtype = rewriter.getF64Type(); + if (inputDtype.isa() && + inputDtype.cast().getWidth() < 64) { + floatDtype = rewriter.getF32Type(); + } + Value floatTensor = - convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type()); + convertTensorToDtype(rewriter, loc, input, floatDtype); Value none = rewriter.create(loc); Value randomVal = rewriter.create( loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none, @@ -2409,7 +2502,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, // As the `output` is expected to be of the `input` type, convert the boolean // tensor `lessThanP` to a `input` type tensor. - output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype()); + output = convertTensorToDtype(rewriter, loc, lessThanP, inputDtype); return success(); } @@ -2542,6 +2635,214 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { return success(); } }; + +class DecomposeAtenGroupNormOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenGroupNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + Value input = op.getInput(); + auto inputTy = input.getType().cast(); + if (!inputTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + ArrayRef inputSize = inputTy.getSizes(); + int64_t inputRank = inputTy.getSizes().size(); + if (inputRank != 4) { + return rewriter.notifyMatchFailure( + op, "group norm only support 4D input now."); + } + Value num_groups = op.getNumGroups(); + int64_t num_groups_int; + if (!matchPattern(num_groups, m_TorchConstantInt(&num_groups_int))) + return rewriter.notifyMatchFailure( + op, "non const num_groups for AtenGroupNormOp"); + + // reshape input -> [N, G, -1(G//C), H, W] + Value negOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value none = rewriter.create(loc); + SmallVector inputForNormTySize(inputRank + 1, kUnknownSize); + inputForNormTySize[1] = num_groups_int; + Type inputForNormTy = inputTy.getWithSizesAndDtype( + llvm::makeArrayRef(inputForNormTySize), inputTy.getDtype()); + SmallVector orginInputSize; + for (int i = 0; i < inputRank; ++i) { + Value index = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + orginInputSize.push_back( + rewriter.create(loc, input, index)); + } + SmallVector inputForNormSize{orginInputSize.begin(), + orginInputSize.end()}; + inputForNormSize.insert(inputForNormSize.begin() + 1, num_groups); + inputForNormSize[2] = negOne; + Value inputForNormSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), inputForNormSize); + Value reshapedInput = rewriter.create( + loc, inputForNormTy, input, inputForNormSizeList); + // only keep N, G, reduce G//C, H, W + int64_t axis = 2; + std::vector meanVarTySizes(inputForNormTySize.size(), 1); + for (int i = 0; i < axis; i++) + meanVarTySizes[i] = inputForNormTySize[i]; + auto meanVarTy = inputTy.getWithSizesAndDtype( + llvm::makeArrayRef(meanVarTySizes), inputTy.getDtype()); + SmallVector normalizedShapeSize{inputForNormSize.begin() + axis, + inputForNormSize.end()}; + auto normalizedSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), normalizedShapeSize); + + auto nativeLayerNorm = + rewriter + .create( + loc, inputForNormTy, meanVarTy, meanVarTy, reshapedInput, + normalizedSizeList, none, none, op.getEps()) + .getResult(0); + // rehshape back to origin shape + Value inputSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), orginInputSize); + Value originOutput = rewriter.create( + loc, op.getType(), nativeLayerNorm, inputSizeList); + // reshape weight and bias to [1, C, 1, 1] + Value weight = op.getWeight(); + Value bias = op.getBias(); + if (!weight.getType().isa() || + !bias.getType().isa()) { + SmallVector weightsAndBiasSize(inputRank - 1, one); + weightsAndBiasSize[0] = orginInputSize[1]; + + SmallVector weightsAndBiasTySize(inputRank - 1, kUnknownSize); + // weightsAndBiasTySize[1] = kUnknownSize; + + Value weightsAndBiasSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), weightsAndBiasSize); + if (!weight.getType().isa()) { + BaseTensorType weightType = weight.getType().cast(); + Type weightTy = weightType.getWithSizesAndDtype( + llvm::makeArrayRef(weightsAndBiasTySize), weightType.getDtype()); + weight = rewriter.create(loc, weightTy, weight, + weightsAndBiasSizeList); + originOutput = rewriter.create(loc, op.getType(), + originOutput, weight); + } + if (!bias.getType().isa()) { + BaseTensorType biasType = bias.getType().cast(); + Type biasTy = biasType.getWithSizesAndDtype( + llvm::makeArrayRef(weightsAndBiasTySize), biasType.getDtype()); + bias = rewriter.create(loc, biasTy, bias, + weightsAndBiasSizeList); + Value alpha = + rewriter.create(loc, rewriter.getF64FloatAttr(1)); + originOutput = rewriter.create( + loc, op.getType(), originOutput, bias, alpha); + } + } + rewriter.replaceOp(op, {originOutput}); + return success(); + } +}; + +} // namespace + +namespace { +class DecomposeAtenNativeLayerNormBackwardOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeLayerNormBackwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + + auto inputTy = op.getInput().getType().cast(); + if (!inputTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + int64_t inputRank = inputTy.getSizes().size(); + Value normalizedShape = op.getNormalizedShape(); + SmallVector normalizedShapeSizesTorchInt; + getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); + int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); + auto reduceDimInts = llvm::to_vector<4>(llvm::seq(axis, inputRank)); + auto outerDimInts = llvm::to_vector<4>(llvm::seq(0, axis)); + auto reducedTy = op.getResult(1).getType(); + auto sizeListType = ListType::get(IntType::get(context)); + + auto fromIntsToList = [&](ArrayRef dimInts) -> Value { + SmallVector dimVals; + dimVals.reserve(dimInts.size()); + std::transform(dimInts.begin(), dimInts.end(), + std::back_inserter(dimVals), [&](int64_t d) { + return rewriter.create( + loc, rewriter.getI64IntegerAttr(d)); + }); + Value dimList = + rewriter.create(loc, sizeListType, dimVals); + return dimList; + }; + // build reduce & outer dims + auto reduceDimList = fromIntsToList(reduceDimInts); + auto outerDimList = fromIntsToList(outerDimInts); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + Value cstFalse = rewriter.create(loc, false); + Value none = rewriter.create(loc); + + // x_hat + Value inputSubMean = rewriter.create( + loc, inputTy, op.getInput(), op.getMean(), one); + Value xHat = + rewriter.create(loc, inputTy, inputSubMean, op.getRstd()); + + // grad(x_hat) + Value xHatGrad = op.getGradOut(); + Value weight = op.getWeight(); + Value wGrad = none; + if (!weight.getType().isa()) { + xHatGrad = rewriter.create(loc, xHatGrad.getType(), + xHatGrad, weight); + wGrad = rewriter.create( + loc, weight.getType(), + rewriter.create(loc, inputTy, op.getGradOut(), xHat), + outerDimList, cstFalse, none); + } + Value bias = op.getBias(); + Value bGrad = none; + if (!bias.getType().isa()) { + bGrad = rewriter.create( + loc, bias.getType(), op.getGradOut(), outerDimList, cstFalse, none); + } + + Value cstTrue = rewriter.create(loc, true); + // grad(mean) + Value meanGrad = rewriter.create( + loc, op.getMean().getType(), xHatGrad, reduceDimList, cstTrue, none); + // grad(rstd) + Value xHatGradMulXHat = + rewriter.create(loc, inputTy, xHatGrad, xHat); + Value rstdGrad0 = rewriter.create( + loc, op.getRstd().getType(), xHatGradMulXHat, reduceDimList, cstTrue, + none); + Value rstdGrad1 = + rewriter.create(loc, inputTy, xHat, rstdGrad0); + + // grad(input) + Value inner = + rewriter.create(loc, inputTy, xHatGrad, meanGrad, one); + inner = + rewriter.create(loc, inputTy, inner, rstdGrad1, one); + Value gradInput = + rewriter.create(loc, inputTy, op.getRstd(), inner); + + rewriter.replaceOp(op, {gradInput, wGrad, bGrad}); + + return success(); + } +}; } // namespace namespace { @@ -2620,7 +2921,6 @@ class DecomposeAtenNativeLayerNormOp rewriter.create(loc, out.getType(), out, bias, one); } rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar}); - return success(); } }; @@ -2705,6 +3005,24 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose constant tensor full like ops. +template +class DecomposeConstantTensorAllocOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Type resultDtype = op.getType().template cast().getDtype(); + Value constVal = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), + fillVal, resultDtype); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSize(), constVal, op.getDtype(), op.getLayout(), + op.getDevice(), op.getPinMemory()); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNativeBatchNormOp : public OpRewritePattern { @@ -2762,6 +3080,7 @@ class DecomposeAtenNativeBatchNormOp // to make it broadcast-compatible with (N, C, D?, H?, W?). // 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?) // 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?) + SmallVector runningStatsShape(inputRank, one); runningStatsShape[1] = numFeatures; Value runningStatsSizeList = rewriter.create( @@ -2773,11 +3092,29 @@ class DecomposeAtenNativeBatchNormOp Type reshapeType = ValueTensorType::get( context, llvm::ArrayRef(runningStatsShapeInt), dtype); - runningMean = rewriter.create(loc, reshapeType, runningMean, - runningStatsSizeList); - runningVar = rewriter.create(loc, reshapeType, runningVar, - runningStatsSizeList); - + auto convertRuningStat = [&](Value runningStat) -> Value { + Type runningStatDtype = + runningStat.getType().cast().getDtype(); + runningStat = rewriter.create( + loc, + ValueTensorType::get(context, + llvm::makeArrayRef(runningStatsShapeInt), + runningStatDtype), + runningStat, runningStatsSizeList); + + if (dtype != runningStatDtype) { + Value cstFalse = rewriter.create(loc, false); + Value none = rewriter.create(loc); + return rewriter.create( + loc, reshapeType, runningStat, + getDtypeIntValueForType(rewriter, loc, dtype), + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } + return runningStat; + }; + runningMean = convertRuningStat(runningMean); + runningVar = convertRuningStat(runningVar); // normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)). Value inputSubMean = rewriter.create( loc, input.getType(), input, runningMean, /*alpha=*/one); @@ -2785,7 +3122,7 @@ class DecomposeAtenNativeBatchNormOp loc, runningVar.getType(), runningVar, eps, /*alpha=*/one); Value invStd = rewriter.create(loc, varEps.getType(), varEps); Value normalizedInput = rewriter.create( - loc, inputSubMean.getType(), inputSubMean, invStd); + loc, op.getType(0), inputSubMean, invStd); // The `weight` and `bias` must be reshaped to (1, C, 1?, 1?, 1?) to make it // broadcast-compatible with (N, C, D?, H?, W?). @@ -2798,20 +3135,18 @@ class DecomposeAtenNativeBatchNormOp std::optional weightRank = getTensorRank(weight); if (!weightRank || *weightRank != 1) return rewriter.notifyMatchFailure(op, "expected weight to be rank 1"); - weight = rewriter.create(loc, reshapeType, weight, - runningStatsSizeList); + weight = convertRuningStat(weight); batchNormOutput = rewriter.create( - loc, batchNormOutput.getType(), batchNormOutput, weight); + loc, op.getType(0), batchNormOutput, weight); } if (!bias.getType().isa()) { // Rank of `bias` must be exactly 1. std::optional biasRank = getTensorRank(bias); if (!biasRank || *biasRank != 1) return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); - bias = rewriter.create(loc, reshapeType, bias, - runningStatsSizeList); + bias = convertRuningStat(bias); batchNormOutput = rewriter.create( - loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one); + loc, op.getType(0), batchNormOutput, bias, /*alpha=*/one); } // The `mean` and `invstd` outputs are empty tensors in inference mode. @@ -2901,6 +3236,24 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose constant tensor like ops. +template +class DecomposeConstantTensorOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Type resultDtype = op.getType().template cast().getDtype(); + Value fillVal = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), + val, resultDtype); + rewriter.replaceOpWithNewOp(op, op.getType(), op.size(), fillVal, + op.getDtype(), op.getLayout(), op.getDevice(), + op.getPinMemory()); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.full` op into `aten.broadcastTo` class DecomposeAtenFullOp : public OpRewritePattern { @@ -3676,6 +4029,128 @@ class DecomposeAtenSelectScatterOp }; } // namespace +namespace { +// def slice_scatter(self, values, dim, start, end, step): +// size = self.size(dim) +// indices = torch.arange(size) +// shift_indices = indices - start +// mask = shift_indices % step == 0 +// start_mask = shift_indices >= 0 +// end_mask = shift_indices < end +// mask = mask * start_mask +// mask = mask * end_mask +// sizes = list(self.size()) +// rank = len(sizes) +// shape = [1] * rank +// shape[dim] = size +// mask = mask.view(shape) +// return torch.where(mask, values, self) +// +class DecomposeAtenSliceScatterOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSliceScatterOp op, + PatternRewriter &rewriter) const override { + int64_t inputRank = *getTensorRank(op.getSelf()); + int64_t dimInt = 0; + if (matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } else { + return rewriter.notifyMatchFailure(op, "dim must be constant"); + } + + auto getOptionalVal = [&](Value val, Value defVal) -> Value { + if (val.getType().isa()) { + return defVal; + } else { + return val; + } + }; + + Value one = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(1)); + Value zero = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(0)); + Value none = rewriter.create(op.getLoc()); + Value dimSize = + rewriter.create(op.getLoc(), op.getSelf(), op.getDim()); + + Value start = getOptionalVal(op.getStart(), zero); + Value end = getOptionalVal(op.getEnd(), dimSize); + Value step = getOptionalVal(op.getStep(), one); + // Step 0. create indices + Type indicesType = ValueTensorType::get( + op.getContext(), ArrayRef{kUnknownSize}, + IntegerType::get(op.getContext(), 64, IntegerType::Signed)); + Value indices = rewriter.create( + op.getLoc(), indicesType, dimSize, none, none, none, none); + + // Step 1. make indices broadcastable to self's shape + SmallVector newIndicesShapeInt(inputRank, 1); + SmallVector newIndicesShape(inputRank, one); + newIndicesShape[dimInt] = dimSize; + newIndicesShapeInt[dimInt] = kUnknownSize; + Value newIndicesSizeList = rewriter.create( + op.getLoc(), ListType::get(IntType::get(op.getContext())), + newIndicesShape); + Type indicesDtype = indices.getType().cast().getDtype(); + Type newIndicesType = ValueTensorType::get( + op.getContext(), llvm::makeArrayRef(newIndicesShapeInt), indicesDtype); + indices = rewriter.create(op.getLoc(), newIndicesType, + indices, newIndicesSizeList); + + // Step 2. calculate scatter indices mask + Type maskType = ValueTensorType::get( + op.getContext(), newIndicesType.cast().getSizes(), + IntegerType::get(op.getContext(), 1)); + auto shiftIndices = rewriter.create( + op.getLoc(), indices.getType(), indices, start, one); + auto stepRemainder = rewriter.create( + op.getLoc(), indices.getType(), shiftIndices, step); + Value mask = rewriter.create(op.getLoc(), maskType, + stepRemainder, zero); + auto maskStart = rewriter.create(op.getLoc(), maskType, + shiftIndices, zero); + auto maskEnd = + rewriter.create(op.getLoc(), maskType, indices, end); + mask = rewriter.create(op.getLoc(), maskType, mask, + maskStart); + mask = rewriter.create(op.getLoc(), maskType, mask, + maskEnd); + + // Step 3. make src broadcastable to self's shape + Value src = op.getSrc(); + BaseTensorType srcTensorType = src.getType().cast(); + if (!srcTensorType.hasSizes()) + return rewriter.notifyMatchFailure(op, "src tensor must have size"); + + ArrayRef srcShape = srcTensorType.getSizes(); + int64_t srcRank = srcShape.size(); + if (srcRank != inputRank) { + if (srcRank + 1 == inputRank) { + SmallVector sizes; + sizes.append(srcShape.begin(), srcShape.end()); + sizes.insert(sizes.begin() + dimInt, 1); + Type srcType = srcTensorType.getWithSizesAndDtype( + llvm::makeArrayRef(sizes), srcTensorType.getDtype()); + src = rewriter.create(op.getLoc(), srcType, src, + op.getDim()); + } else { + return rewriter.notifyMatchFailure(op, "src's rank doesn't match"); + } + } + + // Step 4. replace output = mask? src: self + rewriter.replaceOpWithNewOp(op, op.getType(), mask, + src, op.getSelf()); + return success(); + } +}; +} // namespace + namespace { class DecomposeAten_EmbeddingBagOp : public OpRewritePattern { @@ -4563,6 +5038,7 @@ class DecomposeComplexOpsPass void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); + // The strings in the `legalOps` ArrayRef don't exist during the call to the // constructor `DecomposeComplexOpsPass`, so the creation of the // `legalOpsSet` must be delayed to when `runOnOperation` gets called. @@ -4579,6 +5055,10 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeConstantTensorAllocOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeConstantTensorAllocOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -4701,6 +5181,11 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 76b57fe8c9a..e94f9a774c3 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -126,6 +126,7 @@ class InlineGlobalSlotsAnalysisState : public AnalysisState { return setSafe(false); } + bool isUninitialized() const { return true; } /// This is an optimistic analysis. We start assuming everything is safe. bool isSafe = true; }; @@ -135,7 +136,6 @@ class InlineGlobalSlotsAnalysis : public DataFlowAnalysis { InlineGlobalSlotsAnalysis(DataFlowSolver &solver); LogicalResult initialize(Operation *top) override; LogicalResult visit(ProgramPoint point) override; - private: /// The local transfer function determining the safety of `value`. bool isValueSafeTransferFunction(Value value); @@ -287,6 +287,7 @@ static bool isInitialValueTransitivelySafeToInline(Value initialValue, namespace { class InlineGlobalSlotsPass : public InlineGlobalSlotsBase { + void runOnOperation() override { ModuleOp module = getOperation(); diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 2dce14ef964..384c15a119a 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -28,7 +28,7 @@ static inline bool isQIntType(ScalarType t) { // Type promotion related code are copied from // aten/src/ATen/native/TypeProperties.*. //===----------------------------------------------------------------------===// -static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { +ScalarType promoteTypes(ScalarType a, ScalarType b) { // This is generated according to NumPy's promote_types constexpr auto u1 = ScalarType::Byte; constexpr auto i1 = ScalarType::Char; diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index cb2c5384480..4e8b1ddb57d 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -56,6 +56,18 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes( return success(); } +void ToBuiltinTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](ToBuiltinTensorOp op, PatternRewriter &rewriter) { + auto fromBuiltinTensorOp = + op.getOperand().getDefiningOp(); + if (!fromBuiltinTensorOp) + return rewriter.notifyMatchFailure(op, "operand not FromBuiltinTensorOp"); + rewriter.replaceOp(op, fromBuiltinTensorOp.getOperand()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // FromBuiltinTensorOp //===----------------------------------------------------------------------===// @@ -71,6 +83,17 @@ LogicalResult FromBuiltinTensorOp::verify() { return success(); } +void FromBuiltinTensorOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](FromBuiltinTensorOp op, PatternRewriter &rewriter) { + auto toBuiltinTensorOp = op.getOperand().getDefiningOp(); + if (!toBuiltinTensorOp) + return rewriter.notifyMatchFailure(op, "operand not ToBuiltinTensorOp"); + rewriter.replaceOp(op, toBuiltinTensorOp.getOperand()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // FromI64Op //===----------------------------------------------------------------------===// @@ -97,6 +120,32 @@ OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { } } +//===----------------------------------------------------------------------===// +// FromI1Op +//===----------------------------------------------------------------------===// + +OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) { + auto attr = adaptor.getOperand().dyn_cast_or_null(); + if (attr) { + return attr; + } else { + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// ToI1Op +//===----------------------------------------------------------------------===// + +OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) { + auto attr = adaptor.getOperand().dyn_cast_or_null(); + if (attr) { + return attr; + } else { + return nullptr; + } +} + //===----------------------------------------------------------------------===// // ToF64Op //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 2bc06de1295..9494d6daac7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -390,6 +390,15 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) + emit( + "aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" + ) + emit( + "aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)" + ) + emit( + "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" + ) emit( "aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" ) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp index b144e946ba5..bcfc5c99a42 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.cpp @@ -8,9 +8,13 @@ //===----------------------------------------------------------------------===// #include "class_annotator.h" - +#include "torch_to_mlir_utils.h" #include +#if TORCH_VERSION_LT(1, 8) +#include "ATen/core/function.h" +#endif + using namespace torch_mlir; //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp index 4a538fbcbd5..b9eb261965b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp @@ -69,7 +69,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp( /*userAllowsRefinement=*/false)); }; MlirBlock block = importBlock( - context, torch::jit::toGraphFunction(*function).graph()->block(), + context, getGraphFromFunction(function)->block(), createTerminator, inputTypes, importOptions); mlirRegionAppendOwnedBlock(bodyRegion, block); return func; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp index 75013d5ee9a..edf0f55a361 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp @@ -21,7 +21,11 @@ #include "mlir-c/Diagnostics.h" #include "torch-mlir-c/TorchTypes.h" +#if TORCH_VERSION_LT(1, 15) +// do nothing +#else #include "ATen/native/quantized/PackedParams.h" +#endif using namespace torch_mlir; @@ -48,14 +52,50 @@ using namespace torch_mlir; // which is compatible with the semantics we want (for the subset it doesn't // throw an error on). namespace { +#if TORCH_VERSION_LT(1, 7) +#include "torch/csrc/utils/hash.h" +#endif + +#if TORCH_VERSION_LT(1, 8) +inline size_t IValueHash(const c10::IValue &v) { + using namespace torch; + using namespace c10; + if (v.isNone()) { + return 0; + } else if (v.isInt()) { + return get_hash(v.toInt()); + } else if (v.isBool()) { + return get_hash(v.toBool()); + } else if (v.isDouble()) { + return get_hash(v.toDouble()); + } else if (v.isTensor()) { + // Tensor __hash__ is equivalent to `id()`, so take the pointer value of + // the tensor to emulate it + return get_hash(v.toTensor().unsafeGetTensorImpl()); + } else if (v.isString()) { + return get_hash(v.toStringRef()); + } else if (v.isTuple()) { + return get_hash(v.toTuple()); + } else if (v.isDevice()) { + return get_hash(v.toDevice()); + } else { + return std::hash()( + static_cast(v.internalToPointer())); + } +} +#else +inline size_t IValueHash(const c10::IValue &v) { + return c10::IValue::hash(v); +} +#endif + struct IValueHasher { size_t operator()(const c10::IValue &ivalue) const { if (ivalue.isObject() || ivalue.isList() || ivalue.isGenericDict()) { return std::hash()( static_cast(ivalue.internalToPointer())); } - - return c10::IValue::hash(ivalue); + return IValueHash(ivalue); } }; } // namespace @@ -345,6 +385,9 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { torchMlirTorchNoneTypeGet(context)); return mlirOperationGetResult(operation, 0); } +#if TORCH_VERSION_LT(1, 15) + // do nothing +#else if (ivalue.isCustomClass()) { if (ivalue.type().get() == c10::getCustomClassType>() @@ -365,6 +408,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { return mlirOperationGetResult(operation, 0); } } +#endif std::stringstream msg; msg << "Unsupported ivalue: " << ivalue; throw std::invalid_argument(msg.str()); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index 15cffedbe83..e9fb3c0541b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -324,7 +324,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, if (kind == c10::prim::CallFunction) { auto functionType = node->input(0)->type()->cast(); torch::jit::Block *calleeEntryBlock = - torch::jit::toGraphFunction(*functionType->function()).graph()->block(); + getGraphFromFunction(functionType->function())->block(); auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { return getMlirTypeFromTorchType(loc, v->type(), importOptions); }); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index e0420022d58..5e3035bfa7e 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -22,6 +22,19 @@ #include "torch-mlir-c/TorchOps.h" #include "torch-mlir-c/TorchTypes.h" +#if TORCH_VERSION_LT(1, 8) +#include "torch/custom_class.h" +#endif + +std::shared_ptr +torch_mlir::getGraphFromFunction(torch::jit::Function *function) { +#if TORCH_VERSION_LT(1, 11) + return function->graph(); +#else + return toGraphFunction(*function).graph(); +#endif +} + using namespace torch_mlir; static MlirType getMlirTypeForTorchScalarTypeRaw(MlirContext context, @@ -162,12 +175,24 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, /*optionalDtype=*/ elementType); } + // Ranked with possibly dynamic dims. auto &symbolicShape = tensorType->symbolic_sizes(); +#if TORCH_VERSION_LT(1, 8) + auto getSymbolicShape = [&](size_t d) { + const auto &dims = symbolicShape.sizes(); + if (!dims) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims).at(d); + }; +#else + auto getSymbolicShape = [&](size_t d) { return symbolicShape[d]; }; +#endif std::vector dims; dims.resize(*sizes.rank()); for (size_t i = 0; i < dims.size(); ++i) { - auto shapeSymbol = symbolicShape[i]; + auto shapeSymbol = getSymbolicShape(i); dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1; } @@ -212,6 +237,9 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, return torchMlirTorchTupleTypeGet(context, containedTypes.size(), containedTypes.data()); } +#if TORCH_VERSION_LT(1, 10) +// do nothing +#else case TypeKind::UnionType: { std::vector containedTypes; for (const c10::TypePtr &type : @@ -221,6 +249,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, return torchMlirTorchUnionTypeGet(context, containedTypes.size(), containedTypes.data()); } +#endif case TypeKind::ListType: { return torchMlirTorchListTypeGet(getMlirTypeFromTorchType( loc, torchType->cast()->getElementType(), diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h index 82f39499903..b5eace32b86 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h @@ -22,6 +22,13 @@ namespace torch_mlir { +#define TORCH_VERSION_LT(major, minor) \ + (defined(PYTORCH_MAJOR_VERSION) && defined(PYTORCH_MINOR_VERSION) && \ + (PYTORCH_MAJOR_VERSION == major && PYTORCH_MINOR_VERSION < minor)) + +std::shared_ptr +getGraphFromFunction(torch::jit::Function *function); + /// Thrown on failure when details are in MLIR emitted diagnostics. class mlir_diagnostic_emitted : public std::runtime_error { public: diff --git a/test/Conversion/TorchToMhlo/dropout.mlir b/test/Conversion/TorchToMhlo/dropout.mlir new file mode 100644 index 00000000000..b61a61b3bf8 --- /dev/null +++ b/test/Conversion/TorchToMhlo/dropout.mlir @@ -0,0 +1,47 @@ +// RUN: torch-mlir-opt < %s --torch-function-to-torch-backend-pipeline --torch-backend-to-mhlo-backend-pipeline -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.native_dropout.train( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: f64) -> (tensor, tensor) { +// CHECK: %[[T0:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[CST_0:.*]] = arith.constant 1 : index +// CHECK: %[[CST_1:.*]] = arith.constant 0 : index +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64 +// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64 +// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf64>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor) -> tensor +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor +// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64 +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T5]], %[[CST_0]] : tensor +// CHECK: %[[CST_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[CST_I64_0]], %[[CST_I64_1]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.rng"(%[[T2]], %[[T1]], %[[T6]]) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = shape.shape_of %[[T7]] : tensor -> tensor<2xindex> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T4]], %[[T8]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T10:.*]] = mhlo.compare LT, %[[T7]], %[[T9]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert(%[[T10]]) : (tensor) -> tensor +// CHECK: %[[T12:.*]] = shape.shape_of %[[T11]] : tensor -> tensor<2xindex> +// CHECK: %[[T13:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> +// CHECK: %[[T14:.*]] = shape.cstr_broadcastable %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> +// CHECK: %[[T15:.*]] = shape.assuming %[[T14]] -> (tensor) { +// CHECK: %[[T16:.*]] = shape.broadcast %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> +// CHECK: %[[T17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T11]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T18:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T19:.*]] = mhlo.multiply %[[T17]], %[[T18]] : tensor +// CHECK: shape.assuming_yield %[[T19]] : tensor +// CHECK: } +// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T21:.*]] = "mhlo.reshape"(%[[T20]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor -> tensor<2xindex> +// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor +// CHECK: %[[T25:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T12]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor +// CHECK: %[[T26:.*]] = mhlo.compare GE, %[[T11]], %[[T25]], FLOAT : (tensor, tensor) -> tensor +// CHECK: return %[[T24]], %[[T26]] : tensor, tensor +func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>) { + %bool_true = torch.constant.bool true + %result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> + return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> +} \ No newline at end of file diff --git a/test/Conversion/TorchToStablehlo/linear.mlir b/test/Conversion/TorchToStablehlo/linear.mlir index b9bac97ca6c..d28bb0173c5 100644 --- a/test/Conversion/TorchToStablehlo/linear.mlir +++ b/test/Conversion/TorchToStablehlo/linear.mlir @@ -501,3 +501,60 @@ func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7 %3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,15,15],f32> return %3 : !torch.vtensor<[1,4,15,15],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.linear( +// CHECK-NOT: mhlo.dynamic_reshape +// CHECK: mhlo.transpose +// CHECK: mhlo.dot +// CHECK: chlo.broadcast_add +func.func @torch.aten.linear(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[4,5],f32> { + %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[4,5],f32> + return %1 : !torch.vtensor<[4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.linear$nobias( +// CHECK-NOT: mhlo.dynamic_reshape +// CHECK: mhlo.transpose +// CHECK: mhlo.dot +// CHECK-NOT: chlo.broadcast_add +func.func @torch.aten.linear$nobias(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[4,5],f32> { + %none = torch.constant.none + %1 = torch.aten.linear %arg0, %arg1, %none : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.none -> !torch.vtensor<[4,5],f32> + return %1 : !torch.vtensor<[4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.linear$dynamic( +// CHECK: mhlo.transpose +// CHECK: arith.muli +// CHECK: arith.muli +// CHECK: tensor.from_elements +// CHECK: mhlo.dynamic_reshape +// CHECK: mhlo.dot +// CHECK: mhlo.dynamic_reshape +// CHECK: chlo.broadcast_add +func.func @torch.aten.linear$dynamic(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,5],f32> { + %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,5],f32> + return %1 : !torch.vtensor<[?,?,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.linear$dynamic4D( +// CHECK: mhlo.transpose +// CHECK: arith.muli +// CHECK: arith.muli +// CHECK: tensor.from_elements +// CHECK: mhlo.dynamic_reshape +// CHECK: mhlo.dot +// CHECK: mhlo.dynamic_reshape +// CHECK: chlo.broadcast_add +func.func @torch.aten.linear$dynamic4D(%arg0: !torch.vtensor<[?,?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,?,5],f32> { + %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,?,5],f32> + return %1 : !torch.vtensor<[?,?,?,5],f32> +} \ No newline at end of file diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index abfd3ea613a..37b006f2fab 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -104,6 +104,7 @@ cc_library( deps = [ ":MLIRTorchOpsIncGen", ":MLIRTorchTypesIncGen", + "@llvm-project//mlir:CastInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", @@ -269,6 +270,7 @@ gentbl_cc_library( [ "-gen-pass-decls", "-DTORCH_MLIR_ENABLE_STABLEHLO", + "-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32", ], "include/torch-mlir/Conversion/Passes.h.inc", ), @@ -441,6 +443,7 @@ cc_library( "lib/Conversion/TorchToStablehlo/*.cpp", ]), hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]), + copts = ['-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32'], strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPassesIncGen", @@ -479,10 +482,14 @@ cc_library( cc_library( name = "TorchMLIRTorchConversionPasses", - srcs = glob([ - "lib/Dialect/TorchConversion/Transforms/*.cpp", - "lib/Dialect/TorchConversion/Transforms/*.h", - ]), + srcs = glob( + [ + "lib/Dialect/TorchConversion/Transforms/*.cpp", + "lib/Dialect/TorchConversion/Transforms/*.h", + ], + # Exclude the files belong to other targets. + exclude = ["lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp"], + ), hdrs = glob(["include/torch-mlir/Dialect/TorchConversion/Transforms/*.h"]), strip_include_prefix = "include", deps = [