Skip to content

Commit

Permalink
Fix AutoScheduler for anaconda python (apache#7387)
Browse files Browse the repository at this point in the history
In case of non cpython flavour of python, the task passed to measure process
should be serialized using pickle approach. The task includes workload
which is a list of Tensors. The list should be serialized and deserialized
as an atomic object.
  • Loading branch information
dlexplorer authored and Lokiiiiii committed Mar 1, 2021
1 parent cf7fd44 commit b11ef38
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
18 changes: 12 additions & 6 deletions python/tvm/auto_scheduler/workload_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import json

import tvm._ffi
from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
from .utils import serialize_args, deserialize_args, get_func_name

logger = logging.getLogger("auto_scheduler")
Expand Down Expand Up @@ -216,13 +217,17 @@ def serialize_workload_registry_entry(workload_key):
global WORKLOAD_FUNC_REGISTRY

if workload_key in WORKLOAD_FUNC_REGISTRY:
return (workload_key, WORKLOAD_FUNC_REGISTRY[workload_key])
sname = workload_key
else:
workload = json.loads(workload_key)
sname = workload[0]

workload = json.loads(workload_key)
name = workload[0]
value = WORKLOAD_FUNC_REGISTRY[name]
svalue = WORKLOAD_FUNC_REGISTRY[sname]
if not callable(svalue):
# pylint: disable=assignment-from-no-return
svalue = SaveJSON(svalue)

return name, value
return sname, svalue


def deserialize_workload_registry_entry(data):
Expand All @@ -239,7 +244,8 @@ def deserialize_workload_registry_entry(data):

name, value = data
if name not in WORKLOAD_FUNC_REGISTRY:
WORKLOAD_FUNC_REGISTRY[name] = value
# pylint: disable=assignment-from-no-return
WORKLOAD_FUNC_REGISTRY[name] = LoadJSON(value)


def save_workload_func_registry(filename):
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_auto_scheduler_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from tvm import te, auto_scheduler
import tempfile
import tvm.testing
import pickle

from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul
from tvm.auto_scheduler import workload_registry


def record_common(dag, s):
Expand Down Expand Up @@ -255,6 +257,42 @@ def test_measure_local_builder_runner():
assert mress[0].error_no == 0


def test_dag_measure_local_builder_runner():
if not tvm.testing.device_enabled("llvm"):
return

A = te.placeholder((512, 512), name="A")
B = te.placeholder((512, 512), name="B")
k = te.reduce_axis((0, 512), name="k")
C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C")
D = topi.nn.relu(C)
E = topi.nn.relu(D)

tensors = [A, B, E]
dag = auto_scheduler.ComputeDAG(tensors)
key = workload_registry.register_workload_tensors(dag.workload_key(), tensors)
transfer_data = workload_registry.serialize_workload_registry_entry(key)
f_data = pickle.dumps(transfer_data)
f_new = pickle.loads(f_data)
del workload_registry.WORKLOAD_FUNC_REGISTRY[key]
workload_registry.deserialize_workload_registry_entry(f_new)

target = tvm.target.Target("llvm")
task = auto_scheduler.SearchTask(compute_dag=dag, workload_key=key, target=target)

for enable_cpu_cache_flush in [True, False]:
minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
local_builder = auto_scheduler.LocalBuilder()
local_runner = auto_scheduler.LocalRunner(
timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
)

bress = local_builder.build([minp])
assert bress[0].error_no == 0
mress = local_runner.run([minp], bress)
assert mress[0].error_no == 0


def test_measure_local_builder_rpc_runner():
if not tvm.testing.device_enabled("llvm"):
return
Expand Down Expand Up @@ -325,5 +363,6 @@ def test_measure_target_host():
test_recover_measure_input()
test_workload_dis_factor()
test_measure_local_builder_runner()
test_dag_measure_local_builder_runner()
test_measure_local_builder_rpc_runner()
test_measure_target_host()

0 comments on commit b11ef38

Please sign in to comment.