Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stablehlo] refactor amax, max, max.dim's lowering to stablehlo #3348

Merged
merged 2 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 185 additions & 53 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}

if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
Expand Down Expand Up @@ -121,6 +121,46 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
return nullptr;
}

static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
Type outTy,
ArrayRef<int64_t> dims,
PatternRewriter &rewriter) {
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy)
return nullptr;
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return nullptr;

stablehlo::ReduceOp reduce = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), outTy, input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = reduce.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result;
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
result = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else {
op->emitError("unimplemented lowering in "
"createReduceOpWithSingleRegionOp");
return nullptr;
}
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
return reduce.getResults()[0];
}

// Util for converting AtenArgmaxOp and AtenMaxDimOp
static std::optional<ValueRange>
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
Expand Down Expand Up @@ -371,35 +411,64 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
dim, options.dimSizeIndexBits)
.value();

if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);

auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, stablehloReduceResults[0],
outShapeTensor);
auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
if (op.getResult(1).use_empty()) {
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
outputShape.erase(outputShape.begin() + dim);
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(outputShape, inputElemTy),
ArrayRef<int64_t>{dim}, rewriter);
if (!reduceResult)
return failure();

if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);

auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, reduceResult, outShapeTensor);
rewriter.replaceOp(op, {stablehloReduceValueResult, Value()});
return success();
}
rewriter.replaceOp(op, {reduceResult, Value()});
return success();
} else {
auto stablehloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();

if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);

auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, stablehloReduceResults[0],
outShapeTensor);
auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
return success();
}
rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}

rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}
} // namespace

Expand Down Expand Up @@ -692,11 +761,11 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
}
} // namespace

// AtenMaxOp
// AtenAmaxOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
AtenMaxOp op, OpAdaptor adaptor,
LogicalResult ConvertAtenReductionOp<AtenAmaxOp>::matchAndRewrite(
AtenAmaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
Expand All @@ -717,40 +786,102 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
"AtenMaxOp to StableHLO");
}

bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}

SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
}
for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
// Drop invalid dims
if (isValidDim(d, inputTy.getRank())) {
dims.push_back(d);
}
}
llvm::sort(dims.begin(), dims.end());
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputTy.getDimSize(i));
}
}

Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims,
rewriter);
if (!reduceResult)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
if (keepDim) {
const auto &options = getOptions();
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), reduceResult,
outShapeTensor);
return success();
}
rewriter.replaceOp(op, reduceResult);
return success();
}
} // namespace

block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
// AtenMaxOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
AtenMaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxOp to StableHLO");
}

auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
SmallVector<int64_t> dims =
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value maxResult = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
}
Value reduceResult = createReduceOpWithSingleRegionOp(
op, input, RankedTensorType::get({}, inputElemTy), dims, rewriter);
if (!reduceResult)
return failure();

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
op, getTypeConverter()->convertType(op.getType()), reduceResult);
return success();
}
} // namespace
Expand Down Expand Up @@ -1205,6 +1336,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);
Expand Down
2 changes: 1 addition & 1 deletion projects/pt1/python/torch_mlir/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _get_for_tracing(
"aten.adaptive_avg_pool2d",
"aten.unflatten.int",
],
OutputType.STABLEHLO: [],
OutputType.STABLEHLO: ["aten.amax"],
}


Expand Down
Loading