diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index e69f1d1486..0517f4b7b7 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -38,7 +38,7 @@ def read(args): arg_binary = args.binary arg_clean_artifacts = args.clean_artifacts arg_save_artifacts = args.save_artifacts - arg_section = arg.section + arg_section = args.section # preprocessing if os.path.isdir(arg_binary): @@ -129,11 +129,18 @@ def run(args): # execution print("executing action for all provided flatbuffers") + system_desc, device_ids = ttrt.runtime.get_current_system_desc() + device = ttrt.runtime.open_device(device_ids) + atexit.register(lambda: ttrt.runtime.close_device(device)) + for (binary_name, fbb, fbb_dict) in fbb_list: torch_inputs[binary_name] = [] torch_outputs[binary_name] = [] program = fbb_dict["programs"][program_index] - print(f"running program[{program_index}]:", program["name"]) + print( + f"running binary={binary_name} with program[{program_index}]:", + program["name"], + ) for i in program["inputs"]: torch_tensor = torch.randn( @@ -180,9 +187,6 @@ def run(args): total_inputs.append(inputs) total_outputs.append(outputs) - system_desc, device_ids = ttrt.runtime.get_current_system_desc() - device = ttrt.runtime.open_device(device_ids) - atexit.register(lambda: ttrt.runtime.close_device(device)) for loop in range(arg_loops): ttrt.runtime.submit(device, fbb, 0, total_inputs[loop], total_outputs[loop]) print(f"finished loop={loop}")