Skip to content

Commit

Permalink
support scatter op legalize
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Dec 18, 2023
1 parent ef7d949 commit f1d36a8
Showing 1 changed file with 54 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ limitations under the License.
// This file implements logic for lowering HLO dialect to LHLO dialect.

#include <algorithm>
#include <array>
#include <memory>
#include <optional>
#include <utility>

Expand All @@ -38,10 +36,10 @@ limitations under the License.
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
Expand Down Expand Up @@ -350,6 +348,57 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
}
};

// Legalize mhlo.scatter to a lmhlo.scatter
struct HloToLhloScatterOpConverter : public BaseOpConversion<mhlo::ScatterOp> {
using BaseOpConversion<mhlo::ScatterOp>::BaseOpConversion;

LogicalResult matchAndRewrite(
mhlo::ScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {

auto loc = op->getLoc();
if (!llvm::hasSingleElement(op.getUpdateComputation())) {
return op->emitOpError()
<< "tensor to buffer conversion expects a single block "
"in the region containing the operation";
}

SmallVector<Value, 4> bufferArgs(adaptor.getOperands());
if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
auto newOp = rewriter.create<mhlo::HloToLhloOp<ScatterOp>>(
loc, std::nullopt, bufferArgs, op->getAttrs());

// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.getUpdateComputation(), newOp.getUpdateComputation(),
newOp.getUpdateComputation().end());

// Convert the region signature to memref and add extra result.
auto& entryBlock = newOp.getUpdateComputation().front();
TypeConverter::SignatureConversion sigConversion(
adaptor.getOperands().size());
for (auto arg : entryBlock.getArguments()) {
auto oldType = arg.getType().template cast<TensorType>();
auto newType =
MemRefType::get(oldType.getShape(), oldType.getElementType());
sigConversion.addInputs(arg.getArgNumber(), newType);
}

auto returnOp = cast<mhlo::ReturnOp>(entryBlock.getTerminator());
for (auto result : returnOp.getResults()) {
auto resultType = result.getType().template cast<TensorType>();
sigConversion.addInputs({MemRefType::get(resultType.getShape(),
resultType.getElementType())});

}
rewriter.applySignatureConversion(&newOp.getUpdateComputation(), sigConversion);

rewriter.replaceOp(
op, bufferArgs[adaptor.getOperands().size()]);

return success();
}
};

// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
Expand Down Expand Up @@ -579,7 +628,8 @@ void populateHloToLhloConversionPattern(
HloToLhloOpConverter<mhlo::DynamicUpdateSliceOp>,
HloToLhloReduceLikeOpConverter<mhlo::ReduceOp>,
HloToLhloReduceLikeOpConverter<mhlo::ReduceWindowOp>,
HloToLhloReturnOpConverter
HloToLhloReturnOpConverter,
HloToLhloScatterOpConverter
>(*converter, context);
// clang-format on
}
Expand Down

0 comments on commit f1d36a8

Please sign in to comment.