From 1b595c0c1e190a3176b1b2df8ae4e276408b6993 Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 24 Sep 2021 00:18:59 +0900 Subject: [PATCH] [CUDA] Swap block x and z dimension for conv2d NHWC schedule (#9087) --- python/tvm/topi/cuda/conv2d_nhwc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/conv2d_nhwc.py b/python/tvm/topi/cuda/conv2d_nhwc.py index e4361e30b5c3..f8115830ce50 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc.py +++ b/python/tvm/topi/cuda/conv2d_nhwc.py @@ -86,14 +86,14 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): # Schedule for output ni, hi, wi, fi = s[output].op.axis - bz = s[output].fuse(hi, wi) + bx = s[output].fuse(hi, wi) tx, fi = s[output].split(fi, factor=tile_c) txz, tx = s[output].split(tx, factor=num_thread_c) - bx, txz = s[output].split(txz, factor=vthread_c) + bz, txz = s[output].split(txz, factor=vthread_c) ty, ni = s[output].split(ni, factor=tile_n) tyz, ty = s[output].split(ty, factor=num_thread_n) by, tyz = s[output].split(tyz, factor=vthread_n) - s[output].reorder(bz, by, bx, tyz, txz, ty, tx, ni, fi) + s[output].reorder(bx, by, bz, tyz, txz, ty, tx, ni, fi) s[output].bind(bz, block_z) s[output].bind(by, block_y) s[output].bind(bx, block_x)