diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index d91020e4fc..7f8e01ca92 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -16,6 +16,7 @@ from pkg_resources import get_distribution import sys import shutil +import atexit from ttrt.common.util import * @@ -171,11 +172,11 @@ def run(args): system_desc, device_ids = ttrt.runtime.get_current_system_desc() device = ttrt.runtime.open_device(device_ids) + atexit.register(lambda: ttrt.runtime.close_device(device)) for loop in range(arg_loops): ttrt.runtime.submit(device, fbb, 0, total_inputs[loop], total_outputs[loop]) print(f"finished loop={loop}") print("outputs:\n", torch_outputs) - ttrt.runtime.close_device(device) # save artifacts for binary in binaries: