Skip to content

Commit

Permalink
Use PopenpoolExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and shingjan committed Aug 11, 2021
1 parent a3274cd commit 193709b
Showing 1 changed file with 39 additions and 39 deletions.
78 changes: 39 additions & 39 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
from tvm.contrib.popen_pool import PopenWorker
from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind
from tvm.target import Target


Expand Down Expand Up @@ -601,7 +601,7 @@ class MeasureErrorNo(object):
UNKNOWN_ERROR = 8 # Unknown error


def _timed_func(inp_serialized, build_func, verbose):
def _local_build_worker(inp_serialized, build_func, verbose):
tic = time.time()
inp = MeasureInput.deserialize(inp_serialized)
task = inp.task
Expand Down Expand Up @@ -666,16 +666,12 @@ def local_build_worker(args):
)
build_func = BuildFunc.build_func

worker = PopenWorker()
res = call_func_with_timeout(worker, timeout, _timed_func, args=(inp, build_func, verbose))
if isinstance(res, TimeoutError):
if verbose >= 1:
print(".T", end="", flush=True) # Build timeout
res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
elif isinstance(res, Exception):
try:
res = _local_build_worker(inp, build_func, verbose)
except Exception:
if verbose >= 1:
print(".E", end="", flush=True) # Build error
res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout
res = None, [], MeasureErrorNo.COMPILE_HOST, make_traceback_info(), timeout

return res

Expand Down Expand Up @@ -704,9 +700,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
res : List[BuildResult]
The build results of these MeasureInputs.
"""
# This pool is not doing computationally intensive work, so we can use threads
pool = ThreadPool(n_parallel)
tuple_res = pool.map(
executor = PopenPoolExecutor()
tuple_res = executor.map_with_error_catching(
local_build_worker,
[
(
Expand All @@ -718,13 +713,16 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
for i in inputs
],
)
pool.terminate()
pool.join()
del pool

results = []
for res in tuple_res:
results.append(BuildResult(*res))
if res.status == StatusKind.COMPLETE:
results.append(BuildResult(*res.value))
else:
assert res.status == StatusKind.TIMEOUT
if verbose >= 1:
print(".T", end="", flush=True) # Build timeout
results.append(BuildResult(None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout))

return results

Expand Down Expand Up @@ -1058,7 +1056,7 @@ def local_run(
return measure_results


def _timed_rpc_run(
def _rpc_run(
inp_serialized,
build_res,
args,
Expand Down Expand Up @@ -1181,25 +1179,15 @@ def _rpc_run_worker(args):
time.time(),
)

worker = PopenWorker()
res = call_func_with_timeout(worker, timeout, _timed_rpc_run, args=args)
if isinstance(res, TimeoutError):
if verbose >= 1:
print("*T", end="") # Run timeout
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
None,
build_res.time_cost + timeout,
time.time(),
)
elif isinstance(res, Exception):
try:
res = _rpc_run(*args)
except Exception:
if verbose >= 1:
print("*E", end="") # Run error
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUNTIME_DEVICE,
str(res),
make_traceback_info(),
build_res.time_cost + timeout,
time.time(),
)
Expand Down Expand Up @@ -1279,8 +1267,8 @@ def rpc_runner_run(
"""
assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
# This pool is not doing computationally intensive work, so we can use threads
pool = ThreadPool(n_parallel)
tuple_res = pool.map(
executor = PopenPoolExecutor(n_parallel)
tuple_res = executor.map_with_error_catching(
_rpc_run_worker,
[
(
Expand All @@ -1302,13 +1290,25 @@ def rpc_runner_run(
for inp, build_res in zip(inputs, build_results)
],
)
pool.terminate()
pool.join()
del pool

results = []
for res in tuple_res:
results.append(MeasureResult(*res))
for i, res in enumerate(tuple_res):
if res.status == StatusKind.COMPLETE:
results.append(MeasureResult(*res.value))
else:
assert res.status == StatusKind.TIMEOUT
if verbose >= 1:
print("*T", end="") # Run timeout
build_res = build_results[i]
results.append(
MeasureResult(
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
None,
build_res.time_cost + timeout,
time.time(),
)
)

if verbose >= 1:
print("")
Expand Down

0 comments on commit 193709b

Please sign in to comment.