diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index d6bc0a699172..1a87cdabd458 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -160,6 +160,7 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions }; //===-----------------------------------------------------------------------===// enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX }; +ScalarType promoteTypes(ScalarType a, ScalarType b); } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td index dd8f54bc5503..f9a8850b4f05 100644 --- a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td +++ b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td @@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", let assemblyFormat = [{ $operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result)) }]; + let hasCanonicalizer = 1; let hasVerifier = 1; } @@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso let assemblyFormat = [{ $operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result)) }]; + let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 1c1a2009dbe4..db7155fdd7f0 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -364,13 +364,17 @@ class ConvertTorchToArith : public ConvertTorchToArithBase target.addIllegalOp(); patterns.add>(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); + target.addIllegalOp(); patterns.add>( typeConverter, context); diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 0fafbf1336cd..07f041b20e47 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -29,6 +29,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +using namespace mlir::torch::TorchConversion; bool skipMultiplyAlpha(Value alphaValue) { double doubleValue; @@ -413,10 +414,8 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern { Value dValue = shape[i]; Value newD; int64_t dInt; - if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) { - return op->emitError("element of desired shape must be a scalar"); - } - if (i >= leadingRank && dInt == -1) { + if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) && + dInt == -1) { newD = rewriter.create(op->getLoc(), self, i - leadingRank); } else { @@ -637,7 +636,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), false), lhs); - rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); + auto outType = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + rewriter.replaceOpWithNewOp(op, outType, lhs, zeroTensor); return success(); } @@ -665,7 +668,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto erf = rewriter.create(loc, erfElement); auto erfAdd = rewriter.create(loc, erf, one); auto halfMul = rewriter.create(loc, erfAdd, half); - rewriter.replaceOpWithNewOp(op, input, halfMul); + auto outType = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + rewriter.replaceOpWithNewOp(op, outType, input, halfMul); return success(); } } // namespace @@ -1039,7 +1046,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp); INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp); + INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ @@ -1091,14 +1098,11 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenErfOp); - INSERT_ATENOP_PATTERN(AtenCatOp); - INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); INSERT_ATENOP_PATTERN(AtenNumelOp); diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp index 2f99e79f73c2..0e9cd655732c 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -369,6 +369,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); + llvm::sort(dims.begin(), dims.end()); auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); @@ -426,6 +427,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); + llvm::sort(dims.begin(), dims.end()); auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); @@ -446,7 +448,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( rewriter.create(op->getLoc(), maxResult); } - rewriter.replaceOp(op, mhloReduceOp.getResults()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + mhloReduceOp.getResults()); return success(); } } // namespace @@ -510,6 +514,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); + llvm::sort(dims.begin(), dims.end()); auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); @@ -551,7 +556,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( mhloReduceOp.getResult(0), outShapeTensor); return success(); } - rewriter.replaceOp(op, mhloReduceOp.getResults()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + mhloReduceOp.getResults()); return success(); } } // namespace diff --git a/lib/Conversion/TorchToMhlo/ViewLike.cpp b/lib/Conversion/TorchToMhlo/ViewLike.cpp index b6cf840be699..1b3ee5d5d74f 100644 --- a/lib/Conversion/TorchToMhlo/ViewLike.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLike.cpp @@ -221,12 +221,18 @@ class ConvertAtenViewOp : public OpConversionPattern { auto loc = op.getLoc(); auto newRank = dimSizes.size(); + auto outTy = OpConversionPattern::getTypeConverter()->convertType( + op.getType()); + if (newRank == 0 || rankType.getRank() == 0) { - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), + SmallVector newShape(newRank, 1); + Value output = rewriter.create( + loc, + RankedTensorType::get( + newShape, + outTy.template dyn_cast().getElementType()), adaptor.self()); + rewriter.replaceOpWithNewOp(op, outTy, output); return success(); } @@ -250,28 +256,19 @@ class ConvertAtenViewOp : public OpConversionPattern { Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); Value numel = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); - for (auto d : dimSizes) { - numel = rewriter.create(loc, numel, d); + auto rank = rankType.getRank(); + for (size_t d = 0; d < rank; ++d) { + Value dimSize = rewriter.create( + loc, intType, rewriter.create(loc, adaptor.self(), d)); + numel = rewriter.create(loc, numel, dimSize); } numel = rewriter.create(loc, rewriter.getIndexType(), numel); - - if (dimSizes.size() == 0) { - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - adaptor.self()); - return success(); - } Value mhloShape = rewriter.create(loc, dimSizes); Value computedShape = rewriter.create( loc, mhloShape.getType(), numel, mhloShape); rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - adaptor.self(), computedShape); + op, outTy, adaptor.self(), computedShape); return success(); } @@ -360,7 +357,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "the size of the dimension being squeezed is can't be unknown"); - rewriter.replaceOp(op, adaptor.self()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.self()); return success(); } @@ -402,8 +400,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (failed(unsqzTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create unsqueezed tensor"); - - rewriter.replaceOp(op, *unsqzTensorInfo); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), *unsqzTensorInfo); return success(); } } // namespace diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index bff1a2e8910c..3c790b58f2b0 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -356,6 +356,10 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { if (auto floatType = dtype.dyn_cast()) { return dtype; } else if (auto integerType = dtype.dyn_cast()) { + if (integerType.isUnsignedInteger()) { + return IntegerType::get(context, integerType.getWidth(), + IntegerType::Unsigned); + } return IntegerType::get(context, integerType.getWidth(), IntegerType::Signless); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0f9e5d149b0b..8446eeb87115 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -738,11 +738,14 @@ class DecomposeAtenRollOp : public OpRewritePattern { auto self = op.self(); auto selfTy = self.getType().cast(); // roll(input, shift, dim) = cat({ - // slice(input, dim, -shift, none), - // slice(input, dim, 0, -shift)}, dim) + // slice(input, dim, (dimSize-shift)%dimSize, none), + // slice(input, dim, 0, (dimSize-shift)%dimSize}, dim) auto imitateRoll = [&](Value input, Value shift, Value dim, int64_t cstDim) { - Value negShift = rewriter.create(loc, shift); + Value dimSize = rewriter.create(loc, input, dim); + Value shiftPlus = rewriter.create(loc, dimSize, shift); + Value splitPos = + rewriter.create(loc, shiftPlus, dimSize); ArrayRef inputShape = selfTy.getSizes(); SmallVector sizes; sizes.append(inputShape.begin(), inputShape.end()); @@ -750,14 +753,14 @@ class DecomposeAtenRollOp : public OpRewritePattern { Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), selfTy.getDtype()); Value slice0 = rewriter.create( - loc, sliceTy, input, dim, negShift, constNone, constOne); + loc, sliceTy, input, dim, splitPos, constNone, constOne); Value slice1 = rewriter.create( - loc, sliceTy, input, dim, constZero, negShift, constOne); + loc, sliceTy, input, dim, constZero, splitPos, constOne); Type listType = Torch::ListType::get(sliceTy); Value slices = rewriter.create( loc, listType, llvm::ArrayRef{slice0, slice1}); - return rewriter.create(loc, self.getType(), slices, dim); + return rewriter.create(loc, op.getType(), slices, dim); }; int rank = getTensorRank(self); if (rank < 0) @@ -1525,6 +1528,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, Value input, Value prob, Value &output) { auto inputType = input.getType().cast(); + auto inputDtype = inputType.getDtype(); auto probType = prob.getType().cast(); // Both the `input` and `prob` must be ranked tensors. if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() || @@ -1540,8 +1544,14 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, // Since the `aten.rand_like` op expects float-type operand, create a // float-type tensor with the same shape as that of the `input`. + Type floatDtype = rewriter.getF64Type(); + if (inputDtype.isa() && + inputDtype.cast().getWidth() < 64) { + floatDtype = rewriter.getF32Type(); + } + Value floatTensor = - convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type()); + convertTensorToDtype(rewriter, loc, input, floatDtype); Value none = rewriter.create(loc); Value randomVal = rewriter.create( loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none, @@ -1555,7 +1565,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, // As the `output` is expected to be of the `input` type, convert the boolean // tensor `lessThanP` to a `input` type tensor. - output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype()); + output = convertTensorToDtype(rewriter, loc, lessThanP, inputDtype); return success(); } diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 29e435b3d10f..a166068a3f3c 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -294,6 +294,7 @@ static bool isInitialValueTransitivelySafeToInline(Value initialValue, namespace { class InlineGlobalSlotsPass : public InlineGlobalSlotsBase { + void runOnOperation() override { ModuleOp module = getOperation(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index a67f65ec3444..c8049c567a01 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -377,6 +377,7 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis< using BaseT = dataflow::SparseDataFlowAnalysis>; using BaseT::SparseDataFlowAnalysis; + void setToEntryState(dataflow::Lattice *lattice) override {} // Compute the knowledge for the results of an op, based on the knowledge of // the operands and any information intrinsic to `op`. @@ -1108,7 +1109,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Otherwise, this is an unknown operation. Just mark all results as // having reached a pessimistic fixpoint. - markAllPessimisticFixpoint(results); + setAllToEntryStates(results); return; } diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 37ffffabd8fd..6cd6f1e1143d 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -26,7 +26,7 @@ static inline bool isQIntType(ScalarType t) { // Type promotion related code are copied from // aten/src/ATen/native/TypeProperties.*. //===----------------------------------------------------------------------===// -static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { +ScalarType promoteTypes(ScalarType a, ScalarType b) { // This is generated according to NumPy's promote_types constexpr auto u1 = ScalarType::Byte; constexpr auto i1 = ScalarType::Char; diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index 6abdd3c13d5d..61dade9408ac 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -56,6 +56,18 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes( return success(); } +void ToBuiltinTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](ToBuiltinTensorOp op, PatternRewriter &rewriter) { + auto fromBuiltinTensorOp = + op.getOperand().getDefiningOp(); + if (!fromBuiltinTensorOp) + return rewriter.notifyMatchFailure(op, "operand not FromBuiltinTensorOp"); + rewriter.replaceOp(op, fromBuiltinTensorOp.getOperand()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // FromBuiltinTensorOp //===----------------------------------------------------------------------===// @@ -71,6 +83,17 @@ LogicalResult FromBuiltinTensorOp::verify() { return success(); } +void FromBuiltinTensorOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](FromBuiltinTensorOp op, PatternRewriter &rewriter) { + auto toBuiltinTensorOp = op.getOperand().getDefiningOp(); + if (!toBuiltinTensorOp) + return rewriter.notifyMatchFailure(op, "operand not ToBuiltinTensorOp"); + rewriter.replaceOp(op, toBuiltinTensorOp.getOperand()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // FromI64Op //===----------------------------------------------------------------------===// diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 6704c54bf35e..d103b25d113b 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -286,6 +286,7 @@ gentbl_cc_library( [ "-gen-pass-decls", "-DTORCH_MLIR_ENABLE_MHLO", + "-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32", ], "include/torch-mlir/Conversion/Passes.h.inc", ), @@ -464,6 +465,7 @@ cc_library( hdrs = [ "include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h", ], + copts = ['-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32'], strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPassesIncGen",