Skip to content

Commit

Permalink
[Stablehlo] refactor amax, max, max.dim's lowering to stablehlo (#3348)
Browse files Browse the repository at this point in the history
* not to decompose `aten.amax` on `stablehlo` backend. Because it could
be lowering to `stablehlo.reduce` directly.
* lowering `aten.max.dim` to `stablehlo.reduce apply max` when
`AtenMaxDimOp.getIndices()` doesn't have users. It's more simple.
  • Loading branch information
qingyunqu authored May 15, 2024
1 parent 6b95dd4 commit 5928f68
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 54 deletions.
238 changes: 185 additions & 53 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}

if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
Expand Down Expand Up @@ -121,6 +121,46 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
return nullptr;
}

static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
Type outTy,
ArrayRef<int64_t> dims,
PatternRewriter &rewriter) {
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy)
return nullptr;
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return nullptr;

stablehlo::ReduceOp reduce = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), outTy, input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = reduce.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result;
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
result = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else {
op->emitError("unimplemented lowering in "
"createReduceOpWithSingleRegionOp");
return nullptr;
}
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
return reduce.getResults()[0];
}

// Util for converting AtenArgmaxOp and AtenMaxDimOp
static std::optional<ValueRange>
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
Expand Down Expand Up @@ -371,35 +411,64 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
dim, options.dimSizeIndexBits)
.value();

if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);

auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, stablehloReduceResults[0],
outShapeTensor);
auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
if (op.getResult(1).use_empty()) {
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
outputShape.erase(outputShape.begin() + dim);
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(outputShape, inputElemTy),
ArrayRef<int64_t>{dim}, rewriter);
if (!reduceResult)
return failure();

if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);

auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, reduceResult, outShapeTensor);
rewriter.replaceOp(op, {stablehloReduceValueResult, Value()});
return success();
}
rewriter.replaceOp(op, {reduceResult, Value()});
return success();
} else {
auto stablehloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();

if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);

auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, stablehloReduceResults[0],
outShapeTensor);
auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
return success();
}
rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}

rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}
} // namespace

Expand Down Expand Up @@ -692,11 +761,11 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
}
} // namespace

// AtenMaxOp
// AtenAmaxOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
AtenMaxOp op, OpAdaptor adaptor,
LogicalResult ConvertAtenReductionOp<AtenAmaxOp>::matchAndRewrite(
AtenAmaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
Expand All @@ -717,40 +786,102 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
"AtenMaxOp to StableHLO");
}

bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}

SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
}
for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
// Drop invalid dims
if (isValidDim(d, inputTy.getRank())) {
dims.push_back(d);
}
}
llvm::sort(dims.begin(), dims.end());
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputTy.getDimSize(i));
}
}

Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
rewriter);
if (!reduceResult)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
if (keepDim) {
const auto &options = getOptions();
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), reduceResult,
outShapeTensor);
return success();
}
rewriter.replaceOp(op, reduceResult);
return success();
}
} // namespace

block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
// AtenMaxOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
AtenMaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxOp to StableHLO");
}

auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
SmallVector<int64_t> dims =
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value maxResult = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
}
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get({}, inputElemTy), dims, rewriter);
if (!reduceResult)
return failure();

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
op, getTypeConverter()->convertType(op.getType()), reduceResult);
return success();
}
} // namespace
Expand Down Expand Up @@ -1205,6 +1336,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);
Expand Down
2 changes: 1 addition & 1 deletion projects/pt1/python/torch_mlir/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _get_for_tracing(
"aten.adaptive_avg_pool2d",
"aten.unflatten.int",
],
OutputType.STABLEHLO: [],
OutputType.STABLEHLO: ["aten.amax"],
}


Expand Down

0 comments on commit 5928f68

Please sign in to comment.