diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 0d2d75b7818f..970608e23486 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -153,6 +153,7 @@ enum MemoryFormat { //===----------------------------------------------------------------------===// enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions }; +ScalarType promoteTypes(ScalarType a, ScalarType b); } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e3e5236cfc1b..bc427bf347ac 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2068,8 +2068,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFloorDivideOp op, PatternRewriter &rewriter) const override { + // https://pytorch.org/docs/stable/generated/torch.floor_divide.html + // PyTorch aten.floor_divide is a misnomer because it actually rounds + // the quotient towards zero instead of taking its floor. Value cstStrFloor = - rewriter.create(op.getLoc(), "floor"); + rewriter.create(op.getLoc(), "trunc"); rewriter.replaceOpWithNewOp( op, op.getType(), op.self(), op.other(), /*rounding_mode=*/cstStrFloor); diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 37ffffabd8fd..6cd6f1e1143d 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -26,7 +26,7 @@ static inline bool isQIntType(ScalarType t) { // Type promotion related code are copied from // aten/src/ATen/native/TypeProperties.*. //===----------------------------------------------------------------------===// -static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { +ScalarType promoteTypes(ScalarType a, ScalarType b) { // This is generated according to NumPy's promote_types constexpr auto u1 = ScalarType::Byte; constexpr auto i1 = ScalarType::Char;