Skip to content

Commit

Permalink
[UnitTest] Parametrized test_conv2d_int8_intrinsics (#9143)
Browse files Browse the repository at this point in the history
Parametrized it to get more detailed information while debugging
failures in #9091, but isn't
semantically part of that PR.
  • Loading branch information
Lunderberg authored Sep 29, 2021
1 parent 285dbd8 commit 0467539
Showing 1 changed file with 106 additions and 119 deletions.
225 changes: 106 additions & 119 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,156 +1587,143 @@ def test_upsampling3d():
_test_upsampling3d("NDHWC", "trilinear", "align_corners")


@tvm.testing.uses_gpu
def test_conv2d_int8_intrinsics():
def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
@pytest.mark.skipif(tvm.target.codegen.llvm_version_major() < 8, reason="Requires LLVM 8")
class TestConv2DInt8Intrinsics:
supported_targets = [
"llvm -mcpu=nehalem",
"llvm -mcpu=core-avx2",
"llvm -mcpu=skylake-avx512",
"llvm -mcpu=cascadelake",
]

unsupported_targets = [
"llvm -mcpu=x86-64",
]

data_layout, kernel_layout = tvm.testing.parameters(
("NCHW", "OIHW"),
# TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout.
# Re-enable this after adding conv2d_NCHWc_int8 support for NHWC.
# ("NHWC", "HWIO"),
)

input_channels, output_channels = tvm.testing.parameters(
# Sweep the input channels to check int8 robustness
# Input channels should be a multiple of 4 internally.
(1, 16),
(4, 16),
(6, 16),
# Sweep the output channels to check int8 robustness
# Output channels should be a multiple of 16 internally.
(8, 4),
(8, 16),
(8, 20),
# Check that both non-divisible oc and ic work
(17, 29),
)

@tvm.testing.fixture
def fast_int8_intrinsic(self, target):
if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target:
return "pmaddubs"
elif "cascadelake" in target:
return "vpdpbusd"
else:
assert False, "Target should be Skylake or Cascadelake"

@tvm.testing.fixture
def assembly(
self,
target,
dtypes,
input_channels,
output_channels,
data_layout,
kernel_layout,
):
input_dtype, weight_dtype, output_dtype = dtypes

n, h, w, ch, cw = 1, 64, 64, 3, 3
image_size = (64, 64)
kernel_size = (3, 3)
batch_size = 1

h, w = image_size

if data_layout == "NCHW":
data_shape = (n, ic, h, w)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
data_shape = (batch_size, input_channels, *image_size)
elif data_layout == "NHWC":
data_shape = (n, h, w, ic)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
data_shape = (batch_size, *image_size, input_channels)
else:
raise ValueError("Not supported")
raise ValueError(f"Unsupported data layout: {data_layout}")
x = relay.var("x", relay.TensorType(data_shape, input_dtype))

if kernel_layout == "OIHW":
kernel_shape = (oc, ic, ch, cw)
kernel_shape = (output_channels, input_channels, *kernel_size)
elif kernel_layout == "HWIO":
kernel_shape = (ch, cw, ic, oc)
kernel_shape = (*kernel_size, input_channels, output_channels)
else:
raise ValueError("Not supported")

weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))

y = relay.nn.conv2d(
x,
weight,
kernel_size=(ch, cw),
channels=oc,
kernel_size=kernel_size,
channels=output_channels,
padding=(0, 0, 0, 1),
dilation=(1, 1),
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=output_dtype,
)

func = relay.Function([x, weight], y)

wdata = np.random.rand(*kernel_shape) * 10
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}

with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)

assembly = lib.get_source("asm")
return assembly

def _has_fast_int8_instructions(asm, target):
if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target:
return "pmaddubs" in asm
elif "cascadelake" in target:
return "vpdpbusd" in asm
else:
assert False, "Target should be Skylake or Cascadelake"

# TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout.
# Re-enable this after adding conv2d_NCHWc_int8 support for NHWC.

# compile conv2d for x86 (SSE3/AVX2/AVX512/VNNI capable) and test assembly contains *pmadd* instructions
targets = [
"llvm -mcpu=nehalem",
"llvm -mcpu=core-avx2",
"llvm -mcpu=skylake-avx512",
"llvm -mcpu=cascadelake",
]
llvm_version = tvm.target.codegen.llvm_version_major()
for target in targets:
if tvm.testing.device_enabled(target) and llvm_version >= 8:
dtypes = ("uint8", "int8", "int32")
# Sweep the input channels to check int8 robustness
# Input channels should be a multiple of 4 internally.
for ic in [1, 4, 6]:
asm = _compile(
ic=ic,
oc=16,
target=target,
data_layout="NCHW",
kernel_layout="OIHW",
dtypes=dtypes,
)
assert _has_fast_int8_instructions(asm, target)

# for ic in [1, 4, 6]:
# asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC",
# kernel_layout='HWIO',
# dtypes=dtypes)
# assert _has_fast_int8_instructions(asm, target)

# Sweep the output channels to check int8 robustness
# Output channels should be a multiple of 16 internally.
for oc in [4, 16, 20]:
asm = _compile(
ic=8,
oc=oc,
target=target,
data_layout="NCHW",
kernel_layout="OIHW",
dtypes=dtypes,
)
assert _has_fast_int8_instructions(asm, target)

# for oc in [4, 16, 20]:
# asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC",
# kernel_layout='HWIO',
# dtypes=dtypes)
# assert _has_fast_int8_instructions(asm, target)

# Check that both non-divisible oc and ic work
asm = _compile(
ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes
)
assert _has_fast_int8_instructions(asm, target)

# asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
# dtypes=dtypes)
# assert _has_fast_int8_instructions(asm, target)

# Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
for target in targets:
if tvm.testing.device_enabled(target) and llvm_version >= 8:
dtypes = ("int8", "int8", "int32")
# Check that both non-divisible oc and ic work
asm = _compile(
ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes
)
assert _has_fast_int8_instructions(asm, target)

# asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
# dtypes=dtypes)
# assert _has_fast_int8_instructions(asm, target)
return lib.get_source("asm")

# Ensure that code uses the fast int8 instructions when available.
@tvm.testing.parametrize_targets(*supported_targets)
@pytest.mark.parametrize(
"dtypes",
[
# compile conv2d for x86 (skylake, cascadelake) and test
# assembly contains *pmadd* instructions
("uint8", "int8", "int32"),
# Check that int8 x int8 goes through legalization so that
# fast instructions can be picked up.
("int8", "int8", "int32"),
],
)
def test_uses_intrinsic(
self,
fast_int8_intrinsic,
assembly,
):
assert fast_int8_intrinsic in assembly

# Ensure that code is generated when datatypes are not HW supported.
# dtypes = ('uint8', 'uint8', 'int32')
# asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
# dtypes=dtypes)
# # Check that intrinisic is not present in the assembly.
# assert not _has_fast_int8_instructions(asm, target)
# For datatypes that don't have HW support, ensure that code is
# generated without the fast int8 intrinsic.
@tvm.testing.parametrize_targets(*supported_targets)
@pytest.mark.parametrize("dtypes", [("uint8", "uint8", "int32")])
def test_no_intrinsic(
self,
fast_int8_intrinsic,
assembly,
):
assert fast_int8_intrinsic not in assembly

# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
target = "llvm -mcpu=x86-64"
if tvm.testing.device_enabled(target):
fast_int8_dtypes = ("uint8", "int8", "int32")
asm = _compile(
ic=16,
oc=32,
target=target,
data_layout="NCHW",
kernel_layout="OIHW",
dtypes=fast_int8_dtypes,
)
# Check that vector int mult and add instructions are generated.
assert "pmulhw" in asm and "paddd" in asm
@tvm.testing.parametrize_targets(*unsupported_targets)
@pytest.mark.parametrize("dtypes", [("uint8", "int8", "int32")])
def test_uses_vectorized_instruction(self, assembly):
assert "pmulhw" in assembly and "paddd" in assembly


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 0467539

Please sign in to comment.