diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 907eec381447..6529d70746d2 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -208,7 +208,6 @@ class ConvertAtenMulDivOp : public OpConversionPattern { "only floating-point or integer datatype legalization supported"); } - Value lhsTensor = lhs; if (std::is_same()) { rhs = lhs; } else if (!rhsType) { @@ -217,8 +216,37 @@ class ConvertAtenMulDivOp : public OpConversionPattern { DenseIntElementsAttr bcastDimensions; lhs = mhlo::promoteType(rewriter, lhs, outType); rhs = mhlo::promoteType(rewriter, rhs, outType); - rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, - bcastDimensions); + auto loc = op.getLoc(); + Value result = + rewriter.create(loc, outType, lhs, rhs, bcastDimensions); + + if (!isa(op)) { + rewriter.replaceOp(op, result); + return success(); + } + + AtenDivTensorModeOp divTensorModeOp = + llvm::dyn_cast(op.getOperation()); + std::string roundingMode; + if (!matchPattern(divTensorModeOp.rounding_mode(), + m_TorchConstantStr(roundingMode))) + return rewriter.notifyMatchFailure( + op, "only support constant str rounding mode"); + + if (roundingMode == "trunc") { + // "trunc" - rounds the results of the division towards zero. Equivalent + // to C-style integer division. + auto sign = rewriter.create(loc, result); + auto abs = rewriter.create(loc, result); + auto floor = rewriter.create(loc, abs); + result = rewriter.create(loc, sign, floor).getResult(); + } + if (roundingMode == "floor") { + // "floor" - rounds the results of the division down. Equivalent to + // floor division in Python (the // operator) + result = rewriter.create(loc, result).getResult(); + } + rewriter.replaceOp(op, result); return success(); } }; @@ -554,7 +582,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType outputType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); - auto outputShape = outputType.getShape(); auto outputElemType = outputType.getElementType(); Value mhloTensor = mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType); @@ -968,6 +995,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); + INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp); #undef INSERT_BINARY_MULDIV_PATTERN diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d7abc82136cc..d020c57baca6 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2167,8 +2167,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFloorDivideOp op, PatternRewriter &rewriter) const override { + // https://pytorch.org/docs/stable/generated/torch.floor_divide.html + // PyTorch aten.floor_divide is a misnomer because it actually rounds + // the quotient towards zero instead of taking its floor. Value cstStrFloor = - rewriter.create(op.getLoc(), "floor"); + rewriter.create(op.getLoc(), "trunc"); rewriter.replaceOpWithNewOp( op, op.getType(), op.self(), op.other(), /*rounding_mode=*/cstStrFloor); diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir index 65e462d441bb..77aaea093ad5 100644 --- a/test/Conversion/TorchToMhlo/elementwise.mlir +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -540,3 +540,37 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 return %0 : !torch.vtensor<[?,?],i1> } +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$trunc( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[STR:.*]] = torch.constant.str "trunc" +// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor +// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor +// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor +// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor +// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %str = torch.constant.str "trunc" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$floor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[STR:.*]] = torch.constant.str "floor" +// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %str = torch.constant.str "floor" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index d6f4813fad8a..8c37005db913 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1113,8 +1113,8 @@ func.func @torch.aten.baddbmm(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch. // CHECK-LABEL: func @torch.aten.floor_divide( // CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[OTHER:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[CSTFLOOR:.*]] = torch.constant.str "floor" -// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTFLOOR]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32> +// CHECK: %[[CSTTRUNC:.*]] = torch.constant.str "trunc" +// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTTRUNC]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32> // CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>