From 2967d3216170f2ddd8e8aa4925464ca88047c931 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Wed, 22 Jun 2022 11:36:58 +0800 Subject: [PATCH] fix divide_floor & export promoteTypes api (#9) --- include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h | 1 + lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 5 ++++- lib/Dialect/Torch/Utils/TorchUpstream.cpp | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 0d2d75b7818f..970608e23486 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -153,6 +153,7 @@ enum MemoryFormat { //===----------------------------------------------------------------------===// enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions }; +ScalarType promoteTypes(ScalarType a, ScalarType b); } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6d3dda97651e..9527875e83d3 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2071,8 +2071,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFloorDivideOp op, PatternRewriter &rewriter) const override { + // https://pytorch.org/docs/stable/generated/torch.floor_divide.html + // PyTorch aten.floor_divide is a misnomer because it actually rounds + // the quotient towards zero instead of taking its floor. Value cstStrFloor = - rewriter.create(op.getLoc(), "floor"); + rewriter.create(op.getLoc(), "trunc"); rewriter.replaceOpWithNewOp( op, op.getType(), op.self(), op.other(), /*rounding_mode=*/cstStrFloor); diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 37ffffabd8fd..6cd6f1e1143d 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -26,7 +26,7 @@ static inline bool isQIntType(ScalarType t) { // Type promotion related code are copied from // aten/src/ATen/native/TypeProperties.*. //===----------------------------------------------------------------------===// -static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { +ScalarType promoteTypes(ScalarType a, ScalarType b) { // This is generated according to NumPy's promote_types constexpr auto u1 = ScalarType::Byte; constexpr auto i1 = ScalarType::Char;