From ee6fe122c6a51fc6ae7a82cc535c474d060b3189 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 23 Sep 2021 20:49:44 +0300 Subject: [PATCH] [Auto-Schedule][Fix] Fix hang while tune model through rpc (#9032) * [Auto-Schedule][Fix] Fix hang while tune model through rpc * Fix problem with hang by using deep copy * Fix with local args * Update python/tvm/auto_scheduler/measure.py Co-authored-by: Wuwei Lin --- python/tvm/auto_scheduler/measure.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index c58aeea57d14..8c6fd5f1a949 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -909,6 +909,7 @@ def _timed_eval_func( random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" assert len(args) == len(build_res.args) + loc_args = [] # pylint: disable=consider-using-enumerate for idx in range(len(args)): if args[idx] is None: @@ -917,11 +918,11 @@ def _timed_eval_func( get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev ) random_fill(empty_array) - args[idx] = empty_array + loc_args.append(empty_array) else: - args[idx] = ndarray.array(args[idx], dev) + loc_args.append(ndarray.array(args[idx], dev)) dev.sync() - costs = time_f(*args).results + costs = time_f(*loc_args).results # pylint: disable=broad-except except Exception: costs = (MAX_FLOAT,) @@ -1112,6 +1113,7 @@ def _rpc_run( ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" assert len(args) == len(build_res.args) + loc_args = [] # pylint: disable=consider-using-enumerate for idx in range(len(args)): if args[idx] is None: @@ -1120,16 +1122,16 @@ def _rpc_run( get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev ) random_fill(empty_array) - args[idx] = empty_array + loc_args.append(empty_array) else: - args[idx] = ndarray.array(args[idx], dev) + loc_args.append(ndarray.array(args[idx], dev)) dev.sync() # First run for check that the kernel is correct - func.entry_func(*args) + func.entry_func(*loc_args) dev.sync() - costs = time_f(*args).results + costs = time_f(*loc_args).results # clean up remote files remote.remove(build_res.filename)