Skip to content

Commit

Permalink
[MLIR][normalize-memrefs] Non-normalizable operations with identity m…
Browse files Browse the repository at this point in the history
…ap layouts do not block normalization of the entire function

The current approach is convervative in which whenever there is a
non-normalizable operations in a function will the function be labelled
as non-normalizable. It means it requires that all operations must have
MemRefsNormalizable trait.

This patch relaxes the requirement that if the memref map layouts of a
non-normalizable operation are identity, this operation does not block
the normalization of the other operations in the same function.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D125854
  • Loading branch information
tungld authored and bondhugula committed Aug 19, 2022
1 parent e941b03 commit 183c4a3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
23 changes: 16 additions & 7 deletions mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
/// Check whether all the uses of AllocOps, CallOps and function arguments of a
/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
/// or ReturnOp. Only if these constraints are satisfied will the function
/// become a candidate for normalization. We follow a conservative approach here
/// wherein even if the non-normalizable memref is not a part of the function's
/// argument or return type, we still label the entire function as
/// non-normalizable. We assume external functions to be normalizable.
/// become a candidate for normalization. When the uses of a memref are
/// non-normalizable and the memref map layout is trivial (identity), we can
/// still label the entire function as normalizable. We assume external
/// functions to be normalizable.
bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
// We assume external functions to be normalizable.
if (funcOp.isExternal())
Expand All @@ -157,7 +157,11 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
if (funcOp
.walk([&](memref::AllocOp allocOp) -> WalkResult {
Value oldMemRef = allocOp.getResult();
if (!isMemRefNormalizable(oldMemRef.getUsers()))
if (!oldMemRef.getType()
.cast<MemRefType>()
.getLayout()
.isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return WalkResult::interrupt();
return WalkResult::advance();
})
Expand All @@ -170,7 +174,11 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
llvm::seq<unsigned>(0, callOp.getNumResults())) {
Value oldMemRef = callOp.getResult(resIndex);
if (oldMemRef.getType().isa<MemRefType>())
if (!isMemRefNormalizable(oldMemRef.getUsers()))
if (!oldMemRef.getType()
.cast<MemRefType>()
.getLayout()
.isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return WalkResult::interrupt();
}
return WalkResult::advance();
Expand All @@ -181,7 +189,8 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
BlockArgument oldMemRef = funcOp.getArgument(argIndex);
if (oldMemRef.getType().isa<MemRefType>())
if (!isMemRefNormalizable(oldMemRef.getUsers()))
if (!oldMemRef.getType().cast<MemRefType>().getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return false;
}

Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Transforms/normalize-memrefs-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
return
}

// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm
// does not block the normalization of other operations.

// CHECK-LABEL: test_nonnorm_identity_layout
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
%0 = memref.alloc() : memref<1x16x14x14xf32>
"test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
"test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> ()
memref.dealloc %0 : memref<1x16x14x14xf32>

// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32>
// CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
// CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> ()
// CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32>
return
}

// Test with op_norm, with maps in the operations in the function.

// CHECK-LABEL: test_norm_mix
Expand Down

0 comments on commit 183c4a3

Please sign in to comment.