diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index e86da57fb91578..8f9b21b7ee1e5b 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2352,7 +2352,7 @@ struct OperationConverter { LogicalResult legalizeUnresolvedMaterializations( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, - std::optional>> &inverseMapping); + DenseMap> &inverseMapping); /// Legalize an operation result that was marked as "erased". LogicalResult @@ -2454,10 +2454,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { LogicalResult OperationConverter::finalize(ConversionPatternRewriter &rewriter) { - std::optional>> inverseMapping; ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) || - failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl, + if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) + return failure(); + DenseMap> inverseMapping = + rewriterImpl.mapping.getInverse(); + if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl, inverseMapping))) return failure(); @@ -2483,15 +2485,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) { if (result.getType() == newValue.getType()) continue; - // Compute the inverse mapping only if it is really needed. - if (!inverseMapping) - inverseMapping = rewriterImpl.mapping.getInverse(); - // Legalize this result. rewriter.setInsertionPoint(op); if (failed(legalizeChangedResultType( op, result, newValue, opReplacement->getConverter(), rewriter, - rewriterImpl, *inverseMapping))) + rewriterImpl, inverseMapping))) return failure(); } } @@ -2503,6 +2501,8 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( ConversionPatternRewriterImpl &rewriterImpl) { // Functor used to check if all users of a value will be dead after // conversion. + // TODO: This should probably query the inverse mapping, same as in + // `legalizeChangedResultType`. auto findLiveUser = [&](Value val) { auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) { return rewriterImpl.isOpIgnored(user); @@ -2796,20 +2796,18 @@ static LogicalResult legalizeUnresolvedMaterialization( LogicalResult OperationConverter::legalizeUnresolvedMaterializations( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, - std::optional>> &inverseMapping) { - inverseMapping = rewriterImpl.mapping.getInverse(); - + DenseMap> &inverseMapping) { // As an initial step, compute all of the inserted materializations that we // expect to persist beyond the conversion process. DenseMap materializationOps; SetVector necessaryMaterializations; computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl, - *inverseMapping, necessaryMaterializations); + inverseMapping, necessaryMaterializations); // Once computed, legalize any necessary materializations. for (auto *mat : necessaryMaterializations) { if (failed(legalizeUnresolvedMaterialization( - *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping))) + *mat, materializationOps, rewriter, rewriterImpl, inverseMapping))) return failure(); } return success();