From f367d1533a10c2d476b7a12e54c5261f71b08cfb Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 8 Jun 2020 14:36:42 +0800 Subject: [PATCH] Add RPCRunner & OpenCL/CUDA test (#12) * Add RPCRunner & OpenCL search test * Add CUDA search test * Add RPCRunner test --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/measure.py | 22 +++++++ python/tvm/rpc/server.py | 3 +- src/ansor/measure.cc | 8 +++ .../search_policy/meta_tile_rewrite_policy.h | 1 - tests/python/unittest/test_ansor_measure.py | 29 +++++++++ .../unittest/test_ansor_search_policy.py | 61 +++++++++++++++++-- 7 files changed, 117 insertions(+), 9 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 1be7ed404c17..7552878a3c50 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -28,6 +28,6 @@ from .compute_dag import ComputeDAG from .task import SearchTask, MetaTileRewritePolicy, TuneOption from .task import auto_schedule -from .measure import MeasureInput, LocalBuilder, LocalRunner +from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner from .cost_model import RandomModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 5438edfaa6b2..b80de7c01633 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -168,6 +168,28 @@ def __init__(self, _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) +@tvm._ffi.register_object("ansor.RPCRunner") +class RPCRunner(Runner): + def __init__(self, key, host, port, priority=1, + n_parallel=1, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.__init_handle_by_constructor__( + _ffi_api.RPCRunner, key, host, port, priority, timeout, n_parallel, + number, repeat, min_repeat_ms, cooldown_interval) + + if check_remote(key, host, port, priority, timeout): + logger.info("Get devices for measurement successfully!") + else: + raise RuntimeError("Cannot get remote devices from the tracker. " + "Please check the status of tracker by " + "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " + "and make sure you have free devices on the queue status.") + + MAX_ERROR_MSG_LEN = 512 diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 15a3c7de789d..42bcb00a9117 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -348,7 +348,8 @@ def __init__(self, cmd = [sys.executable, "-m", "tvm.exec.rpc_server", "--host=%s" % host, - "--port=%s" % port] + "--port=%s" % port, + "--port-end=%s" % port_end] if tracker_addr: assert key cmd += ["--tracker=%s:%d" % tracker_addr, diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 43be530f2a35..e3593753d3ff 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -368,5 +368,13 @@ TVM_REGISTER_GLOBAL("ansor.LocalRunner") cooldown_interval); }); +TVM_REGISTER_GLOBAL("ansor.RPCRunner") +.set_body_typed([](const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval) { + return RPCRunnerNode::make(key, host, port, priority, timeout, n_parallel, + number, repeat, min_repeat_ms, cooldown_interval); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index 0c8c44b9c5ea..823ef6df4983 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -76,7 +76,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { SearchTask cur_task_; // The current task - friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT protected: // Pick states from best states and random states with eps-greedy policy void PickStatesWithEpsGreedy(std::vector* inputs, diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index baf8a0c4efa2..0385568894fe 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -19,6 +19,8 @@ import tvm from tvm import ansor +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server import tempfile from test_ansor_common import get_tiled_matmul @@ -62,6 +64,33 @@ def test_measure_local_builder_runner(): assert mress[0].error_no == 0 +def test_measure_local_builder_rpc_runner(): + dag, s0 = get_tiled_matmul() + + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() + host = '0.0.0.0' + tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(host, port=tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = rpc_runner.run([minp], bress) + assert mress[0].error_no == 0 + + tracker.terminate() + server.terminate() + + if __name__ == "__main__": test_serialization() test_measure_local_builder_runner() + test_measure_local_builder_rpc_runner() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index eea3f5cfbda3..9a57691aba22 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -24,19 +24,20 @@ import tvm from tvm import ansor +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server from test_ansor_common import matmul_nkkm -def test_search_basic(): - print("Test schedule search with the default search policy") +def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local'): + print("Test %s schedule search with the default search policy" % (target)) N = 128 A, B, C = matmul_nkkm(N, N, N) dag = ansor.ComputeDAG([A, B, C]) - tgt = tvm.target.create("llvm") + tgt = tvm.target.create(target) task = ansor.SearchTask(dag, "test", tgt) - seed = 944563397 random.seed(seed) with tempfile.NamedTemporaryFile() as fp: @@ -44,7 +45,7 @@ def test_search_basic(): cost_model = ansor.RandomModel() search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) - tune_option = ansor.TuneOption(n_trials=2, + tune_option = ansor.TuneOption(n_trials=2, runner=runner, callbacks=[ansor.LogToFile(log_file)]) state = ansor.auto_schedule(task, search_policy, tune_option=tune_option) @@ -60,7 +61,7 @@ def test_search_basic(): print(tvm.lower(sch, args, simple_mode=True)) mod = tvm.build(sch, args, tgt) - ctx = tvm.context("llvm", 0) + ctx = tvm.context(target, 0) a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) @@ -75,7 +76,55 @@ def test_search_basic(): s0 = dag.infer_bound_from_state(state) s1 = dag.infer_bound_from_state(inp.state) assert s0 == s1 + print() + + +def test_search_basic(): + search_common(seed=944563397) + + +def test_search_opencl(): + if tvm.context("opencl", 0).exist: + host = '0.0.0.0' + tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(host, port=tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + + search_common("opencl", 380344973, rpc_runner) + + tracker.terminate() + server.terminate() + else: + print("OpenCL device not found, skip this test.") + + +def test_search_cuda(): + ctx = tvm.context("cuda", 0) + if ctx.exist: + cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) + tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) + host = '0.0.0.0' + tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(host, port=tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + + search_common("cuda", 903667810, rpc_runner) + + tracker.terminate() + server.terminate() + else: + print("CUDA device not found, skip this test.") if __name__ == "__main__": test_search_basic() + test_search_opencl() + test_search_cuda()