Skip to content

Commit

Permalink
[MHLO] fix tensor mode aten.div op pattern
Browse files Browse the repository at this point in the history
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
  • Loading branch information
5 people committed Aug 5, 2022
1 parent 31727f8 commit 2333621
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
35 changes: 31 additions & 4 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
"only floating-point or integer datatype legalization supported");
}

Value lhsTensor = lhs;
if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs;
} else if (!rhsType) {
Expand All @@ -217,8 +216,37 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);

if (!isa<AtenDivTensorModeOp>(op)) {
rewriter.replaceOp(op, result);
return success();
}

AtenDivTensorModeOp divTensorModeOp =
llvm::dyn_cast<AtenDivTensorModeOp>(op.getOperation());
std::string roundingMode;
if (!matchPattern(divTensorModeOp.rounding_mode(),
m_TorchConstantStr(roundingMode)))
return rewriter.notifyMatchFailure(
op, "only support constant str rounding mode");

if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
auto sign = rewriter.create<mhlo::SignOp>(loc, result);
auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
}
rewriter.replaceOp(op, result);
return success();
}
};
Expand Down Expand Up @@ -554,7 +582,6 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
RankedTensorType outputType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto outputShape = outputType.getShape();
auto outputElemType = outputType.getElementType();
Value mhloTensor =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType);
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2167,8 +2167,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
PatternRewriter &rewriter) const override {
// https://pytorch.org/docs/stable/generated/torch.floor_divide.html
// PyTorch aten.floor_divide is a misnomer because it actually rounds
// the quotient towards zero instead of taking its floor.
Value cstStrFloor =
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
op, op.getType(), op.self(), op.other(),
/*rounding_mode=*/cstStrFloor);
Expand Down
4 changes: 2 additions & 2 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1113,8 +1113,8 @@ func.func @torch.aten.baddbmm(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.
// CHECK-LABEL: func @torch.aten.floor_divide(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[OTHER:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[CSTFLOOR:.*]] = torch.constant.str "floor"
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTFLOOR]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32>
// CHECK: %[[CSTTRUNC:.*]] = torch.constant.str "trunc"
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTTRUNC]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
Expand Down

0 comments on commit 2333621

Please sign in to comment.