Skip to content

Commit

Permalink
[CUDNN] Add cuDNN as a Relay partitioning target (BYOC) (apache#10871)
Browse files Browse the repository at this point in the history
* [CUDNN] Add cuDNN as a Relay partitioning target (BYOC)

This adds infrastructure to support offloading of Relay
patterns to cuDNN. In this initial commit, only softmax
is supported.

* Refactor common TE BYOC code into separate file

* Add test guard
  • Loading branch information
mbaret authored and altanh committed Apr 28, 2022
1 parent 9920642 commit 3aba0a8
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 48 deletions.
54 changes: 7 additions & 47 deletions python/tvm/relay/op/contrib/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
from tvm.contrib import cublas

from ...dataflow_pattern import is_op, wildcard
from .te_target import lower_composite, relay_to_runtime
from .register import register_pattern_table


tvm._ffi.register_func("relay.ext.cublas", relay_to_runtime(tvm.target.cuda()))


def partition_for_cublas(
mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
) -> tvm.IRModule:
Expand Down Expand Up @@ -111,51 +115,7 @@ def check_matmul_like(matched: relay.Call) -> bool:
]


_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor]
_LOWER_MAP: Dict[str, _LowerFunc] = {}


def _lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]:
"""Register a lowering function for a given composite function name."""

def _register(f: _LowerFunc) -> _LowerFunc:
_LOWER_MAP[comp_name] = f
return f

return _register


@tvm._ffi.register_func("relay.ext.cublas")
def relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
"""Compile cuBLAS Relay functions to a runtime module."""
assert isinstance(partition, relay.Function)
assert isinstance(partition.body, relay.Call)
assert isinstance(partition.body.op, relay.Function)

global_name = str(partition.attrs.global_symbol)
target = tvm.target.cuda()
comp_func = partition.body.op
comp_name = comp_func.attrs["Composite"]
assert comp_name in _LOWER_MAP
assert isinstance(comp_func.body, relay.Call)

op = comp_func.body
inputs = []
for i, param in enumerate(comp_func.params):
inputs.append(
te.placeholder(
param.checked_type.shape,
name=f"input_{i}",
dtype=param.checked_type.dtype,
)
)

output = _LOWER_MAP[comp_name](op, inputs)
prim_func = te.create_prim_func(inputs + [output])
return tvm.build(prim_func, target=target, name=global_name)


@_lower_composite("cublas.matmul")
@lower_composite("cublas.matmul")
def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a matmul using cuBLAS."""
return cublas.matmul(
Expand All @@ -167,7 +127,7 @@ def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
)


@_lower_composite("cublas.batch_matmul")
@lower_composite("cublas.batch_matmul")
def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a batch_matmul using cuBLAS."""
return cublas.batch_matmul(
Expand All @@ -179,7 +139,7 @@ def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
)


@_lower_composite("cublas.dense")
@lower_composite("cublas.dense")
def _lower_dense(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a dense using cuBLAS."""
return cublas.matmul(
Expand Down
89 changes: 89 additions & 0 deletions python/tvm/relay/op/contrib/cudnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""cuDNN Relay integration."""
from typing import Callable, List, Tuple, Dict, Optional

import tvm
import tvm.ir
from tvm import relay
from tvm import te
from tvm.relay import transform
from tvm.contrib import cudnn

from ...dataflow_pattern import is_op, wildcard
from .te_target import lower_composite, relay_to_runtime
from .register import register_pattern_table


tvm._ffi.register_func("relay.ext.cudnn", relay_to_runtime(tvm.target.cuda()))


def partition_for_cudnn(
mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
) -> tvm.IRModule:
"""Partition the graph to offload for cuDNN.
Parameters
----------
mod : tvm.IRModule
The module to partition.
params : Optional[Dict[str, tvm.runtime.NDArray]]
Constant input parameters.
Returns
-------
tvm.IRModule
The partitioned module.
"""

seq = tvm.transform.Sequential(
[
transform.InferType(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("cudnn"),
transform.PartitionGraph(),
transform.InferType(),
]
)
return seq(mod)


@register_pattern_table("cudnn")
def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], bool]]]:
"""Get the cuDNN pattern table."""

def softmax_pattern() -> relay.Pattern:
"""Create pattern for softmax."""
return is_op("nn.softmax")(wildcard())

def check_softmax(matched: relay.Call) -> bool:
"""Check if softmax is supported by cuDNN."""
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
return False

return True

return [
("cudnn.softmax", softmax_pattern(), check_softmax),
]


@lower_composite("cudnn.softmax")
def _lower_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a softmax using cuDNN."""
return cudnn.softmax(inputs[0], axis=op.attrs["axis"])
70 changes: 70 additions & 0 deletions python/tvm/relay/op/contrib/te_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Support a Relay partitioning target using Tensor Expressions."""
from typing import Callable, List, Dict

import tvm
import tvm.ir
from tvm import relay
from tvm import te


_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor]
_LOWER_MAP: Dict[str, _LowerFunc] = {}


def lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]:
"""Register a lowering function for a given composite function name."""

def _register(f: _LowerFunc) -> _LowerFunc:
_LOWER_MAP[comp_name] = f
return f

return _register


def relay_to_runtime(target: tvm.target.Target) -> Callable[[relay.Function], tvm.runtime.Module]:
"""Create a Relay to runtime module lowering function using Tensor Expressions for lowering."""

def _relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
"""Compile Relay functions to a runtime module using Tensor Expressions."""
assert isinstance(partition, relay.Function)
assert isinstance(partition.body, relay.Call)
assert isinstance(partition.body.op, relay.Function)

global_name = str(partition.attrs.global_symbol)
comp_func = partition.body.op
comp_name = comp_func.attrs["Composite"]
assert comp_name in _LOWER_MAP
assert isinstance(comp_func.body, relay.Call)

op = comp_func.body
inputs = []
for i, param in enumerate(comp_func.params):
inputs.append(
te.placeholder(
param.checked_type.shape,
name=f"input_{i}",
dtype=param.checked_type.dtype,
)
)

output = _LOWER_MAP[comp_name](op, inputs)
prim_func = te.create_prim_func(inputs + [output])
return tvm.build(prim_func, target=target, name=global_name)

return _relay_to_runtime
68 changes: 68 additions & 0 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@

import tvm
from tvm import te
from tvm import relay
from tvm.contrib import cudnn
from tvm.contrib.nvcc import have_fp16
from tvm.contrib import graph_executor
import numpy as np
import tvm.topi.testing
import tvm.testing
from tvm.relay.op.contrib.cudnn import partition_for_cudnn


requires_cudnn = pytest.mark.skipif(
Expand Down Expand Up @@ -445,5 +448,70 @@ def conv_output_shape_kwargs(request):
return request.param


def _verify_cudnn_relay(expr):
np.random.seed(42)

mod = tvm.IRModule.from_expr(expr)
mod = relay.transform.InferType()(mod)
func = mod["main"]
cudnn_mod = partition_for_cudnn(mod)
assert len(cudnn_mod.get_global_vars()) == 2

input_data = []
for param in func.params:
shape = [int(x) for x in param.checked_type.shape]
input_data.append(
(param.name_hint, np.random.uniform(0, 32, size=shape).astype(param.checked_type.dtype))
)

# Test against CPU reference
cuda_config = (tvm.target.cuda(), tvm.cuda(), cudnn_mod)
cpu_config = (tvm.target.Target("llvm"), tvm.cpu(), mod)
outputs = []
for target, dev, test_mod in [cuda_config, cpu_config]:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(test_mod, target=target, target_host=cpu_config[0])
module = graph_executor.GraphModule(lib["default"](dev))
for name, data in input_data:
module.set_input(name, tvm.nd.array(data, dev))

module.run()
out_type = func.body.checked_type
outputs.append(
module.get_output(0, tvm.nd.empty(out_type.shape, dtype=out_type.dtype)).numpy()
)

tvm.testing.assert_allclose(
outputs[0],
outputs[1],
rtol=1e-3,
)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"shape,axis",
[
((200,), 0),
((13, 27), 0),
((44, 12, 67), 1),
((1, 16, 16, 8), 2),
((2, 4, 6, 8, 10), 3),
],
)
@pytest.mark.parametrize(
"dtype",
[
"float32",
"float16",
"float64",
],
)
def test_relay_cudnn_softmax(shape, axis, dtype):
x = tvm.relay.var("x", tvm.relay.TensorType(shape, dtype))
softmax = relay.op.nn.softmax(x, axis=axis)
_verify_cudnn_relay(softmax)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))
4 changes: 3 additions & 1 deletion tests/scripts/task_mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ mypy --check-untyped-defs python/tvm/tir/transform/
echo "Checking MyPy Type defs in the TIR package with unittest"
MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py

echo "Checking MyPy Type defs in tvm.relay.op.contrib.cublas"
echo "Checking MyPy Type defs in tvm.relay.op.contrib"
mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cublas.py
mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cudnn.py
mypy --disallow-untyped-defs python/tvm/relay/op/contrib/te_target.py

#TODO(@mikepapadim): This is failing atm
# echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."
Expand Down

0 comments on commit 3aba0a8

Please sign in to comment.