-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[memref] Handle edge case in subview of full static size fold #105635
Conversation
It is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>> return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` To: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` Which drops the dynamic offset from the `subview` op.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Benjamin Maxwell (MacDue) ChangesIt is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds: func.func @<!-- -->subview_of_static_full_size(
%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index)
-> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
%0 = memref.subview %arg0[%idx, 0][16, 4][1, 1]
: memref<16x4xf32, strided<[4, 1], offset: ?>>
to memref<16x4xf32, strided<[4, 1], offset: ?>>
return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
} To: func.func @<!-- -->subview_of_static_full_size(
%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index)
-> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
} Which drops the dynamic offset from the Full diff: https://github.com/llvm/llvm-project/pull/105635.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index d9295936ee97bd..f0d41754001400 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -1012,6 +1012,10 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
let extraClassDeclaration = [{
/// Print the attribute to the given output stream.
void print(raw_ostream &os) const;
+
+ /// Returns true if this layout is static, i.e. the strides and offset all
+ /// have a known value > 0.
+ bool hasStaticLayout() const;
}];
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 150049e5c5effe..9c021d3613f1c8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3279,11 +3279,14 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
- auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
- auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
-
- if (resultShapedType.hasStaticShape() &&
- resultShapedType == sourceShapedType) {
+ MemRefType sourceMemrefType = getSource().getType();
+ MemRefType resultMemrefType = getResult().getType();
+ auto resultLayout =
+ dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
+
+ if (resultMemrefType == sourceMemrefType &&
+ resultMemrefType.hasStaticShape() &&
+ (!resultLayout || resultLayout.hasStaticLayout())) {
return getViewSource();
}
@@ -3301,7 +3304,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
bool allSizesSame = llvm::equal(sizes, srcSizes);
if (allOffsetsZero && allStridesOne && allSizesSame &&
- resultShapedType == sourceShapedType)
+ resultMemrefType == sourceMemrefType)
return getViewSource();
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 89b1ed67f5d067..8861a940336133 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -229,6 +229,13 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
os << ">";
}
+/// Returns true if this layout is static, i.e. the strides and offset all have
+/// a known value > 0.
+bool StridedLayoutAttr::hasStaticLayout() const {
+ return !ShapedType::isDynamic(getOffset()) &&
+ !ShapedType::isDynamicShape(getStrides());
+}
+
/// Returns the strided layout as an affine map.
AffineMap StridedLayoutAttr::getAffineMap() const {
return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b15af9baca7dc7..02110bc2892d05 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -70,6 +70,19 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
// -----
+// CHECK-LABEL: func @negative_subview_of_static_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
+// CHECK-SAME: %[[IDX:.+]]: index
+// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1]
+// CHECK-SAME: to memref<16x4xf32, strided<[4, 1], offset: ?>>
+// CHECK: return %[[S]] : memref<16x4xf32, strided<[4, 1], offset: ?>>
+func.func @negative_subview_of_static_full_size(%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> {
+ %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>>
+ return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
+}
+
+// -----
+
func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
{
|
@llvm/pr-subscribers-mlir-core Author: Benjamin Maxwell (MacDue) ChangesIt is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds: func.func @<!-- -->subview_of_static_full_size(
%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index)
-> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
%0 = memref.subview %arg0[%idx, 0][16, 4][1, 1]
: memref<16x4xf32, strided<[4, 1], offset: ?>>
to memref<16x4xf32, strided<[4, 1], offset: ?>>
return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
} To: func.func @<!-- -->subview_of_static_full_size(
%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index)
-> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
} Which drops the dynamic offset from the Full diff: https://github.com/llvm/llvm-project/pull/105635.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index d9295936ee97bd..f0d41754001400 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -1012,6 +1012,10 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
let extraClassDeclaration = [{
/// Print the attribute to the given output stream.
void print(raw_ostream &os) const;
+
+ /// Returns true if this layout is static, i.e. the strides and offset all
+ /// have a known value > 0.
+ bool hasStaticLayout() const;
}];
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 150049e5c5effe..9c021d3613f1c8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3279,11 +3279,14 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
- auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
- auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
-
- if (resultShapedType.hasStaticShape() &&
- resultShapedType == sourceShapedType) {
+ MemRefType sourceMemrefType = getSource().getType();
+ MemRefType resultMemrefType = getResult().getType();
+ auto resultLayout =
+ dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
+
+ if (resultMemrefType == sourceMemrefType &&
+ resultMemrefType.hasStaticShape() &&
+ (!resultLayout || resultLayout.hasStaticLayout())) {
return getViewSource();
}
@@ -3301,7 +3304,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
bool allSizesSame = llvm::equal(sizes, srcSizes);
if (allOffsetsZero && allStridesOne && allSizesSame &&
- resultShapedType == sourceShapedType)
+ resultMemrefType == sourceMemrefType)
return getViewSource();
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 89b1ed67f5d067..8861a940336133 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -229,6 +229,13 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
os << ">";
}
+/// Returns true if this layout is static, i.e. the strides and offset all have
+/// a known value > 0.
+bool StridedLayoutAttr::hasStaticLayout() const {
+ return !ShapedType::isDynamic(getOffset()) &&
+ !ShapedType::isDynamicShape(getStrides());
+}
+
/// Returns the strided layout as an affine map.
AffineMap StridedLayoutAttr::getAffineMap() const {
return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b15af9baca7dc7..02110bc2892d05 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -70,6 +70,19 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
// -----
+// CHECK-LABEL: func @negative_subview_of_static_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
+// CHECK-SAME: %[[IDX:.+]]: index
+// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1]
+// CHECK-SAME: to memref<16x4xf32, strided<[4, 1], offset: ?>>
+// CHECK: return %[[S]] : memref<16x4xf32, strided<[4, 1], offset: ?>>
+func.func @negative_subview_of_static_full_size(%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> {
+ %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>>
+ return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
+}
+
+// -----
+
func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
{
|
@llvm/pr-subscribers-mlir-ods Author: Benjamin Maxwell (MacDue) ChangesIt is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds: func.func @<!-- -->subview_of_static_full_size(
%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index)
-> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
%0 = memref.subview %arg0[%idx, 0][16, 4][1, 1]
: memref<16x4xf32, strided<[4, 1], offset: ?>>
to memref<16x4xf32, strided<[4, 1], offset: ?>>
return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
} To: func.func @<!-- -->subview_of_static_full_size(
%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index)
-> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
} Which drops the dynamic offset from the Full diff: https://github.com/llvm/llvm-project/pull/105635.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index d9295936ee97bd..f0d41754001400 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -1012,6 +1012,10 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
let extraClassDeclaration = [{
/// Print the attribute to the given output stream.
void print(raw_ostream &os) const;
+
+ /// Returns true if this layout is static, i.e. the strides and offset all
+ /// have a known value > 0.
+ bool hasStaticLayout() const;
}];
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 150049e5c5effe..9c021d3613f1c8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3279,11 +3279,14 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
- auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
- auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
-
- if (resultShapedType.hasStaticShape() &&
- resultShapedType == sourceShapedType) {
+ MemRefType sourceMemrefType = getSource().getType();
+ MemRefType resultMemrefType = getResult().getType();
+ auto resultLayout =
+ dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
+
+ if (resultMemrefType == sourceMemrefType &&
+ resultMemrefType.hasStaticShape() &&
+ (!resultLayout || resultLayout.hasStaticLayout())) {
return getViewSource();
}
@@ -3301,7 +3304,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
bool allSizesSame = llvm::equal(sizes, srcSizes);
if (allOffsetsZero && allStridesOne && allSizesSame &&
- resultShapedType == sourceShapedType)
+ resultMemrefType == sourceMemrefType)
return getViewSource();
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 89b1ed67f5d067..8861a940336133 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -229,6 +229,13 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
os << ">";
}
+/// Returns true if this layout is static, i.e. the strides and offset all have
+/// a known value > 0.
+bool StridedLayoutAttr::hasStaticLayout() const {
+ return !ShapedType::isDynamic(getOffset()) &&
+ !ShapedType::isDynamicShape(getStrides());
+}
+
/// Returns the strided layout as an affine map.
AffineMap StridedLayoutAttr::getAffineMap() const {
return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b15af9baca7dc7..02110bc2892d05 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -70,6 +70,19 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
// -----
+// CHECK-LABEL: func @negative_subview_of_static_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
+// CHECK-SAME: %[[IDX:.+]]: index
+// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1]
+// CHECK-SAME: to memref<16x4xf32, strided<[4, 1], offset: ?>>
+// CHECK: return %[[S]] : memref<16x4xf32, strided<[4, 1], offset: ?>>
+func.func @negative_subview_of_static_full_size(%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> {
+ %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>>
+ return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
+}
+
+// -----
+
func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
{
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, LGTM cheers
…05635) It is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>> return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` To: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` Which drops the dynamic offset from the `subview` op.
…05635) It is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>> return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` To: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` Which drops the dynamic offset from the `subview` op.
It is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds:
To:
Which drops the dynamic offset from the
subview
op.