From 82eae0c60f5277f6392c1b8da173a675c8bff3ef Mon Sep 17 00:00:00 2001 From: Vraj Prajapati Date: Thu, 12 Sep 2024 11:34:39 -0500 Subject: [PATCH] Added --artifact-dir flag to TTRT CLI (#629) * Added --artifact-dir flag to TTRT CLI * Small bug fix when calling ttrt run upstream * Hotfix for Upstream TTRT Run call --- runtime/tools/python/ttrt/common/api.py | 65 ++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 6 deletions(-) diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index 54f488649..73ca9ab7f 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -59,6 +59,13 @@ def initialize_apis(): choices=None, help="log file to dump ttrt output to", ) + API.Query.register_arg( + name="--artifact-dir", + type=str, + default="", + choices=None, + help="--save-artifacts flag must be set, provides a directory path to save artifacts to", + ) # register all read arguments API.Read.register_arg( @@ -96,6 +103,13 @@ def initialize_apis(): choices=None, help="log file to dump ttrt output to", ) + API.Read.register_arg( + name="--artifact-dir", + type=str, + default="", + choices=None, + help="--save-artifacts flag must be set, provides a directory path to save artifacts to", + ) # register all run arguments API.Run.register_arg( @@ -182,6 +196,14 @@ def initialize_apis(): choices=None, help="log file to dump ttrt output to", ) + API.Run.register_arg( + name="--artifact-dir", + type=str, + default="", + choices=None, + help="--save-artifacts flag must be set, provides a directory path to save artifacts to", + api_only=False, + ) # register all perf arguments API.Perf.register_arg( @@ -265,6 +287,13 @@ def initialize_apis(): choices=None, help="system desc to check against", ) + API.Check.register_arg( + name="--artifact-dir", + type=str, + default="", + choices=None, + help="--save-artifacts flag must be set, provides a directory path to save artifacts to", + ) # register apis API.register_api(API.Query) @@ -310,7 +339,11 @@ def __init__(self, args={}, logging=None, artifacts=None): self.artifacts = ( artifacts if artifacts != None - else Artifacts(self.logger, self.file_manager) + else Artifacts( + self.logger, + self.file_manager, + artifacts_folder_path=self["artifact_dir"], + ) ) self.system_desc = None self.device_ids = None @@ -465,7 +498,11 @@ def __init__(self, args={}, logging=None, artifacts=None): self.artifacts = ( artifacts if artifacts != None - else Artifacts(self.logger, self.file_manager) + else Artifacts( + self.logger, + self.file_manager, + artifacts_folder_path=self["artifact_dir"], + ) ) self.read_action_functions = {} self.ttnn_binaries = [] @@ -714,7 +751,11 @@ def __init__(self, args={}, logging=None, artifacts=None): self.artifacts = ( artifacts if artifacts != None - else Artifacts(self.logger, self.file_manager) + else Artifacts( + self.logger, + self.file_manager, + artifacts_folder_path=self["artifact_dir"], + ) ) self.query = API.Query({}, self.logger, self.artifacts) self.ttnn_binaries = [] @@ -1063,7 +1104,11 @@ def __init__(self, args={}, logging=None, artifacts=None): self.artifacts = ( artifacts if artifacts != None - else Artifacts(self.logger, self.file_manager) + else Artifacts( + self.logger, + self.file_manager, + artifacts_folder_path=self["artifact_dir"], + ) ) self.query = API.Query({}, self.logger, self.artifacts) self.ttnn_binaries = [] @@ -1311,7 +1356,11 @@ def _execute(binaries): if self[name]: command_options += f" {api['name']} " else: - command_options += f" {api['name']} {self[name]} " + command_options += f" {api['name']} " + if isinstance(self[name], str) and not self[name]: + command_options += f'"{self[name]}" ' + else: + command_options += f"{self[name]} " library_link_path = self.globals.get_ld_path( f"{self.globals.get_ttmetal_home_path()}" @@ -1466,7 +1515,11 @@ def __init__(self, args={}, logging=None, artifacts=None): self.artifacts = ( artifacts if artifacts != None - else Artifacts(self.logger, self.file_manager) + else Artifacts( + self.logger, + self.file_manager, + artifacts_folder_dir=self["artifact_dir"], + ) ) self.query = API.Query({}, self.logger, self.artifacts) self.ttnn_binaries = []