From 05ac6ba89a218398cae281029835aac284075a8c Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 8 Nov 2024 10:01:32 +0000 Subject: [PATCH 1/6] #0: Tie name of each log to a slot index --- models/demos/llama3/lt | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index c69b4113a84..c8e286e74d0 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -527,6 +527,21 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): env["FAKE_DEVICE"] = entry["device"] env["LLAMA_DIR"] = get_llama_dir(entry["model"]) + # Calculate slot index (1-based) based on position in output_entries + slot_index = output_entries.index(entry) + 1 + + # Prepare log directory + os.makedirs("logs", exist_ok=True) + + # Remove any existing logs with the same slot index + for existing_file in os.listdir("logs"): + if existing_file.startswith(f"{slot_index:03d}-"): + os.remove(os.path.join("logs", existing_file)) + + # Create new log file with slot index + log_filename = get_log_filename(entry["device"], entry["model"], entry["command_input"], slot_index) + entry["log_file"] = open(os.path.join("logs", log_filename), "w") + # Define command shortcuts command_shortcuts = { "demo": "pytest models/demos/llama3/demo/demo.py -k instruct_weights-1", @@ -575,11 +590,6 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): # Prepare the command cmd_list = shlex.split(command_input) - # Open log file - log_filename = get_log_filename(entry["device"], entry["model"], command_input) - os.makedirs("logs", exist_ok=True) - entry["log_file"] = open(os.path.join("logs", log_filename), "w") - # If the command is invalid, write the output to the log file and return before trying to run the bad command if entry["status"] == "Error": entry["log_file"].write(entry["output"] + "\n") @@ -746,9 +756,9 @@ def get_command_name(command_input): return command_name -def get_log_filename(device, model, command_input): +def get_log_filename(device, model, command_input, slot_index): command_name = get_command_name(command_input) - filename = f"{device}-{model}-{command_name}.log" + filename = f"{slot_index:03d}-{device}-{model}-{command_name}.log" filename = filename.replace("/", "_") return filename From 3396547dc5e09836d15e09bae4b415c90cdbf543 Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 8 Nov 2024 10:36:49 +0000 Subject: [PATCH 2/6] #0: Detach log id's from visible index as this changes on deletion --- models/demos/llama3/lt | 43 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index c8e286e74d0..9dfd6037c1b 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -52,6 +52,37 @@ def ensure_ttsmi_installed(): sys.exit(1) +class OutputEntryList: + def __init__(self): + self._entries = [] + + def __len__(self): + return len(self._entries) + + def __getitem__(self, index): + return self._entries[index] + + def __iter__(self): + return iter(self._entries) + + def append(self, entry): + entry["log_id"] = self.next_log_id() + self._entries.append(entry) + + def pop(self, index): + result = self._entries.pop(index) + return result + + def index(self, entry): + return self._entries.index(entry) + + def get_entries(self): + return self._entries + + def next_log_id(self): + return max([entry["log_id"] for entry in self._entries], default=0) + 1 + + def main(stdscr): curses.curs_set(0) # Hide cursor curses.start_color() @@ -69,7 +100,7 @@ def main(stdscr): {"label": "Device (n150, n300, t3k) [all]", "value": "", "x": 0, "y": 2}, ] - output_entries = [] + output_entries = OutputEntryList() current_line = 0 # Index of the current line (input fields + output entries) total_lines = len(input_fields) @@ -209,7 +240,6 @@ def main(stdscr): "output": "", "process": None, "log_file": None, - "index": len(output_entries), "stop_event": threading.Event(), "lock": threading.Lock(), "command_input": command, # Save the individual command @@ -527,19 +557,16 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): env["FAKE_DEVICE"] = entry["device"] env["LLAMA_DIR"] = get_llama_dir(entry["model"]) - # Calculate slot index (1-based) based on position in output_entries - slot_index = output_entries.index(entry) + 1 - # Prepare log directory os.makedirs("logs", exist_ok=True) # Remove any existing logs with the same slot index for existing_file in os.listdir("logs"): - if existing_file.startswith(f"{slot_index:03d}-"): + if existing_file.startswith(f"{entry['log_id']:04d}-"): os.remove(os.path.join("logs", existing_file)) # Create new log file with slot index - log_filename = get_log_filename(entry["device"], entry["model"], entry["command_input"], slot_index) + log_filename = get_log_filename(entry["device"], entry["model"], entry["command_input"], entry["log_id"]) entry["log_file"] = open(os.path.join("logs", log_filename), "w") # Define command shortcuts @@ -758,7 +785,7 @@ def get_command_name(command_input): def get_log_filename(device, model, command_input, slot_index): command_name = get_command_name(command_input) - filename = f"{slot_index:03d}-{device}-{model}-{command_name}.log" + filename = f"{slot_index:04d}-{device}-{model}-{command_name}.log" filename = filename.replace("/", "_") return filename From 5f1b397763289985735ac5f65d5896654301eddd Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 8 Nov 2024 11:13:28 +0000 Subject: [PATCH 3/6] #0: remove logs when deleted from list --- models/demos/llama3/lt | 199 ++++++++++++++++++++++++++--------------- 1 file changed, 128 insertions(+), 71 deletions(-) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 9dfd6037c1b..3639b6ffcd1 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -55,6 +55,8 @@ def ensure_ttsmi_installed(): class OutputEntryList: def __init__(self): self._entries = [] + # Create logs directory + os.makedirs("logs", exist_ok=True) def __len__(self): return len(self._entries) @@ -66,11 +68,27 @@ class OutputEntryList: return iter(self._entries) def append(self, entry): - entry["log_id"] = self.next_log_id() + log_id = self.next_log_id() + + # Remove any existing logs with same ID + for existing_file in os.listdir("logs"): + if existing_file.startswith(entry.log_prefix): + os.remove(os.path.join("logs", existing_file)) + + entry["log_id"] = log_id self._entries.append(entry) def pop(self, index): result = self._entries.pop(index) + + try: + os.remove(result.get_log_filename()) + except (OSError, FileNotFoundError): + pass + + # Mark all subsequent entries as changed + for entry in self._entries[index:]: + entry.mark_changed() return result def index(self, entry): @@ -83,6 +101,76 @@ class OutputEntryList: return max([entry["log_id"] for entry in self._entries], default=0) + 1 +class Entry: + def __init__(self, command_name, model, device, command_input): + self.command_name = command_name + self.model = model + self.device = device.upper() + self.command_input = command_input + self.status = "Waiting" + self.output = "" + self.process = None + self.log_file = None + self.stop_event = threading.Event() + self.lock = threading.Lock() + self.log_id = None # Will be set by OutputEntryList + self.speed = None + self.pcc = None + self.thread = None + self.changed = True # Initialize as changed to ensure first draw + + @property + def log_prefix(self): + """Generate the log file prefix based on the entry's log ID""" + return f"{self.log_id:04d}-" + + def mark_changed(self): + self.changed = True + + def mark_drawn(self): + self.changed = False + + def __setattr__(self, name, value): + # Mark as changed whenever any attribute is modified + # (except for 'changed' itself to avoid recursion) + if name != "changed" and hasattr(self, "changed"): + self.changed = True + super().__setattr__(name, value) + + def __getitem__(self, key): + # Support dictionary-style access for backward compatibility + return getattr(self, key) + + def __setitem__(self, key, value): + # Support dictionary-style assignment for backward compatibility + setattr(self, key, value) + + def get(self, key, default=None): + # Support dictionary-style get() for backward compatibility + return getattr(self, key, default) + + def get_log_filename(self): + """Generate log filename based on entry properties""" + command_name = self._get_command_name() + filename = f"{self.log_prefix}{self.device}-{self.model}-{command_name}.log" + return os.path.join("logs", filename.replace("/", "_")) + + def _get_command_name(self): + """Extract command name from command input""" + if "pytest" in self.command_input: + match = re.search(r"pytest\s+([\S]+)", self.command_input) + if match: + test_file = match.group(1) + return os.path.basename(test_file).split(".")[0] + return "pytest" + return os.path.basename(shlex.split(self.command_input)[0]) + + def open_log_file(self): + """Open and return log file for writing""" + self.log_file = open(self.get_log_filename(), "w") + return self.log_file + + def main(stdscr): curses.curs_set(0) # Hide cursor curses.start_color() @@ -232,18 +320,7 @@ def main(stdscr): # Create output entries for command, model, device in combinations: command_name = get_command_name(command) - entry = { - "command_name": command_name, - "model": model, - "device": device.upper(), - "status": "Waiting", - "output": "", - "process": None, - "log_file": None, - "stop_event": threading.Event(), - "lock": threading.Lock(), - "command_input": command, # Save the individual command - } + entry = Entry(command_name, model, device, command) output_entries.append(entry) # Update total_lines total_lines = len(input_fields) + len(output_entries) @@ -415,28 +492,24 @@ def draw_changes(stdscr, input_fields, output_entries, current_line, last_drawn_ output_start_y = header_y + 1 for idx, entry in enumerate(output_entries): y = output_start_y + idx - if y >= max_y - 3: # Change this line from max_y - 2 to max_y - 3 + if y >= max_y - 3: break - if ( - idx >= len(last_drawn_state["output_entries"]) - or entry != last_drawn_state["output_entries"][idx] - or current_line != last_drawn_state["current_line"] - ): + + # Only draw if entry has changed or selection state changed + if entry.changed or current_line != last_drawn_state["current_line"]: draw_output_entry(stdscr, entry, y, current_line == len(input_fields) + idx, max_x) + entry.mark_drawn() # Mark as drawn after updating # Clear any extra lines if output entries were removed for y in range( output_start_y + len(output_entries), - min( - output_start_y + len(last_drawn_state["output_entries"]), max_y - 3 - ), # Change this line from max_y - 2 to max_y - 3 + min(output_start_y + len(last_drawn_state["output_entries"]), max_y - 3), ): stdscr.move(y, 0) stdscr.clrtoeol() - # Update last_drawn_state - last_drawn_state["output_entries"] = [entry.copy() for entry in output_entries] last_drawn_state["current_line"] = current_line + last_drawn_state["output_entries"] = [{"log_id": entry.log_id} for entry in output_entries] def draw_input_field(stdscr, field, is_selected, max_x): @@ -453,13 +526,13 @@ def draw_input_field(stdscr, field, is_selected, max_x): def draw_output_entry(stdscr, entry, y, is_selected, max_x): cols = [ - entry["command_name"], - entry["model"], - entry["device"], - entry["status"], - entry.get("speed", ""), - entry.get("pcc", ""), - entry["output"], + entry.command_name, + entry.model, + entry.device, + entry.status, + getattr(entry, "speed", ""), + getattr(entry, "pcc", ""), + entry.output, ] col_widths = [20, 10, 10, 20, 10, 10, max_x - 85] # Adjusted widths to accommodate the PCC column @@ -471,7 +544,7 @@ def draw_output_entry(stdscr, entry, y, is_selected, max_x): else: color = curses.color_pair(0) if i == 3: # Status column - status = entry["status"] + status = entry.status if status == "Waiting" or status == "Cancelled": color = COLOR_PAIR_WAITING elif status in ["Running", "Initializing device", "Prefill", "Decode", "Starting"] or status.startswith( @@ -557,17 +630,8 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): env["FAKE_DEVICE"] = entry["device"] env["LLAMA_DIR"] = get_llama_dir(entry["model"]) - # Prepare log directory - os.makedirs("logs", exist_ok=True) - - # Remove any existing logs with the same slot index - for existing_file in os.listdir("logs"): - if existing_file.startswith(f"{entry['log_id']:04d}-"): - os.remove(os.path.join("logs", existing_file)) - - # Create new log file with slot index - log_filename = get_log_filename(entry["device"], entry["model"], entry["command_input"], entry["log_id"]) - entry["log_file"] = open(os.path.join("logs", log_filename), "w") + # Open log file + entry.open_log_file() # Define command shortcuts command_shortcuts = { @@ -637,31 +701,31 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): def process_output(entry, screen_lock, output_entries, screen_needs_update): - process = entry["process"] - log_file = entry["log_file"] + process = entry.process + log_file = entry.log_file previous_line = "" try: for line in iter(process.stdout.readline, ""): - if entry["stop_event"].is_set(): + if entry.stop_event.is_set(): break # Write to log file log_file.write(line) log_file.flush() # Update status and output based on output - status, output, speed, pcc = parse_output_line(line, previous_line, entry["status"]) + status, output, speed, pcc = parse_output_line(line, previous_line, entry.status) previous_line = line.strip() - with entry["lock"]: - if status != entry["status"] or output or speed is not None or pcc is not None: - entry["status"] = status + with entry.lock: + if status != entry.status or output or speed is not None or pcc is not None: + entry.status = status # This will mark entry as changed via __setattr__ if output: - entry["output"] = output + entry.output = output if speed is not None: - entry["speed"] = f"{speed:.1f}" + entry.speed = f"{speed:.1f}" if pcc is not None: - current_pcc = entry.get("pcc") + current_pcc = getattr(entry, "pcc", None) if current_pcc is None or float(pcc) < float(current_pcc): - entry["pcc"] = pcc + entry.pcc = pcc screen_needs_update.set() with screen_lock: @@ -674,17 +738,17 @@ def process_output(entry, screen_lock, output_entries, screen_needs_update): # Wait for the process to fully terminate process.wait() - with entry["lock"]: + with entry.lock: if process.returncode != 0: - exception_name = find_exception_in_log(entry["log_file"].name) - entry["status"] = "Error" + exception_name = find_exception_in_log(entry.log_file.name) + entry.status = "Error" if exception_name: - entry["output"] = exception_name + entry.output = exception_name reset_device_async(entry, screen_lock, screen_needs_update) screen_needs_update.set() else: - entry["status"] = "Finished" - entry["process"] = None + entry.status = "Finished" + entry.process = None log_file.close() screen_needs_update.set() # Ensure screen is updated after process termination @@ -783,13 +847,6 @@ def get_command_name(command_input): return command_name -def get_log_filename(device, model, command_input, slot_index): - command_name = get_command_name(command_input) - filename = f"{slot_index:04d}-{device}-{model}-{command_name}.log" - filename = filename.replace("/", "_") - return filename - - def find_exception_in_log(log_filename): exception_name = None with open(log_filename, "r") as f: @@ -844,12 +901,12 @@ def reset_device_async(entry, screen_lock, screen_needs_update): except subprocess.CalledProcessError as e: pass finally: - with entry["lock"]: - entry["status"] = previous_status + with entry.lock: + entry.status = previous_status screen_needs_update.set() - previous_status = entry["status"] - entry["status"] = "Resetting" + previous_status = entry.status + entry.status = "Resetting" reset_thread = threading.Thread(target=reset_thread) reset_thread.daemon = True reset_thread.start() From 850eaef97a86bb7f1d2e955eb867f8559f35506d Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 8 Nov 2024 11:52:36 +0000 Subject: [PATCH 4/6] #0: Persist lt status in logs/status.json --- models/demos/llama3/lt | 118 ++++++++++++++++++++++++++++++++++------- 1 file changed, 100 insertions(+), 18 deletions(-) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 3639b6ffcd1..09e886247d2 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -14,6 +14,7 @@ import re import time import signal import psutil +import json def ensure_less_installed(): @@ -57,6 +58,66 @@ class OutputEntryList: self._entries = [] # Create logs directory os.makedirs("logs", exist_ok=True) + # Load existing state + self._load_state() + + def _load_state(self): + try: + with open("logs/state.json", "r") as f: + state = json.load(f) + for entry_data in state: + entry = Entry( + entry_data["command_name"], + entry_data["model"], + entry_data["device"], + entry_data["command_input"], + ) + # Restore saved attributes + entry.status = ( + "Cancelled" + if entry_data["status"] + in [ + "Running", + "Resetting", + "Initializing device", + "Starting", + "Prefill", + "Decode", + "Terminating", + "Exiting", + ] + else entry_data["status"] + ) + entry.output = entry_data["output"] + entry.log_id = entry_data["log_id"] + if "speed" in entry_data: + entry.speed = entry_data["speed"] + if "pcc" in entry_data: + entry.pcc = entry_data["pcc"] + self._entries.append(entry) + except (FileNotFoundError, json.JSONDecodeError): + pass + + def save_state(self): + state = [] + for entry in self._entries: + entry_data = { + "command_name": entry.command_name, + "model": entry.model, + "device": entry.device, + "command_input": entry.command_input, + "status": entry.status, + "output": entry.output, + "log_id": entry.log_id, + } + if hasattr(entry, "speed"): + entry_data["speed"] = entry.speed + if hasattr(entry, "pcc"): + entry_data["pcc"] = entry.pcc + state.append(entry_data) + + with open("logs/state.json", "w") as f: + json.dump(state, f, indent=2) def __len__(self): return len(self._entries) @@ -68,15 +129,18 @@ class OutputEntryList: return iter(self._entries) def append(self, entry): - log_id = self.next_log_id() + if entry.log_id is None: + entry.log_id = self.next_log_id() # Remove any existing logs with same ID for existing_file in os.listdir("logs"): if existing_file.startswith(entry.log_prefix): os.remove(os.path.join("logs", existing_file)) - entry["log_id"] = log_id + # Set parent list reference before appending + entry.set_parent_list(self) self._entries.append(entry) + self.save_state() def pop(self, index): result = self._entries.pop(index) @@ -89,6 +153,8 @@ class OutputEntryList: # Mark all subsequent entries as changed for entry in self._entries[index:]: entry.mark_changed() + + self.save_state() return result def index(self, entry): @@ -98,7 +164,12 @@ class OutputEntryList: return self._entries def next_log_id(self): - return max([entry["log_id"] for entry in self._entries], default=0) + 1 + # Fix potential issue with list comprehension on dictionary access + max_id = 0 + for entry in self._entries: + if entry.log_id > max_id: + max_id = entry.log_id + return max_id + 1 class Entry: @@ -118,6 +189,7 @@ class Entry: self.pcc = None self.thread = None self.changed = True # Initialize as changed to ensure first draw + self._parent_list = None # Reference to parent OutputEntryList @property def log_prefix(self): @@ -131,11 +203,18 @@ class Entry: self.changed = False def __setattr__(self, name, value): + super().__setattr__(name, value) # Mark as changed whenever any attribute is modified # (except for 'changed' itself to avoid recursion) if name != "changed" and hasattr(self, "changed"): self.changed = True - super().__setattr__(name, value) + # Save state if we have a parent list and this isn't an unpersisted attribute + if ( + hasattr(self, "_parent_list") + and self._parent_list + and name not in ["process", "log_file", "stop_event", "lock", "thread"] + ): + self._parent_list.save_state() def __getitem__(self, key): # Support dictionary-style access for backward compatibility @@ -170,6 +249,9 @@ class Entry: self.log_file = open(self.get_log_filename(), "w") return self.log_file + def set_parent_list(self, parent_list): + self._parent_list = parent_list + def main(stdscr): curses.curs_set(0) # Hide cursor @@ -190,7 +272,6 @@ def main(stdscr): output_entries = OutputEntryList() current_line = 0 # Index of the current line (input fields + output entries) - total_lines = len(input_fields) screen_lock = threading.Lock() screen_needs_update = threading.Event() # New event to signal screen updates @@ -322,12 +403,12 @@ def main(stdscr): command_name = get_command_name(command) entry = Entry(command_name, model, device, command) output_entries.append(entry) - # Update total_lines - total_lines = len(input_fields) + len(output_entries) + current_line = 0 screen_needs_update.set() else: # Otherwise if not the last field, move to next field + total_lines = len(input_fields) + len(output_entries) current_line = (current_line + 1) % total_lines screen_needs_update.set() else: @@ -335,13 +416,13 @@ def main(stdscr): entry_index = current_line - len(input_fields) if entry_index < len(output_entries): entry = output_entries[entry_index] - if entry["log_file"]: + if os.path.exists(entry.get_log_filename()): # Save current terminal state curses.def_prog_mode() # Exit curses temporarily curses.endwin() # Run less command - os.system(f"less -R {entry['log_file'].name}") + os.system(f"less -R {entry.get_log_filename()}") # Resume curses curses.reset_prog_mode() stdscr.refresh() @@ -369,11 +450,12 @@ def main(stdscr): elif entry["status"] != "Resetting": # Remove the entry if it's already cancelled output_entries.pop(entry_index) - total_lines -= 1 + total_lines = len(input_fields) + len(output_entries) if current_line >= total_lines: current_line = total_lines - 1 screen_needs_update.set() elif c == 9: # Tab key + total_lines = len(input_fields) + len(output_entries) current_line = (current_line + 1) % total_lines screen_needs_update.set() elif c == ord("r") and current_line >= len(input_fields): @@ -706,8 +788,6 @@ def process_output(entry, screen_lock, output_entries, screen_needs_update): previous_line = "" try: for line in iter(process.stdout.readline, ""): - if entry.stop_event.is_set(): - break # Write to log file log_file.write(line) log_file.flush() @@ -740,12 +820,14 @@ def process_output(entry, screen_lock, output_entries, screen_needs_update): with entry.lock: if process.returncode != 0: - exception_name = find_exception_in_log(entry.log_file.name) - entry.status = "Error" - if exception_name: - entry.output = exception_name - reset_device_async(entry, screen_lock, screen_needs_update) - screen_needs_update.set() + if entry.stop_event.is_set(): + entry.status = "Cancelled" + else: + exception_name = find_exception_in_log(entry.log_file.name) + entry.status = "Error" + if exception_name: + entry.output = exception_name + reset_device_async(entry, screen_lock, screen_needs_update) else: entry.status = "Finished" entry.process = None From e47daaf5be94bea7630da5676779deb65f937e82 Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 8 Nov 2024 12:03:05 +0000 Subject: [PATCH 5/6] #0: always reset on error code unless cancelled --- models/demos/llama3/lt | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 09e886247d2..83e6f9b4145 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -90,10 +90,8 @@ class OutputEntryList: ) entry.output = entry_data["output"] entry.log_id = entry_data["log_id"] - if "speed" in entry_data: - entry.speed = entry_data["speed"] - if "pcc" in entry_data: - entry.pcc = entry_data["pcc"] + entry.speed = entry_data["speed"] + entry.pcc = entry_data["pcc"] self._entries.append(entry) except (FileNotFoundError, json.JSONDecodeError): pass @@ -467,10 +465,8 @@ def main(stdscr): # Reset the entry to "Waiting" status entry["status"] = "Waiting" entry["output"] = "" - if "speed" in entry: - del entry["speed"] - if "pcc" in entry: - del entry["pcc"] + entry["speed"] = None + entry["pcc"] = None entry["process"] = None entry["log_file"] = None entry["stop_event"].clear() @@ -612,8 +608,8 @@ def draw_output_entry(stdscr, entry, y, is_selected, max_x): entry.model, entry.device, entry.status, - getattr(entry, "speed", ""), - getattr(entry, "pcc", ""), + entry.speed if entry.speed else "", + entry.pcc if entry.pcc else "", entry.output, ] col_widths = [20, 10, 10, 20, 10, 10, max_x - 85] # Adjusted widths to accommodate the PCC column @@ -803,8 +799,7 @@ def process_output(entry, screen_lock, output_entries, screen_needs_update): if speed is not None: entry.speed = f"{speed:.1f}" if pcc is not None: - current_pcc = getattr(entry, "pcc", None) - if current_pcc is None or float(pcc) < float(current_pcc): + if entry.pcc is None or float(pcc) < float(entry.pcc): entry.pcc = pcc screen_needs_update.set() @@ -827,7 +822,7 @@ def process_output(entry, screen_lock, output_entries, screen_needs_update): entry.status = "Error" if exception_name: entry.output = exception_name - reset_device_async(entry, screen_lock, screen_needs_update) + reset_device_async(entry, screen_lock, screen_needs_update) else: entry.status = "Finished" entry.process = None From c7f2446ee98a6821cd6a9c1d3aa593ac1b8d2296 Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 8 Nov 2024 12:21:54 +0000 Subject: [PATCH 6/6] #0: Add X to clear all --- models/demos/llama3/lt | 54 +++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 83e6f9b4145..70beb4e6e98 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -439,19 +439,24 @@ def main(stdscr): entry_index = current_line - len(input_fields) if entry_index < len(output_entries): entry = output_entries[entry_index] - with entry["lock"]: - if entry["process"] and entry["process"].poll() is None: - # Cancel the running process - entry["stop_event"].set() - terminate_process_tree(entry["process"].pid) - entry["status"] = "Terminating" - elif entry["status"] != "Resetting": - # Remove the entry if it's already cancelled - output_entries.pop(entry_index) - total_lines = len(input_fields) + len(output_entries) - if current_line >= total_lines: - current_line = total_lines - 1 + if cancel_entry(entry): + output_entries.pop(entry_index) + total_lines = len(input_fields) + len(output_entries) + if current_line >= total_lines: + current_line = total_lines - 1 screen_needs_update.set() + elif c == ord("X") and current_line >= len(input_fields): # Shift-X to clear all entries + to_remove = [] + for entry in output_entries: + if cancel_entry(entry): + to_remove.append(entry) + for entry in to_remove: + entry_index = output_entries.index(entry) + output_entries.pop(entry_index) + screen_needs_update.set() + total_lines = len(input_fields) + len(output_entries) + if current_line >= total_lines: + current_line = total_lines - 1 elif c == 9: # Tab key total_lines = len(input_fields) + len(output_entries) current_line = (current_line + 1) % total_lines @@ -714,7 +719,7 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): # Define command shortcuts command_shortcuts = { "demo": "pytest models/demos/llama3/demo/demo.py -k instruct_weights-1", - "demo_1layer": "pytest models/demos/llama3/demo/demo.py -k single_layer", + "demo-1layer": "pytest models/demos/llama3/demo/demo.py -k single_layer", "attention": "pytest models/demos/llama3/tests/test_llama_attention.py", "attention-prefill": "pytest models/demos/llama3/tests/test_llama_attention_prefill.py", "mlp": "pytest models/demos/llama3/tests/test_llama_mlp.py", @@ -998,13 +1003,28 @@ def draw_help_bar(stdscr, current_line, num_input_fields, num_output_entries): def get_help_text(current_line, num_input_fields, num_output_entries): if current_line == 0: - return "Shortcuts: demo, attention, mlp, decoder, decoder-prefill, model, model-prefill, model-quick | Enter: Submit | ↑↓: Navigate fields | Esc: Exit" + return "Shortcuts: demo, demo-1layer, attention, mlp, rmsnorm, decoder, model, model-quick, 'help' for full list | Enter: Submit | ↑↓: Navigate fields | Esc: Exit" elif current_line <= num_input_fields - 1: return "Enter: Next field | ↑↓: Navigate fields | Esc: Exit" else: - return ( - "Enter: View log | Backspace/x: Cancel/remove entry | r: Restart entry | ↑↓: Navigate entries | Esc: Exit" - ) + return "Enter: View log | Backspace/x: Cancel entry | X: Cancel all | r: Restart entry | ↑↓: Navigate entries | Esc: Exit" + + +def cancel_entry(entry): + """Handle removal of a single entry, returning True if entry was removed""" + with entry["lock"]: + if entry["process"] and entry["process"].poll() is None: + # Cancel the running process + entry["stop_event"].set() + terminate_process_tree(entry["process"].pid) + entry["status"] = "Terminating" + # Entry is still running, so don't remove it + return False + elif entry["status"] != "Resetting": + # Safe to remove the entry if it's already cancelled + return True + # Entry is running/resetting, so don't remove it + return False if __name__ == "__main__":