Skip to content

Commit

Permalink
[CUTLASS][Ansor] Combine CUTLASS and Ansor (#13879)
Browse files Browse the repository at this point in the history
* feat: combine cutlass and ansor

* use sm80 and disable run_benchmark

* fix lint error

* use tempfile; fix dangerous default value

* merge cutlass_ansor test into test_cutlass.py

* fix lint

---------

Co-authored-by: hanqingchang <hanqingchang@kuaishou.com>
  • Loading branch information
qingchanghan and hanqingchang authored Feb 1, 2023
1 parent 2877c5a commit ba936e9
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 3 deletions.
8 changes: 7 additions & 1 deletion python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def extract_tasks(
include_simple_tasks=False,
dump_workload_to_dag_log=None,
opt_level=3,
other_targets=None,
):
"""Extract tuning tasks from a relay program.
Expand All @@ -105,6 +106,8 @@ def extract_tasks(
A file to dump an association between the workload keys and the actual DAG
opt_level : Optional[int]
The optimization level of the task extractions.
other_targets: Optional[List[tvm.target.Target]]
Other targets for call_all_topi_funcs, e.g., cutlass target.
Returns
-------
Expand All @@ -125,12 +128,15 @@ def extract_tasks(
old_verbose = dispatch_ctx.verbose
dispatch_ctx.verbose = 0

targets = [target]
if other_targets is not None:
targets += other_targets
errors = []
with env:
# Wrap build call in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
build_thread = threading.Thread(
target=call_all_topi_funcs, args=(mod, params, target, errors, opt_level)
target=call_all_topi_funcs, args=(mod, params, targets, errors, opt_level)
)
build_thread.start()
build_thread.join()
Expand Down
140 changes: 138 additions & 2 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
# specific language governing permissions and limitations
# under the License.
import logging
import tempfile
import math
import tvm
from tvm import relay
from tvm.contrib.cudnn import conv_output_shape
import numpy as np
from tvm.relay import op as _op
from tvm.runtime.vm import VirtualMachine
from tvm.relay.op.contrib.cutlass import partition_for_cutlass
from tvm import auto_scheduler
from tvm.relay.transform import FirstOrderGradient, ToMixedPrecision, InferType
from tvm.contrib.cutlass import (
has_cutlass,
Expand Down Expand Up @@ -235,6 +238,32 @@ def get_conv2d_backward_weight(
)


def get_dense_transpose_dense(M, N, K, dtype="float16"):
"""
output = nn.dense(_op.transpose(nn.dense(input, weight0), axes=(1, 0)), weight1)
dense0: [M, K] * [N, K] -> [M, N]
transpose: [M, N] -> [N, M]
dense1: [N, M] * [K, M] -> [N, K]
input: [M, K]
weight0: [N, K]
weight1: [K, M]
"""
input_shape = (M, K)
weight0_shape = (N, K)
weight1_shape = (K, M)

input = relay.var("input", shape=input_shape, dtype=dtype)
weight0 = relay.var("weight0", shape=weight0_shape, dtype=dtype)
weight1 = relay.var("weight1", shape=weight1_shape, dtype=dtype)

output0 = relay.nn.dense(input, weight0, out_dtype=dtype)
input1 = _op.transpose(output0, axes=(1, 0))
output = relay.nn.dense(input1, weight1, out_dtype=dtype)
return output


def convert_conv2d_layout(mod, desired_layouts):
with tvm.transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
Expand All @@ -257,6 +286,8 @@ def profile_and_build(
tmp_dir="./tmp",
use_fast_math=False,
use_3xtf32=True,
use_ansor=False,
ansor_tuning=False,
):
logging.info("before partitioning:\n%s", mod)
mod = partition_for_cutlass(mod)
Expand All @@ -279,8 +310,53 @@ def profile_and_build(
},
host=host,
)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=[cuda, cutlass], params=params)

if use_ansor:
with tvm.transform.PassContext(
opt_level=3, config={"relay.backend.use_auto_scheduler": True}
):
tasks, task_weights = auto_scheduler.extract_tasks(
mod, params, cuda, include_simple_tasks=True, opt_level=3, other_targets=[cutlass]
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
logging.info(
f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) ====="
)
logging.info(task.compute_dag)

with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name

# auto-tuning is disabled by default
if ansor_tuning:
measure_ctx = auto_scheduler.LocalRPCMeasureContext(
repeat=3, min_repeat_ms=200, timeout=10
)
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=100,
runner=measure_ctx.runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
)
)

with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay.build(
mod,
target=cuda,
target_host=host,
params=params,
)
else:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=[cuda, cutlass], params=params)
lib = finalize_modules(lib, "compile.so", tmp_dir)
dev = tvm.device("cuda", 0)
rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
Expand Down Expand Up @@ -959,5 +1035,65 @@ def test_conv2d_bwd():
)


def verify_dense_transpose_dense(
func,
M,
N,
K,
ref_target="cuda",
sm=80,
atol=1e-5,
rtol=1e-5,
run_benchmark=False,
dtype="float16",
use_3xtf32=True,
):
assert has_cutlass()
if sm < 80 and dtype == "float32":
return

mod = tvm.IRModule.from_expr(func)
typ = relay.transform.InferType()(mod)["main"].body.checked_type
np_data = get_random_ndarray((M, K), dtype)
np_weight0 = get_random_ndarray((N, K), dtype)
np_weight1 = get_random_ndarray((K, M), dtype)

params = {"weight0": np_weight0, "weight1": np_weight1}

rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target)
cutlass_rt_mod, dev, num_partition = profile_and_build(
mod,
params,
sm,
use_3xtf32=use_3xtf32,
use_ansor=False,
)
cutlass_ansor_rt_mod, dev, num_partition = profile_and_build(
mod,
params,
sm,
use_3xtf32=use_3xtf32,
use_ansor=True,
)
x = tvm.nd.array(np_data, device=dev)
cutlass_out = get_output(cutlass_rt_mod, ["input"], [x])
cutlass_ansor_out = get_output(cutlass_ansor_rt_mod, ["input"], [x])
ref_out = get_output(rt_mod_ref, ["input"], [x])

assert num_partition > 0
np.testing.assert_allclose(cutlass_out, ref_out, atol=atol, rtol=rtol)
np.testing.assert_allclose(cutlass_ansor_out, ref_out, atol=atol, rtol=rtol)

if run_benchmark:
print("CUTLASS:", cutlass_rt_mod.benchmark(dev, number=1, repeat=600))
print("CUTLASS with Ansor:", cutlass_ansor_rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM with target %s:" % ref_target, rt_mod_ref.benchmark(dev, number=1, repeat=600))


@tvm.testing.requires_cutlass
def test_dense_transpose_dense():
verify_dense_transpose_dense(get_dense_transpose_dense(M, N, K), M, N, K)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ba936e9

Please sign in to comment.