Skip to content

Commit

Permalink
[CUDNN] Add partitioning support for conv2d and log_softmax (#10961)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbaret authored Apr 12, 2022
1 parent 11d22bd commit 98fc649
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 1 deletion.
62 changes: 62 additions & 0 deletions python/tvm/relay/op/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm import te
from tvm.relay import transform
from tvm.contrib import cudnn
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import is_op, wildcard
from .te_target import lower_composite, relay_to_runtime
Expand All @@ -50,6 +51,8 @@ def partition_for_cudnn(
tvm.IRModule
The partitioned module.
"""
if params:
mod["main"] = bind_params_by_name(mod["main"], params)

seq = tvm.transform.Sequential(
[
Expand All @@ -71,19 +74,78 @@ def softmax_pattern() -> relay.Pattern:
"""Create pattern for softmax."""
return is_op("nn.softmax")(wildcard())

def log_softmax_pattern() -> relay.Pattern:
"""Create pattern for log_softmax."""
return is_op("nn.log_softmax")(wildcard())

def conv2d_pattern() -> relay.Pattern:
"""Create pattern for conv2d."""
return is_op("nn.conv2d")(wildcard(), wildcard())

def check_softmax(matched: relay.Call) -> bool:
"""Check if softmax is supported by cuDNN."""
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
return False

return True

def check_log_softmax(matched: relay.Call) -> bool:
"""Check if log_softmax is supported by cuDNN."""
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
return False

if len(matched.args[0].checked_type.shape) != 2:
return False

if matched.attrs["axis"] not in (1, -1):
return False

return True

def check_conv2d(matched: relay.Call) -> bool:
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
return False

if matched.attrs["data_layout"] != "NCHW" or matched.attrs["kernel_layout"] != "OIHW":
return False

padding = matched.attrs["padding"]
if padding[0] != padding[2] or padding[1] != padding[3]:
return False

return True

return [
("cudnn.softmax", softmax_pattern(), check_softmax),
("cudnn.log_softmax", log_softmax_pattern(), check_log_softmax),
("cudnn.conv2d", conv2d_pattern(), check_conv2d),
]


@lower_composite("cudnn.softmax")
def _lower_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a softmax using cuDNN."""
return cudnn.softmax(inputs[0], axis=op.attrs["axis"])


@lower_composite("cudnn.log_softmax")
def _lower_log_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a log_softmax using cuDNN."""
return cudnn.log_softmax(inputs[0], axis=op.attrs["axis"])


@lower_composite("cudnn.conv2d")
def _lower_conv2d(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a conv2d using cuDNN."""
return cudnn.conv_forward(
inputs[0],
inputs[1],
pad=op.attrs["padding"],
stride=op.attrs["strides"],
dilation=op.attrs["dilation"],
conv_mode=1,
tensor_format=0,
algo=1,
conv_dtype=op.checked_type.dtype,
groups=op.attrs["groups"],
)
66 changes: 65 additions & 1 deletion tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def _verify_cudnn_relay(expr):
tvm.testing.assert_allclose(
outputs[0],
outputs[1],
rtol=1e-3,
rtol=1e-2,
)


Expand Down Expand Up @@ -513,5 +513,69 @@ def test_relay_cudnn_softmax(shape, axis, dtype):
_verify_cudnn_relay(softmax)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"shape,axis",
[
((32, 16), -1),
((13, 27), 1),
],
)
@pytest.mark.parametrize(
"dtype",
[
"float32",
"float16",
"float64",
],
)
def test_relay_cudnn_log_softmax(shape, axis, dtype):
x = tvm.relay.var("x", tvm.relay.TensorType(shape, dtype))
log_softmax = relay.op.nn.log_softmax(x, axis=axis)
_verify_cudnn_relay(log_softmax)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,h,w,ci,co,groups",
[
(1, 16, 20, 8, 16, 1),
(10, 17, 19, 16, 8, 4),
],
)
@pytest.mark.parametrize(
"kh,kw,padding",
[
(1, 1, (3, 1, 3, 1)),
(3, 3, (1, 2)),
(7, 2, (0, 0)),
],
)
@pytest.mark.parametrize(
"strides,dilation,dtype",
[
((1, 1), (1, 1), "float32"),
((2, 1), (2, 2), "float16"),
((3, 3), (1, 2), "float64"),
],
)
def test_relay_cudnn_conv2d(n, h, w, ci, co, kh, kw, strides, dilation, padding, groups, dtype):
data = tvm.relay.var("data", tvm.relay.TensorType((n, ci, h, w), dtype))
weight = tvm.relay.var("weight", tvm.relay.TensorType((co, ci // groups, kh, kw), dtype))
conv2d = relay.op.nn.conv2d(
data,
weight,
groups=groups,
channels=co,
kernel_size=(kh, kw),
strides=strides,
dilation=dilation,
padding=padding,
data_layout="NCHW",
kernel_layout="OIHW",
)
_verify_cudnn_relay(conv2d)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit 98fc649

Please sign in to comment.