diff --git a/python/tvm/relay/op/contrib/cudnn.py b/python/tvm/relay/op/contrib/cudnn.py index 591178e6f882..9714a0b87dcf 100644 --- a/python/tvm/relay/op/contrib/cudnn.py +++ b/python/tvm/relay/op/contrib/cudnn.py @@ -24,6 +24,7 @@ from tvm import te from tvm.relay import transform from tvm.contrib import cudnn +from tvm.relay.build_module import bind_params_by_name from ...dataflow_pattern import is_op, wildcard from .te_target import lower_composite, relay_to_runtime @@ -50,6 +51,8 @@ def partition_for_cudnn( tvm.IRModule The partitioned module. """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) seq = tvm.transform.Sequential( [ @@ -71,6 +74,14 @@ def softmax_pattern() -> relay.Pattern: """Create pattern for softmax.""" return is_op("nn.softmax")(wildcard()) + def log_softmax_pattern() -> relay.Pattern: + """Create pattern for log_softmax.""" + return is_op("nn.log_softmax")(wildcard()) + + def conv2d_pattern() -> relay.Pattern: + """Create pattern for conv2d.""" + return is_op("nn.conv2d")(wildcard(), wildcard()) + def check_softmax(matched: relay.Call) -> bool: """Check if softmax is supported by cuDNN.""" if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]: @@ -78,8 +89,36 @@ def check_softmax(matched: relay.Call) -> bool: return True + def check_log_softmax(matched: relay.Call) -> bool: + """Check if log_softmax is supported by cuDNN.""" + if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]: + return False + + if len(matched.args[0].checked_type.shape) != 2: + return False + + if matched.attrs["axis"] not in (1, -1): + return False + + return True + + def check_conv2d(matched: relay.Call) -> bool: + if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]: + return False + + if matched.attrs["data_layout"] != "NCHW" or matched.attrs["kernel_layout"] != "OIHW": + return False + + padding = matched.attrs["padding"] + if padding[0] != padding[2] or padding[1] != padding[3]: + return False + + return True + return [ ("cudnn.softmax", softmax_pattern(), check_softmax), + ("cudnn.log_softmax", log_softmax_pattern(), check_log_softmax), + ("cudnn.conv2d", conv2d_pattern(), check_conv2d), ] @@ -87,3 +126,26 @@ def check_softmax(matched: relay.Call) -> bool: def _lower_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: """Lower a softmax using cuDNN.""" return cudnn.softmax(inputs[0], axis=op.attrs["axis"]) + + +@lower_composite("cudnn.log_softmax") +def _lower_log_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: + """Lower a log_softmax using cuDNN.""" + return cudnn.log_softmax(inputs[0], axis=op.attrs["axis"]) + + +@lower_composite("cudnn.conv2d") +def _lower_conv2d(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: + """Lower a conv2d using cuDNN.""" + return cudnn.conv_forward( + inputs[0], + inputs[1], + pad=op.attrs["padding"], + stride=op.attrs["strides"], + dilation=op.attrs["dilation"], + conv_mode=1, + tensor_format=0, + algo=1, + conv_dtype=op.checked_type.dtype, + groups=op.attrs["groups"], + ) diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 45ca7c91717d..8ca3df343dad 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -484,7 +484,7 @@ def _verify_cudnn_relay(expr): tvm.testing.assert_allclose( outputs[0], outputs[1], - rtol=1e-3, + rtol=1e-2, ) @@ -513,5 +513,69 @@ def test_relay_cudnn_softmax(shape, axis, dtype): _verify_cudnn_relay(softmax) +@tvm.testing.requires_cuda +@pytest.mark.parametrize( + "shape,axis", + [ + ((32, 16), -1), + ((13, 27), 1), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + "float32", + "float16", + "float64", + ], +) +def test_relay_cudnn_log_softmax(shape, axis, dtype): + x = tvm.relay.var("x", tvm.relay.TensorType(shape, dtype)) + log_softmax = relay.op.nn.log_softmax(x, axis=axis) + _verify_cudnn_relay(log_softmax) + + +@tvm.testing.requires_cuda +@pytest.mark.parametrize( + "n,h,w,ci,co,groups", + [ + (1, 16, 20, 8, 16, 1), + (10, 17, 19, 16, 8, 4), + ], +) +@pytest.mark.parametrize( + "kh,kw,padding", + [ + (1, 1, (3, 1, 3, 1)), + (3, 3, (1, 2)), + (7, 2, (0, 0)), + ], +) +@pytest.mark.parametrize( + "strides,dilation,dtype", + [ + ((1, 1), (1, 1), "float32"), + ((2, 1), (2, 2), "float16"), + ((3, 3), (1, 2), "float64"), + ], +) +def test_relay_cudnn_conv2d(n, h, w, ci, co, kh, kw, strides, dilation, padding, groups, dtype): + data = tvm.relay.var("data", tvm.relay.TensorType((n, ci, h, w), dtype)) + weight = tvm.relay.var("weight", tvm.relay.TensorType((co, ci // groups, kh, kw), dtype)) + conv2d = relay.op.nn.conv2d( + data, + weight, + groups=groups, + channels=co, + kernel_size=(kh, kw), + strides=strides, + dilation=dilation, + padding=padding, + data_layout="NCHW", + kernel_layout="OIHW", + ) + _verify_cudnn_relay(conv2d) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv))