diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 32ea58fe9f3a..9c3612048098 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -511,7 +511,7 @@ def check(t0, t1, factor): # schedule s = tvm.te.create_schedule(C.op) - ob, ib = s[C].split(s[C].op.axis[0], nparts=128//factor) + ob, ib = s[C].split(s[C].op.axis[0], nparts=n // factor) _, iib = s[C].split(ib, factor=factor) s[C].vectorize(iib) s[C].bind(ob, tx) @@ -538,14 +538,26 @@ def skip(t0, t1): return True return False - types_4 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32", "float64", "int64", "uint64"] + types_4 = [ + "float16", + "float32", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "float64", + "int64", + "uint64", + ] types_8 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"] for t0, t1 in [(x, y) for x in types_4 for y in types_4 if not skip(x, y)]: check(t0, t1, 4) for t0, t1 in [(x, y) for x in types_8 for y in types_8 if not skip(x, y)]: check(t0, t1, 8) - check('int8', 'uint8', 16) - check('uint8', 'int8', 16) + check("int8", "uint8", 16) + check("uint8", "int8", 16) def sched(B):