Skip to content

Commit

Permalink
support scatter op legalize
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Dec 18, 2023
1 parent ef7d949 commit b5e65ea
Showing 1 changed file with 53 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,57 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
}
};

// Legalize mhlo.scatter to a lmhlo.scatter
struct HloToLhloScatterOpConverter : public BaseOpConversion<mhlo::ScatterOp> {
using BaseOpConversion<mhlo::ScatterOp>::BaseOpConversion;

LogicalResult matchAndRewrite(
mhlo::ScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {

auto loc = op->getLoc();
if (!llvm::hasSingleElement(op.getUpdateComputation())) {
return op->emitOpError()
<< "tensor to buffer conversion expects a single block "
"in the region containing the operation";
}

SmallVector<Value, 4> bufferArgs(adaptor.getOperands());
if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
auto newOp = rewriter.create<mhlo::HloToLhloOp<ScatterOp>>(
loc, std::nullopt, bufferArgs, op->getAttrs());

// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.getUpdateComputation(), newOp.getUpdateComputation(),
newOp.getUpdateComputation().end());

// Convert the region signature to memref and add extra result.
auto& entryBlock = newOp.getUpdateComputation().front();
TypeConverter::SignatureConversion sigConversion(
adaptor.getOperands().size());
for (auto arg : entryBlock.getArguments()) {
auto oldType = arg.getType().template cast<TensorType>();
auto newType =
MemRefType::get(oldType.getShape(), oldType.getElementType());
sigConversion.addInputs(arg.getArgNumber(), newType);
}

auto returnOp = cast<mhlo::ReturnOp>(entryBlock.getTerminator());
for (auto result : returnOp.getResults()) {
auto resultType = result.getType().template cast<TensorType>();
sigConversion.addInputs({MemRefType::get(resultType.getShape(),
resultType.getElementType())});

}
rewriter.applySignatureConversion(&newOp.getUpdateComputation(), sigConversion);

rewriter.replaceOp(
op, bufferArgs[adaptor.getOperands().size()]);

return success();
}
};

// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
Expand Down Expand Up @@ -579,7 +630,8 @@ void populateHloToLhloConversionPattern(
HloToLhloOpConverter<mhlo::DynamicUpdateSliceOp>,
HloToLhloReduceLikeOpConverter<mhlo::ReduceOp>,
HloToLhloReduceLikeOpConverter<mhlo::ReduceWindowOp>,
HloToLhloReturnOpConverter
HloToLhloReturnOpConverter,
HloToLhloScatterOpConverter
>(*converter, context);
// clang-format on
}
Expand Down

0 comments on commit b5e65ea

Please sign in to comment.