diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py index 5f999fb3a1ad..9eaef189e081 100644 --- a/python/tvm/auto_scheduler/measure_record.py +++ b/python/tvm/auto_scheduler/measure_record.py @@ -79,6 +79,8 @@ def check_workload_key(self, inputs): """ for inp in inputs: _, args = decode_workload_key(inp.task.workload_key) + if args is None: + continue if not args: msg = ( "MeasureInput with old format workload key %s should be updated " @@ -164,9 +166,9 @@ def flatten_list(inp): return ret target_key, target_args = decode_workload_key(target_workload_key) - target_args = flatten_list(target_args) + target_args = flatten_list(target_args) if target_args is not None else [] key, args = decode_workload_key(workload_key) - args = flatten_list(args) + args = flatten_list(args) if args is not None else [] # Not even the same func/DAG. if key != target_key or len(target_args) != len(args): diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 689c82c37ca4..fd25fdb783f7 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -44,7 +44,9 @@ def decode_workload_key(workload_key): - """Decode the workload key from a string to a list. + """Decode the workload key from a string to the name and arguments. The wokrload key + is expected to be a list of "[func_name/hash, args ...]" in a JSON string. If not, + then simply return the workload key as the name without arguments. Parameters ---------- @@ -55,12 +57,16 @@ def decode_workload_key(workload_key): ------- name: str The workload function name or the DAG hash. - args: List[Any] - The arguments of the workload. + args: Optional[List[Any]] + The arguments of the workload, or None if the workload key format is not decodeable. """ - key_list = json.loads(workload_key) - assert len(key_list) >= 1 - return key_list[0], key_list[1:] + try: + key_list = json.loads(workload_key) + if isinstance(key_list, list) and len(key_list) >= 1: + return key_list[0], key_list[1:] + except json.decoder.JSONDecodeError: + pass + return workload_key, None def get_func_name(func):