diff --git a/python/tvm/relay/op/contrib/cublas.py b/python/tvm/relay/op/contrib/cublas.py index a93169c2d84e..47b70efebdab 100644 --- a/python/tvm/relay/op/contrib/cublas.py +++ b/python/tvm/relay/op/contrib/cublas.py @@ -26,9 +26,13 @@ from tvm.contrib import cublas from ...dataflow_pattern import is_op, wildcard +from .te_target import lower_composite, relay_to_runtime from .register import register_pattern_table +tvm._ffi.register_func("relay.ext.cublas", relay_to_runtime(tvm.target.cuda())) + + def partition_for_cublas( mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None ) -> tvm.IRModule: @@ -111,51 +115,7 @@ def check_matmul_like(matched: relay.Call) -> bool: ] -_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor] -_LOWER_MAP: Dict[str, _LowerFunc] = {} - - -def _lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]: - """Register a lowering function for a given composite function name.""" - - def _register(f: _LowerFunc) -> _LowerFunc: - _LOWER_MAP[comp_name] = f - return f - - return _register - - -@tvm._ffi.register_func("relay.ext.cublas") -def relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module: - """Compile cuBLAS Relay functions to a runtime module.""" - assert isinstance(partition, relay.Function) - assert isinstance(partition.body, relay.Call) - assert isinstance(partition.body.op, relay.Function) - - global_name = str(partition.attrs.global_symbol) - target = tvm.target.cuda() - comp_func = partition.body.op - comp_name = comp_func.attrs["Composite"] - assert comp_name in _LOWER_MAP - assert isinstance(comp_func.body, relay.Call) - - op = comp_func.body - inputs = [] - for i, param in enumerate(comp_func.params): - inputs.append( - te.placeholder( - param.checked_type.shape, - name=f"input_{i}", - dtype=param.checked_type.dtype, - ) - ) - - output = _LOWER_MAP[comp_name](op, inputs) - prim_func = te.create_prim_func(inputs + [output]) - return tvm.build(prim_func, target=target, name=global_name) - - -@_lower_composite("cublas.matmul") +@lower_composite("cublas.matmul") def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: """Lower a matmul using cuBLAS.""" return cublas.matmul( @@ -167,7 +127,7 @@ def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: ) -@_lower_composite("cublas.batch_matmul") +@lower_composite("cublas.batch_matmul") def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: """Lower a batch_matmul using cuBLAS.""" return cublas.batch_matmul( @@ -179,7 +139,7 @@ def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: ) -@_lower_composite("cublas.dense") +@lower_composite("cublas.dense") def _lower_dense(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: """Lower a dense using cuBLAS.""" return cublas.matmul( diff --git a/python/tvm/relay/op/contrib/cudnn.py b/python/tvm/relay/op/contrib/cudnn.py new file mode 100644 index 000000000000..591178e6f882 --- /dev/null +++ b/python/tvm/relay/op/contrib/cudnn.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""cuDNN Relay integration.""" +from typing import Callable, List, Tuple, Dict, Optional + +import tvm +import tvm.ir +from tvm import relay +from tvm import te +from tvm.relay import transform +from tvm.contrib import cudnn + +from ...dataflow_pattern import is_op, wildcard +from .te_target import lower_composite, relay_to_runtime +from .register import register_pattern_table + + +tvm._ffi.register_func("relay.ext.cudnn", relay_to_runtime(tvm.target.cuda())) + + +def partition_for_cudnn( + mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None +) -> tvm.IRModule: + """Partition the graph to offload for cuDNN. + + Parameters + ---------- + mod : tvm.IRModule + The module to partition. + params : Optional[Dict[str, tvm.runtime.NDArray]] + Constant input parameters. + + Returns + ------- + tvm.IRModule + The partitioned module. + """ + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + transform.MergeComposite(pattern_table()), + transform.AnnotateTarget("cudnn"), + transform.PartitionGraph(), + transform.InferType(), + ] + ) + return seq(mod) + + +@register_pattern_table("cudnn") +def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], bool]]]: + """Get the cuDNN pattern table.""" + + def softmax_pattern() -> relay.Pattern: + """Create pattern for softmax.""" + return is_op("nn.softmax")(wildcard()) + + def check_softmax(matched: relay.Call) -> bool: + """Check if softmax is supported by cuDNN.""" + if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]: + return False + + return True + + return [ + ("cudnn.softmax", softmax_pattern(), check_softmax), + ] + + +@lower_composite("cudnn.softmax") +def _lower_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: + """Lower a softmax using cuDNN.""" + return cudnn.softmax(inputs[0], axis=op.attrs["axis"]) diff --git a/python/tvm/relay/op/contrib/te_target.py b/python/tvm/relay/op/contrib/te_target.py new file mode 100644 index 000000000000..ab1a1d0cda28 --- /dev/null +++ b/python/tvm/relay/op/contrib/te_target.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Support a Relay partitioning target using Tensor Expressions.""" +from typing import Callable, List, Dict + +import tvm +import tvm.ir +from tvm import relay +from tvm import te + + +_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor] +_LOWER_MAP: Dict[str, _LowerFunc] = {} + + +def lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]: + """Register a lowering function for a given composite function name.""" + + def _register(f: _LowerFunc) -> _LowerFunc: + _LOWER_MAP[comp_name] = f + return f + + return _register + + +def relay_to_runtime(target: tvm.target.Target) -> Callable[[relay.Function], tvm.runtime.Module]: + """Create a Relay to runtime module lowering function using Tensor Expressions for lowering.""" + + def _relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module: + """Compile Relay functions to a runtime module using Tensor Expressions.""" + assert isinstance(partition, relay.Function) + assert isinstance(partition.body, relay.Call) + assert isinstance(partition.body.op, relay.Function) + + global_name = str(partition.attrs.global_symbol) + comp_func = partition.body.op + comp_name = comp_func.attrs["Composite"] + assert comp_name in _LOWER_MAP + assert isinstance(comp_func.body, relay.Call) + + op = comp_func.body + inputs = [] + for i, param in enumerate(comp_func.params): + inputs.append( + te.placeholder( + param.checked_type.shape, + name=f"input_{i}", + dtype=param.checked_type.dtype, + ) + ) + + output = _LOWER_MAP[comp_name](op, inputs) + prim_func = te.create_prim_func(inputs + [output]) + return tvm.build(prim_func, target=target, name=global_name) + + return _relay_to_runtime diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 6bf0fdffcc53..45ca7c91717d 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -21,11 +21,14 @@ import tvm from tvm import te +from tvm import relay from tvm.contrib import cudnn from tvm.contrib.nvcc import have_fp16 +from tvm.contrib import graph_executor import numpy as np import tvm.topi.testing import tvm.testing +from tvm.relay.op.contrib.cudnn import partition_for_cudnn requires_cudnn = pytest.mark.skipif( @@ -445,5 +448,70 @@ def conv_output_shape_kwargs(request): return request.param +def _verify_cudnn_relay(expr): + np.random.seed(42) + + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) + func = mod["main"] + cudnn_mod = partition_for_cudnn(mod) + assert len(cudnn_mod.get_global_vars()) == 2 + + input_data = [] + for param in func.params: + shape = [int(x) for x in param.checked_type.shape] + input_data.append( + (param.name_hint, np.random.uniform(0, 32, size=shape).astype(param.checked_type.dtype)) + ) + + # Test against CPU reference + cuda_config = (tvm.target.cuda(), tvm.cuda(), cudnn_mod) + cpu_config = (tvm.target.Target("llvm"), tvm.cpu(), mod) + outputs = [] + for target, dev, test_mod in [cuda_config, cpu_config]: + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(test_mod, target=target, target_host=cpu_config[0]) + module = graph_executor.GraphModule(lib["default"](dev)) + for name, data in input_data: + module.set_input(name, tvm.nd.array(data, dev)) + + module.run() + out_type = func.body.checked_type + outputs.append( + module.get_output(0, tvm.nd.empty(out_type.shape, dtype=out_type.dtype)).numpy() + ) + + tvm.testing.assert_allclose( + outputs[0], + outputs[1], + rtol=1e-3, + ) + + +@tvm.testing.requires_cuda +@pytest.mark.parametrize( + "shape,axis", + [ + ((200,), 0), + ((13, 27), 0), + ((44, 12, 67), 1), + ((1, 16, 16, 8), 2), + ((2, 4, 6, 8, 10), 3), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + "float32", + "float16", + "float64", + ], +) +def test_relay_cudnn_softmax(shape, axis, dtype): + x = tvm.relay.var("x", tvm.relay.TensorType(shape, dtype)) + softmax = relay.op.nn.softmax(x, axis=axis) + _verify_cudnn_relay(softmax) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index b7589d1d30e8..aaba996dbe96 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -36,8 +36,10 @@ mypy --check-untyped-defs python/tvm/tir/transform/ echo "Checking MyPy Type defs in the TIR package with unittest" MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py -echo "Checking MyPy Type defs in tvm.relay.op.contrib.cublas" +echo "Checking MyPy Type defs in tvm.relay.op.contrib" mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cublas.py +mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cudnn.py +mypy --disallow-untyped-defs python/tvm/relay/op/contrib/te_target.py #TODO(@mikepapadim): This is failing atm # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."