From c8350f010e33b1dc482aa42c4c6693dc10594044 Mon Sep 17 00:00:00 2001 From: weitao <1136862851@qq.com> Date: Fri, 27 Oct 2023 15:55:42 +0800 Subject: [PATCH 1/2] [Fix][TIR]fix symbolic strides lower --- src/tir/transforms/ir_utils.cc | 3 +- .../test_tir_transform_lower_opaque_block.py | 31 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 99ed4376590e..25c10dd6828d 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -417,7 +417,8 @@ Array GetBufferAllocationShape(const Buffer& buffer) { if (buffer->strides.size()) { ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); for (size_t i = buffer->strides.size() - 1; i > 0; --i) { - ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i]))); + ICHECK( + arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0)); alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]); } } diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py index 444e36bfbb7a..efab334032a9 100644 --- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py +++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py @@ -250,6 +250,33 @@ def transformed_strided_buffer_func( C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2) +@T.prim_func +def compacted_symbolic_strided_buffer_func(a: T.handle) -> None: + n = T.int32() + A = T.match_buffer(a, (1, n, 10240)) + padded_size = T.meta_var(T.min((n + 63) // 64 * 64, 96)) + # with T.block("root"): + for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): + with T.block(""): + A_pad_shared_dyn = T.alloc_buffer((1, padded_size, 64), strides=(72 * padded_size, 72, 1), scope="shared.dyn") + for ax0, ax1 in T.grid(96, 64): + with T.block("A_pad_shared.dyn"): + T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64) + A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(i * 128 + j * 32 + ax0 < n, A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], T.float32(0)) + + +@T.prim_func +def transformed_symbolic_strided_buffer_func(a: T.handle): + n = T.int32() + A = T.match_buffer(a, (1, n, 10240)) + for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): + A_pad_shared_dyn = T.allocate([1, T.min((n + 63) // 64 * 64, 96), 72], "float32", "shared.dyn") + A_pad_shared_dyn_1 = T.decl_buffer((1, T.min((n + 63) // 64 * 64, 96), 64), data=A_pad_shared_dyn, strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), scope="shared.dyn") + for ax0, ax1 in T.grid(96, 64): + if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64: + A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else(i * 128 + j * 32 + ax0 < n, A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], T.float32(0)) + + @T.prim_func def annotated_loops(a: T.handle) -> None: A = T.match_buffer(a, (16,), "float32") @@ -301,6 +328,10 @@ def test_strided_buffer(): _check(compacted_strided_buffer_func, transformed_strided_buffer_func) +def test_symbolic_strided_buffer(): + _check(compacted_symbolic_strided_buffer_func, transformed_symbolic_strided_buffer_func) + + def test_lower_te(): x = te.placeholder((1,)) y = te.compute((1,), lambda i: x[i] + 2) From 9595c68c9ea5be5075499b47cb8f29fe2fc9f37b Mon Sep 17 00:00:00 2001 From: weitao <1136862851@qq.com> Date: Sat, 28 Oct 2023 15:14:31 +0800 Subject: [PATCH 2/2] [Fix][TIR] run the black formatter --- .../test_tir_transform_lower_opaque_block.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py index efab334032a9..ae44d2127595 100644 --- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py +++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py @@ -258,11 +258,17 @@ def compacted_symbolic_strided_buffer_func(a: T.handle) -> None: # with T.block("root"): for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): with T.block(""): - A_pad_shared_dyn = T.alloc_buffer((1, padded_size, 64), strides=(72 * padded_size, 72, 1), scope="shared.dyn") + A_pad_shared_dyn = T.alloc_buffer( + (1, padded_size, 64), strides=(72 * padded_size, 72, 1), scope="shared.dyn" + ) for ax0, ax1 in T.grid(96, 64): with T.block("A_pad_shared.dyn"): T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64) - A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(i * 128 + j * 32 + ax0 < n, A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], T.float32(0)) + A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else( + i * 128 + j * 32 + ax0 < n, + A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], + T.float32(0), + ) @T.prim_func @@ -270,11 +276,22 @@ def transformed_symbolic_strided_buffer_func(a: T.handle): n = T.int32() A = T.match_buffer(a, (1, n, 10240)) for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): - A_pad_shared_dyn = T.allocate([1, T.min((n + 63) // 64 * 64, 96), 72], "float32", "shared.dyn") - A_pad_shared_dyn_1 = T.decl_buffer((1, T.min((n + 63) // 64 * 64, 96), 64), data=A_pad_shared_dyn, strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), scope="shared.dyn") + A_pad_shared_dyn = T.allocate( + [1, T.min((n + 63) // 64 * 64, 96), 72], "float32", "shared.dyn" + ) + A_pad_shared_dyn_1 = T.decl_buffer( + (1, T.min((n + 63) // 64 * 64, 96), 64), + data=A_pad_shared_dyn, + strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), + scope="shared.dyn", + ) for ax0, ax1 in T.grid(96, 64): if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64: - A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else(i * 128 + j * 32 + ax0 < n, A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], T.float32(0)) + A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else( + i * 128 + j * 32 + ax0 < n, + A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], + T.float32(0), + ) @T.prim_func