diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc index 8693ff5fead..7145935ca9d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc @@ -350,6 +350,57 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion { } }; +// Legalize mhlo.scatter to a lmhlo.scatter +struct HloToLhloScatterOpConverter : public BaseOpConversion { + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::ScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + + auto loc = op->getLoc(); + if (!llvm::hasSingleElement(op.getUpdateComputation())) { + return op->emitOpError() + << "tensor to buffer conversion expects a single block " + "in the region containing the operation"; + } + + SmallVector bufferArgs(adaptor.getOperands()); + if (failed(convertResults(op, bufferArgs, rewriter))) return failure(); + auto newOp = rewriter.create>( + loc, std::nullopt, bufferArgs, op->getAttrs()); + + // Copy over the operations inside the region. + rewriter.inlineRegionBefore(op.getUpdateComputation(), newOp.getUpdateComputation(), + newOp.getUpdateComputation().end()); + + // Convert the region signature to memref and add extra result. + auto& entryBlock = newOp.getUpdateComputation().front(); + TypeConverter::SignatureConversion sigConversion( + adaptor.getOperands().size()); + for (auto arg : entryBlock.getArguments()) { + auto oldType = arg.getType().template cast(); + auto newType = + MemRefType::get(oldType.getShape(), oldType.getElementType()); + sigConversion.addInputs(arg.getArgNumber(), newType); + } + + auto returnOp = cast(entryBlock.getTerminator()); + for (auto result : returnOp.getResults()) { + auto resultType = result.getType().template cast(); + sigConversion.addInputs({MemRefType::get(resultType.getShape(), + resultType.getElementType())}); + + } + rewriter.applySignatureConversion(&newOp.getUpdateComputation(), sigConversion); + + rewriter.replaceOp( + op, bufferArgs[adaptor.getOperands().size()]); + + return success(); + } +}; + // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary // buffers if necessary. // @@ -579,7 +630,8 @@ void populateHloToLhloConversionPattern( HloToLhloOpConverter, HloToLhloReduceLikeOpConverter, HloToLhloReduceLikeOpConverter, - HloToLhloReturnOpConverter + HloToLhloReturnOpConverter, + HloToLhloScatterOpConverter >(*converter, context); // clang-format on }