Skip to content

Commit

Permalink
[AutoScheduler] Fix distill record (apache#7439)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Fix distill record

* update comments
  • Loading branch information
comaniac authored and trevor-m committed Mar 2, 2021
1 parent f3f1f2d commit aee3fd3
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
52 changes: 34 additions & 18 deletions python/tvm/auto_scheduler/measure_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,26 +286,42 @@ def distill_record_file(in_file, out_file):
if os.path.isfile(out_file):
out_context = load_records(out_file)
context = itertools.chain(context, out_context)
context, context_clone = itertools.tee(context)
best_context = ApplyHistoryBest(context)
best_set = set()

def measure_input_str_key(inp):
return _ffi_api.SerializeMeasureInput(inp)

for v in best_context.best_by_model.values():
best_set.add(measure_input_str_key(v[0]))
# Dict[target key,
# Dict[workload hash,
# Dict[workload args, (cost, (MeasureInput, MeasureResult))]]]
# Full type: Dict[str, Dict[str, Dict[Tuple, Tuple[float, Tuple[Measureinput, MeasureResult]]]]]
best_records = {}

for v in best_context.best_by_targetkey.values():
best_set.add(measure_input_str_key(v[0]))
for inp, res in context:
if res.error_no != 0:
continue

# Keep the best record for each target and workload.
costs = [x.value for x in res.costs if isinstance(x, tvm.tir.expr.FloatImm)]
cost = np.mean(costs)
for k in inp.task.target.keys:
entry, _, workload_args = ApplyHistoryBest.get_workload_entry(
best_records, k, inp.task.workload_key
)
if workload_args not in entry or cost < entry[workload_args][0]:
entry[workload_args] = (cost, (inp, res))

# Remove duplications by multiple target keys.
out_records = {}
for target_entry in best_records.values():
for workload_entry in target_entry.values():
for _, (inp, res) in workload_entry.values():
out_records[measure_input_str_key(inp)] = (inp, res)

inputs = []
results = []
for inp, res in context_clone:
if measure_input_str_key(inp) in best_set:
inputs.append(inp)
results.append(res)
best_set.remove(measure_input_str_key(inp))
for inp, res in out_records.values():
inputs.append(inp)
results.append(res)

# create a new file and save the best records
open(out_file, "w")
Expand All @@ -316,23 +332,23 @@ def measure_input_str_key(inp):
def main():
"""The main function for CLI."""
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["distill"], required=True)
parser.add_argument("--i", type=str, help="input file")
parser.add_argument("--o", type=str, default=None, help="output file")
parser.add_argument("--mode", choices=["distill"], default="distill")
parser.add_argument("-i", "--input", type=str, help="input file")
parser.add_argument("-o", "--output", type=str, default=None, help="output file")

args = parser.parse_args()
logging.basicConfig()
logger.setLevel(logging.INFO)

if args.mode == "distill":
args.o = args.o or args.i + ".best.json"
distill_record_file(args.i, args.o)
args.output = args.output or args.input + ".best.json"
distill_record_file(args.input, args.output)


"""
Usage:
* Distill the best entries from a large log file
e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i input.json
e.g. python -m tvm.auto_scheduler.measure_record --mode distill -i input.json
"""
if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion tutorials/auto_scheduler/tune_network_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def tune_and_evaluate():
# 1. During the tuning, the auto-scheduler needs to compile many programs and
# extract feature from them. This part is CPU-intensive,
# so a high-performance CPU with many cores is recommended for faster search.
# 2. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill --i log.json`
# 2. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill -i log.json`
# to distill the large log file and only save the best useful records.
# 3. You can resume a search from the previous log file. You just need to
# add a new argument :code:`load_log_file` when creating the task scheduler
Expand Down
2 changes: 1 addition & 1 deletion tutorials/auto_scheduler/tune_network_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def run_tuning():
# 1. During the tuning, the auto-scheduler needs to compile many programs and
# extract feature from them. This part is CPU-intensive,
# so a high-performance CPU with many cores is recommended for faster search.
# 2. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill --i log.json`
# 2. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill -i log.json`
# to distill the large log file and only save the best useful records.
# 3. You can resume a search from the previous log file. You just need to
# add a new argument :code:`load_log_file` when creating the task scheduler
Expand Down
2 changes: 1 addition & 1 deletion tutorials/auto_scheduler/tune_network_mali.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def tune_and_evaluate():
# 1. During the tuning, the auto-scheduler needs to compile many programs and
# extract feature from them. This part is CPU-intensive,
# so a high-performance CPU with many cores is recommended for faster search.
# 2. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill --i log.json`
# 2. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill -i log.json`
# to distill the large log file and only save the best useful records.
# 3. You can resume a search from the previous log file. You just need to
# add a new argument :code:`load_log_file` when creating the task scheduler
Expand Down
2 changes: 1 addition & 1 deletion tutorials/auto_scheduler/tune_network_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def run_tuning():
# 1. During the tuning, the auto-scheduler needs to compile many programs and
# extract feature from them. This part is CPU-intensive,
# so a high-performance CPU with many cores is recommended for faster search.
# 2. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill --i log.json`
# 2. You can use :code:`python3 -m tvm.auto_scheduler.measure_record --mode distill -i log.json`
# to distill the large log file and only save the best useful records.
# 3. You can resume a search from the previous log file. You just need to
# add a new argument :code:`load_log_file` when creating the task scheduler
Expand Down

0 comments on commit aee3fd3

Please sign in to comment.