diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 0d2d75b7818..970608e2348 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -153,6 +153,7 @@ enum MemoryFormat { //===----------------------------------------------------------------------===// enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions }; +ScalarType promoteTypes(ScalarType a, ScalarType b); } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6d3dda97651..9527875e83d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2071,8 +2071,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFloorDivideOp op, PatternRewriter &rewriter) const override { + // https://pytorch.org/docs/stable/generated/torch.floor_divide.html + // PyTorch aten.floor_divide is a misnomer because it actually rounds + // the quotient towards zero instead of taking its floor. Value cstStrFloor = - rewriter.create(op.getLoc(), "floor"); + rewriter.create(op.getLoc(), "trunc"); rewriter.replaceOpWithNewOp( op, op.getType(), op.self(), op.other(), /*rounding_mode=*/cstStrFloor); diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 37ffffabd8f..6cd6f1e1143 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -26,7 +26,7 @@ static inline bool isQIntType(ScalarType t) { // Type promotion related code are copied from // aten/src/ATen/native/TypeProperties.*. //===----------------------------------------------------------------------===// -static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { +ScalarType promoteTypes(ScalarType a, ScalarType b) { // This is generated according to NumPy's promote_types constexpr auto u1 = ScalarType::Byte; constexpr auto i1 = ScalarType::Char;