We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
While running test_integration_cuda_tensorcore.py, I got the following error.
test_integration_cuda_tensorcore.py
@tvm.script.tir class Module: def main(var_A: ty.handle, var_B: ty.handle, var_C: ty.handle) -> None: A = tir.match_buffer(var_A, [512, 512], dtype="float16", elem_offset=0, align=128, offset_factor=1) B = tir.match_buffer(var_B, [512, 512], dtype="float16", elem_offset=0, align=128, offset_factor=1) C = tir.match_buffer(var_C, [512, 512], elem_offset=0, align=128, offset_factor=1) # body with tir.block([], "root"): tir.reads([]) tir.writes([]) C_local = tir.alloc_buffer([512, 512], elem_offset=0, scope="local", align=128, offset_factor=1) A_shared = tir.alloc_buffer([512, 512], dtype="float16", elem_offset=0, scope="shared", align=128, offset_factor=1) B_shared = tir.alloc_buffer([512, 512], dtype="float16", elem_offset=0, scope="shared", align=128, offset_factor=1) A_shared_wmma_matrix_a = tir.alloc_buffer([512, 512], dtype="float16", elem_offset=0, scope="wmma.matrix_a", align=128, offset_factor=1) B_shared_wmma_matrix_b = tir.alloc_buffer([512, 512], dtype="float16", elem_offset=0, scope="wmma.matrix_b", align=128, offset_factor=1) C_local_wmma_accumulator = tir.alloc_buffer([512, 512], elem_offset=0, scope="wmma.accumulator", align=128, offset_factor=1) for i0_0_0_i1_0_0_fused in tir.thread_binding(0, 2, thread = "blockIdx.x"): for i0_0_1_i1_0_1_fused in tir.thread_binding(0, 4, thread = "vthread"): for i0_0_2_i1_0_2_fused in tir.thread_binding(0, 16, thread = "threadIdx.x"): for i0_0_3_init, i1_0_3_init in tir.grid(2, 4): with tir.block([32, 32], "blockized_C_init") as [io_init, jo_init]: tir.bind(io_init, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*16) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*2)) + i0_0_3_init)) tir.bind(jo_init, ((((i0_0_0_i1_0_0_fused*16) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*8)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*4)) + i1_0_3_init)) tir.reads([]) tir.writes([C_local_wmma_accumulator[(io_init*16):((io_init*16) + 16), (jo_init*16):((jo_init*16) + 16)]]) with tir.block([1, 1], "blockized_C_init") as [i_inito, j_inito]: tir.bind(i_inito, 0) tir.bind(j_inito, 0) tir.reads([]) tir.writes([C_local_wmma_accumulator[(io_init*16):((io_init*16) + 16), (jo_init*16):((jo_init*16) + 16)]]) tir.evaluate(tir.tvm_fill_fragment(C_local_wmma_accumulator.data, 16, 16, 16, tir.floordiv(tir.get_elem_offset(C_local_wmma_accumulator[0, 0], dtype="int32"), 256), tir.float32(0), dtype="handle")) for i2_0_0 in tir.serial(0, 2): for ax0_ax1_fused_0 in tir.serial(0, 16384, annotation = {"loop_type":"lazy_cooperative_fetch"}): for ax0_ax1_fused_1 in tir.vectorized(0, 4): with tir.block([512, 512], "B_shared") as [v0, v1]: tir.bind(v0, ((i2_0_0*256) + tir.floordiv(((ax0_ax1_fused_0*4) + ax0_ax1_fused_1), 256))) tir.bind(v1, ((i0_0_0_i1_0_0_fused*256) + tir.floormod(((ax0_ax1_fused_0*4) + ax0_ax1_fused_1), 256))) tir.reads([B[v0, v1]]) tir.writes([B_shared[v0, v1]]) B_shared[v0, v1] = B[v0, v1] for ax0_ax1_fused_0_1 in tir.serial(0, 32768, annotation = {"loop_type":"lazy_cooperative_fetch"}): for ax0_ax1_fused_1_1 in tir.vectorized(0, 4): with tir.block([512, 512], "A_shared") as [v0_1, v1_1]: tir.bind(v0_1, tir.floordiv(((ax0_ax1_fused_0_1*4) + ax0_ax1_fused_1_1), 256)) tir.bind(v1_1, ((i2_0_0*256) + tir.floormod(((ax0_ax1_fused_0_1*4) + ax0_ax1_fused_1_1), 256))) tir.reads([A[v0_1, v1_1]]) tir.writes([A_shared[v0_1, v1_1]]) A_shared[v0_1, v1_1] = A[v0_1, v1_1] for i2_0_1, i0_0_3, i1_0_3, i2_0_2, i0_0_4, i1_0_4 in tir.grid(8, 2, 4, 2, 1, 1): with tir.block([32, 32], "blockized_B_shared_wmma.matrix_b") as [v0o, v1o]: tir.bind(v0o, (((i2_0_0*16) + (i2_0_1*2)) + i2_0_2)) tir.bind(v1o, ((((i0_0_0_i1_0_0_fused*16) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*8)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*4)) + i1_0_3)) tir.reads([B_shared[(v0o*16):((v0o*16) + 16), (v1o*16):((v1o*16) + 16)]]) tir.writes([B_shared_wmma_matrix_b[(v0o*16):((v0o*16) + 16), (v1o*16):((v1o*16) + 16)]]) tir.evaluate(tir.tvm_load_matrix_sync(B_shared_wmma_matrix_b.data, 16, 16, 16, tir.floordiv(tir.get_elem_offset(B_shared_wmma_matrix_b[0, 0], dtype="int32"), 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), B_shared.data, tir.get_elem_offset(B_shared[0, 0], dtype="int32"), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) with tir.block([32, 32], "blockized_A_shared_wmma.matrix_a") as [v0o_1, v1o_1]: tir.bind(v0o_1, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*16) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*2)) + i0_0_3)) tir.bind(v1o_1, (((i2_0_0*16) + (i2_0_1*2)) + i2_0_2)) tir.reads([A_shared[(v0o_1*16):((v0o_1*16) + 16), (v1o_1*16):((v1o_1*16) + 16)]]) tir.writes([A_shared_wmma_matrix_a[(v0o_1*16):((v0o_1*16) + 16), (v1o_1*16):((v1o_1*16) + 16)]]) tir.evaluate(tir.tvm_load_matrix_sync(A_shared_wmma_matrix_a.data, 16, 16, 16, tir.floordiv(tir.get_elem_offset(A_shared_wmma_matrix_a[0, 0], dtype="int32"), 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), A_shared.data, tir.get_elem_offset(A_shared[0, 0], dtype="int32"), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) with tir.block([32, 32, tir.reduce_axis(0, 32)], "blockized_C_update") as [io, jo, ko]: tir.bind(io, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*16) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*2)) + i0_0_3)) tir.bind(jo, ((((i0_0_0_i1_0_0_fused*16) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*8)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*4)) + i1_0_3)) tir.bind(ko, (((i2_0_0*16) + (i2_0_1*2)) + i2_0_2)) tir.reads([C_local_wmma_accumulator[(io*16):((io*16) + 16), (jo*16):((jo*16) + 16)], A_shared_wmma_matrix_a[(io*16):((io*16) + 16), (ko*16):((ko*16) + 16)], B_shared_wmma_matrix_b[(ko*16):((ko*16) + 16), (jo*16):((jo*16) + 16)]]) tir.writes([C_local_wmma_accumulator[(io*16):((io*16) + 16), (jo*16):((jo*16) + 16)]]) with tir.block([1, 1, tir.reduce_axis(0, 1)], "blockized_C") as [io_1, jo_1, ko_1]: tir.bind(io_1, 0) tir.bind(jo_1, 0) tir.bind(ko_1, 0) tir.reads([C_local_wmma_accumulator[(io*16):((io*16) + 16), (jo*16):((jo*16) + 16)], A_shared_wmma_matrix_a[(io*16):((io*16) + 16), (ko*16):((ko*16) + 16)], B_shared_wmma_matrix_b[(ko*16):((ko*16) + 16), (jo*16):((jo*16) + 16)]]) tir.writes([C_local_wmma_accumulator[(io*16):((io*16) + 16), (jo*16):((jo*16) + 16)]]) tir.evaluate(tir.tvm_mma_sync(C_local_wmma_accumulator.data, tir.floordiv(tir.get_elem_offset(C_local_wmma_accumulator[0, 0], dtype="int32"), 256), A_shared_wmma_matrix_a.data, tir.floordiv(tir.get_elem_offset(A_shared_wmma_matrix_a[0, 0], dtype="int32"), 256), B_shared_wmma_matrix_b.data, tir.floordiv(tir.get_elem_offset(B_shared_wmma_matrix_b[0, 0], dtype="int32"), 256), C_local_wmma_accumulator.data, tir.floordiv(tir.get_elem_offset(C_local_wmma_accumulator[0, 0], dtype="int32"), 256), dtype="handle")) with tir.block([32, 32], "blockized_C_local_wmma.accumulator") as [v0o_2, v1o_2]: tir.bind(v0o_2, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*16) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*2)) + i0_0_3)) tir.bind(v1o_2, ((((i0_0_0_i1_0_0_fused*16) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*8)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*4)) + i1_0_3)) tir.reads([C_local_wmma_accumulator[(v0o_2*16):((v0o_2*16) + 16), (v1o_2*16):((v1o_2*16) + 16)]]) tir.writes([C_local[(v0o_2*16):((v0o_2*16) + 16), (v1o_2*16):((v1o_2*16) + 16)]]) tir.evaluate(tir.tvm_store_matrix_sync(C_local_wmma_accumulator.data, 16, 16, 16, tir.floordiv(tir.get_elem_offset(C_local_wmma_accumulator[0, 0], dtype="int32"), 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), C_local.data, tir.get_elem_offset(C_local[0, 0], dtype="int32"), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) for ax0, ax1 in tir.grid(32, 64): with tir.block([512, 512], "C_local") as [v0_2, v1_2]: tir.bind(v0_2, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*256) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*32)) + ax0)) tir.bind(v1_2, ((((i0_0_0_i1_0_0_fused*256) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*128)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*64)) + ax1)) tir.reads([C_local[v0_2, v1_2]]) tir.writes([C[v0_2, v1_2]]) C[v0_2, v1_2] = C_local[v0_2, v1_2] Traceback (most recent call last): File "test_integration_cuda_tensorcore.py", line 229, in <module> test_integration_conv2d_nchwc() File "test_integration_cuda_tensorcore.py", line 224, in test_integration_conv2d_nchwc schedule(sch) File "test_integration_cuda_tensorcore.py", line 207, in schedule fused = sch.fuse(*sch.get_loops(w_read)[-6:]) File "/home/zxybazh/tvm-tensorir/python/tvm/tir/schedule/schedule.py", line 412, in fuse return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member File "/home/zxybazh/tvm-tensorir/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__ raise get_last_ffi_error() tvm.tir.schedule.schedule.ScheduleError: ScheduleError: An error occurred in the schedule primitive 'fuse'. The IR is: @tvm.script.tir class Module: def main(var_X: ty.handle, var_W: ty.handle, var_conv2d_nchwc: ty.handle) -> None: X = tir.match_buffer(var_X, [1, 6, 98, 98, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) W = tir.match_buffer(var_W, [12, 6, 3, 3, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) conv2d_nchwc = tir.match_buffer(var_conv2d_nchwc, [1, 12, 96, 96, 16], elem_offset=0, align=128, offset_factor=1) # body with tir.block([], "root"): tir.reads([]) tir.writes([]) conv2d_nchwc_local = tir.alloc_buffer([1, 12, 96, 96, 16], elem_offset=0, scope="local", align=128, offset_factor=1) W_shared = tir.alloc_buffer([12, 6, 3, 3, 16, 16], dtype="float16", elem_offset=0, scope="shared", align=128, offset_factor=1) for i0 in tir.serial(0, 1): for i1_0_i2_0_i3_0_0_i4_0_0_fused in tir.thread_binding(0, 4, thread = "blockIdx.x"): for i1_1_i2_1_i3_0_1_i4_0_1_fused in tir.thread_binding(0, 2, thread = "vthread"): for i1_2_i2_2_i3_0_2_i4_0_2_fused in tir.thread_binding(0, 12, thread = "threadIdx.x"): for i5_0_0, i6_0, i7_0 in tir.grid(2, 1, 3): for ax0, ax1, ax2, ax4, ax5 in tir.grid(12, 3, 3, 16, 16): with tir.block([12, 6, 3, 3, 16, 16], "W_shared") as [v0, v1, v2, v3, v4, v5]: tir.bind(v0, ax0) tir.bind(v1, ((i5_0_0*3) + ax1)) tir.bind(v2, ax2) tir.bind(v3, i7_0) tir.bind(v4, ax4) tir.bind(v5, ax5) tir.reads([W[v0, v1, v2, v3, v4, v5]]) tir.writes([W_shared[v0, v1, v2, v3, v4, v5]]) W_shared[v0, v1, v2, v3, v4, v5] = W[v0, v1, v2, v3, v4, v5] for i5_0_1, i6_1, i7_1, i1_3, i2_3, i3_0_3, i4_0_3, i5_0_2, i6_2, i7_2, i1_4, i2_4, i3_0_4, i4_0_4, i3_1, i4_1, i5_1 in tir.grid(1, 3, 1, 1, 2, 1, 1, 3, 1, 1, 6, 1, 6, 1, 16, 16, 16): with tir.block([1, 12, 96, 96, 16, tir.reduce_axis(0, 96), tir.reduce_axis(0, 3), tir.reduce_axis(0, 3)], "conv2d_nchwc") as [n, c0, h, w, c1, rc, rh, rw]: tir.bind(n, 0) tir.bind(c0, ((i1_1_i2_1_i3_0_1_i4_0_1_fused*6) + i1_4)) tir.bind(h, (((i1_0_i2_0_i3_0_0_i4_0_0_fused*24) + (i1_2_i2_2_i3_0_2_i4_0_2_fused*2)) + i2_3)) tir.bind(w, ((i3_0_4*16) + i3_1)) tir.bind(c1, i4_1) tir.bind(rc, (((i5_0_0*48) + (i5_0_2*16)) + i5_1)) tir.bind(rh, i6_1) tir.bind(rw, i7_0) tir.reads([conv2d_nchwc_local[n, c0, h, w, c1], X[n, tir.floordiv(rc, 16), (h + rh), (w + rw), tir.floormod(rc, 16)], W_shared[c0, tir.floordiv(rc, 16), rh, rw, tir.floormod(rc, 16), c1]]) tir.writes([conv2d_nchwc_local[n, c0, h, w, c1]]) with tir.init(): conv2d_nchwc_local[n, c0, h, w, c1] = tir.float32(0) conv2d_nchwc_local[n, c0, h, w, c1] = (conv2d_nchwc_local[n, c0, h, w, c1] + (tir.cast(X[n, tir.floordiv(rc, 16), (h + rh), (w + rw), tir.floormod(rc, 16)], "float32")*tir.cast(W_shared[c0, tir.floordiv(rc, 16), rh, rw, tir.floormod(rc, 16), c1], "float32"))) for ax1_1, ax2_1, ax3, ax4_1 in tir.grid(6, 2, 96, 16): with tir.block([1, 12, 96, 96, 16], "conv2d_nchwc_local") as [v0_1, v1_1, v2_1, v3_1, v4_1]: tir.bind(v0_1, 0) tir.bind(v1_1, ((i1_1_i2_1_i3_0_1_i4_0_1_fused*6) + ax1_1)) tir.bind(v2_1, (((i1_0_i2_0_i3_0_0_i4_0_0_fused*24) + (i1_2_i2_2_i3_0_2_i4_0_2_fused*2)) + ax2_1)) tir.bind(v3_1, ax3) tir.bind(v4_1, ax4_1) tir.reads([conv2d_nchwc_local[v0_1, v1_1, v2_1, v3_1, v4_1]]) tir.writes([conv2d_nchwc[v0_1, v1_1, v2_1, v3_1, v4_1]]) conv2d_nchwc[v0_1, v1_1, v2_1, v3_1, v4_1] = conv2d_nchwc_local[v0_1, v1_1, v2_1, v3_1, v4_1] Regions of interest: tir.For#0 for (i7_0, 0, 3) { for (ax0, 0, 12) { for (ax1, 0, 3) { for (ax2, 0, 3) { for (ax4, 0, 16) { for (ax5, 0, 16) { block W_shared(iter_var(v0, range(min=0, ext=12)), iter_var(v1, range(min=0, ext=6)), iter_var(v2, range(min=0, ext=3)), iter_var(v3, range(min=0, ext=3)), iter_var(v4, range(min=0, ext=16)), iter_var(v5, range(min=0, ext=16))) { bind(v0, ax0) bind(v1, ((i5_0_0*3) + ax1)) bind(v2, ax2) bind(v3, i7_0) bind(v4, ax4) bind(v5, ax5) reads([W[v0, v1, v2, v3, v4, v5]]) writes([W_shared[v0, v1, v2, v3, v4, v5]]) W_shared[v0, v1, v2, v3, v4, v5] = W[v0, v1, v2, v3, v4, v5] } } } } } } for (i5_0_1, 0, 1) { for (i6_1, 0, 3) { for (i7_1, 0, 1) { for (i1_3, 0, 1) { for (i2_3, 0, 2) { for (i3_0_3, 0, 1) { for (i4_0_3, 0, 1) { for (i5_0_2, 0, 3) { for (i6_2, 0, 1) { for (i7_2, 0, 1) { for (i1_4, 0, 6) { for (i2_4, 0, 1) { for (i3_0_4, 0, 6) { for (i4_0_4, 0, 1) { for (i3_1, 0, 16) { for (i4_1, 0, 16) { for (i5_1, 0, 16) { block conv2d_nchwc(iter_var(n, range(min=0, ext=1)), iter_var(c0, range(min=0, ext=12)), iter_var(h, range(min=0, ext=96)), iter_var(w, range(min=0, ext=96)), iter_var(c1, range(min=0, ext=16)), iter_var(rc, range(min=0, ext=96)), iter_var(rh, range(min=0, ext=3)), iter_var(rw, range(min=0, ext=3))) { bind(n, 0) bind(c0, ((i1_1_i2_1_i3_0_1_i4_0_1_fused*6) + i1_4)) bind(h, (((i1_0_i2_0_i3_0_0_i4_0_0_fused*24) + (i1_2_i2_2_i3_0_2_i4_0_2_fused*2)) + i2_3)) bind(w, ((i3_0_4*16) + i3_1)) bind(c1, i4_1) bind(rc, (((i5_0_0*48) + (i5_0_2*16)) + i5_1)) bind(rh, i6_1) bind(rw, i7_0) reads([conv2d_nchwc_local[n, c0, h, w, c1], X[n, floordiv(rc, 16), (h + rh), (w + rw), floormod(rc, 16)], W_shared[c0, floordiv(rc, 16), rh, rw, floormod(rc, 16), c1]]) writes([conv2d_nchwc_local[n, c0, h, w, c1]]) with init() { conv2d_nchwc_local[n, c0, h, w, c1] = 0f } conv2d_nchwc_local[n, c0, h, w, c1] = (conv2d_nchwc_local[n, c0, h, w, c1] + (float32(X[n, floordiv(rc, 16), (h + rh), (w + rw), floormod(rc, 16)])*float32(W_shared[c0, floordiv(rc, 16), rh, rw, floormod(rc, 16), c1]))) } } } } } } } } } } } } } } } } } } } tir.For#1 for (ax0, 0, 12) { for (ax1, 0, 3) { for (ax2, 0, 3) { for (ax4, 0, 16) { for (ax5, 0, 16) { block W_shared(iter_var(v0, range(min=0, ext=12)), iter_var(v1, range(min=0, ext=6)), iter_var(v2, range(min=0, ext=3)), iter_var(v3, range(min=0, ext=3)), iter_var(v4, range(min=0, ext=16)), iter_var(v5, range(min=0, ext=16))) { bind(v0, ax0) bind(v1, ((i5_0_0*3) + ax1)) bind(v2, ax2) bind(v3, i7_0) bind(v4, ax4) bind(v5, ax5) reads([W[v0, v1, v2, v3, v4, v5]]) writes([W_shared[v0, v1, v2, v3, v4, v5]]) W_shared[v0, v1, v2, v3, v4, v5] = W[v0, v1, v2, v3, v4, v5] } } } } } } Error message: The loops can't be fused because the inner loop tir.For#1 is not the only child of outer loop tir.For#0.
The text was updated successfully, but these errors were encountered:
Won't fix
Sorry, something went wrong.
No branches or pull requests
While running
test_integration_cuda_tensorcore.py
, I got the following error.The text was updated successfully, but these errors were encountered: