Skip to content

Commit

Permalink
support other formats
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jan 24, 2021
1 parent 04ceaaf commit a95c810
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
6 changes: 4 additions & 2 deletions python/tvm/auto_scheduler/measure_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def check_workload_key(self, inputs):
"""
for inp in inputs:
_, args = decode_workload_key(inp.task.workload_key)
if args is None:
continue
if not args:
msg = (
"MeasureInput with old format workload key %s should be updated "
Expand Down Expand Up @@ -164,9 +166,9 @@ def flatten_list(inp):
return ret

target_key, target_args = decode_workload_key(target_workload_key)
target_args = flatten_list(target_args)
target_args = flatten_list(target_args) if target_args is not None else []
key, args = decode_workload_key(workload_key)
args = flatten_list(args)
args = flatten_list(args) if args is not None else []

# Not even the same func/DAG.
if key != target_key or len(target_args) != len(args):
Expand Down
18 changes: 12 additions & 6 deletions python/tvm/auto_scheduler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@


def decode_workload_key(workload_key):
"""Decode the workload key from a string to a list.
"""Decode the workload key from a string to the name and arguments. The wokrload key
is expected to be a list of "[func_name/hash, args ...]" in a JSON string. If not,
then simply return the workload key as the name without arguments.
Parameters
----------
Expand All @@ -55,12 +57,16 @@ def decode_workload_key(workload_key):
-------
name: str
The workload function name or the DAG hash.
args: List[Any]
The arguments of the workload.
args: Optional[List[Any]]
The arguments of the workload, or None if the workload key format is not decodeable.
"""
key_list = json.loads(workload_key)
assert len(key_list) >= 1
return key_list[0], key_list[1:]
try:
key_list = json.loads(workload_key)
if isinstance(key_list, list) and len(key_list) >= 1:
return key_list[0], key_list[1:]
except json.decoder.JSONDecodeError:
pass
return workload_key, None


def get_func_name(func):
Expand Down

0 comments on commit a95c810

Please sign in to comment.