diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index c69b4113a84..70beb4e6e98 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -14,6 +14,7 @@ import re import time import signal import psutil +import json def ensure_less_installed(): @@ -52,6 +53,204 @@ def ensure_ttsmi_installed(): sys.exit(1) +class OutputEntryList: + def __init__(self): + self._entries = [] + # Create logs directory + os.makedirs("logs", exist_ok=True) + # Load existing state + self._load_state() + + def _load_state(self): + try: + with open("logs/state.json", "r") as f: + state = json.load(f) + for entry_data in state: + entry = Entry( + entry_data["command_name"], + entry_data["model"], + entry_data["device"], + entry_data["command_input"], + ) + # Restore saved attributes + entry.status = ( + "Cancelled" + if entry_data["status"] + in [ + "Running", + "Resetting", + "Initializing device", + "Starting", + "Prefill", + "Decode", + "Terminating", + "Exiting", + ] + else entry_data["status"] + ) + entry.output = entry_data["output"] + entry.log_id = entry_data["log_id"] + entry.speed = entry_data["speed"] + entry.pcc = entry_data["pcc"] + self._entries.append(entry) + except (FileNotFoundError, json.JSONDecodeError): + pass + + def save_state(self): + state = [] + for entry in self._entries: + entry_data = { + "command_name": entry.command_name, + "model": entry.model, + "device": entry.device, + "command_input": entry.command_input, + "status": entry.status, + "output": entry.output, + "log_id": entry.log_id, + } + if hasattr(entry, "speed"): + entry_data["speed"] = entry.speed + if hasattr(entry, "pcc"): + entry_data["pcc"] = entry.pcc + state.append(entry_data) + + with open("logs/state.json", "w") as f: + json.dump(state, f, indent=2) + + def __len__(self): + return len(self._entries) + + def __getitem__(self, index): + return self._entries[index] + + def __iter__(self): + return iter(self._entries) + + def append(self, entry): + if entry.log_id is None: + entry.log_id = self.next_log_id() + + # Remove any existing logs with same ID + for existing_file in os.listdir("logs"): + if existing_file.startswith(entry.log_prefix): + os.remove(os.path.join("logs", existing_file)) + + # Set parent list reference before appending + entry.set_parent_list(self) + self._entries.append(entry) + self.save_state() + + def pop(self, index): + result = self._entries.pop(index) + + try: + os.remove(result.get_log_filename()) + except (OSError, FileNotFoundError): + pass + + # Mark all subsequent entries as changed + for entry in self._entries[index:]: + entry.mark_changed() + + self.save_state() + return result + + def index(self, entry): + return self._entries.index(entry) + + def get_entries(self): + return self._entries + + def next_log_id(self): + # Fix potential issue with list comprehension on dictionary access + max_id = 0 + for entry in self._entries: + if entry.log_id > max_id: + max_id = entry.log_id + return max_id + 1 + + +class Entry: + def __init__(self, command_name, model, device, command_input): + self.command_name = command_name + self.model = model + self.device = device.upper() + self.command_input = command_input + self.status = "Waiting" + self.output = "" + self.process = None + self.log_file = None + self.stop_event = threading.Event() + self.lock = threading.Lock() + self.log_id = None # Will be set by OutputEntryList + self.speed = None + self.pcc = None + self.thread = None + self.changed = True # Initialize as changed to ensure first draw + self._parent_list = None # Reference to parent OutputEntryList + + @property + def log_prefix(self): + """Generate the log file prefix based on the entry's log ID""" + return f"{self.log_id:04d}-" + + def mark_changed(self): + self.changed = True + + def mark_drawn(self): + self.changed = False + + def __setattr__(self, name, value): + super().__setattr__(name, value) + # Mark as changed whenever any attribute is modified + # (except for 'changed' itself to avoid recursion) + if name != "changed" and hasattr(self, "changed"): + self.changed = True + # Save state if we have a parent list and this isn't an unpersisted attribute + if ( + hasattr(self, "_parent_list") + and self._parent_list + and name not in ["process", "log_file", "stop_event", "lock", "thread"] + ): + self._parent_list.save_state() + + def __getitem__(self, key): + # Support dictionary-style access for backward compatibility + return getattr(self, key) + + def __setitem__(self, key, value): + # Support dictionary-style assignment for backward compatibility + setattr(self, key, value) + + def get(self, key, default=None): + # Support dictionary-style get() for backward compatibility + return getattr(self, key, default) + + def get_log_filename(self): + """Generate log filename based on entry properties""" + command_name = self._get_command_name() + filename = f"{self.log_prefix}{self.device}-{self.model}-{command_name}.log" + return os.path.join("logs", filename.replace("/", "_")) + + def _get_command_name(self): + """Extract command name from command input""" + if "pytest" in self.command_input: + match = re.search(r"pytest\s+([\S]+)", self.command_input) + if match: + test_file = match.group(1) + return os.path.basename(test_file).split(".")[0] + return "pytest" + return os.path.basename(shlex.split(self.command_input)[0]) + + def open_log_file(self): + """Open and return log file for writing""" + self.log_file = open(self.get_log_filename(), "w") + return self.log_file + + def set_parent_list(self, parent_list): + self._parent_list = parent_list + + def main(stdscr): curses.curs_set(0) # Hide cursor curses.start_color() @@ -69,9 +268,8 @@ def main(stdscr): {"label": "Device (n150, n300, t3k) [all]", "value": "", "x": 0, "y": 2}, ] - output_entries = [] + output_entries = OutputEntryList() current_line = 0 # Index of the current line (input fields + output entries) - total_lines = len(input_fields) screen_lock = threading.Lock() screen_needs_update = threading.Event() # New event to signal screen updates @@ -201,26 +399,14 @@ def main(stdscr): # Create output entries for command, model, device in combinations: command_name = get_command_name(command) - entry = { - "command_name": command_name, - "model": model, - "device": device.upper(), - "status": "Waiting", - "output": "", - "process": None, - "log_file": None, - "index": len(output_entries), - "stop_event": threading.Event(), - "lock": threading.Lock(), - "command_input": command, # Save the individual command - } + entry = Entry(command_name, model, device, command) output_entries.append(entry) - # Update total_lines - total_lines = len(input_fields) + len(output_entries) + current_line = 0 screen_needs_update.set() else: # Otherwise if not the last field, move to next field + total_lines = len(input_fields) + len(output_entries) current_line = (current_line + 1) % total_lines screen_needs_update.set() else: @@ -228,13 +414,13 @@ def main(stdscr): entry_index = current_line - len(input_fields) if entry_index < len(output_entries): entry = output_entries[entry_index] - if entry["log_file"]: + if os.path.exists(entry.get_log_filename()): # Save current terminal state curses.def_prog_mode() # Exit curses temporarily curses.endwin() # Run less command - os.system(f"less -R {entry['log_file'].name}") + os.system(f"less -R {entry.get_log_filename()}") # Resume curses curses.reset_prog_mode() stdscr.refresh() @@ -253,20 +439,26 @@ def main(stdscr): entry_index = current_line - len(input_fields) if entry_index < len(output_entries): entry = output_entries[entry_index] - with entry["lock"]: - if entry["process"] and entry["process"].poll() is None: - # Cancel the running process - entry["stop_event"].set() - terminate_process_tree(entry["process"].pid) - entry["status"] = "Terminating" - elif entry["status"] != "Resetting": - # Remove the entry if it's already cancelled - output_entries.pop(entry_index) - total_lines -= 1 - if current_line >= total_lines: - current_line = total_lines - 1 + if cancel_entry(entry): + output_entries.pop(entry_index) + total_lines = len(input_fields) + len(output_entries) + if current_line >= total_lines: + current_line = total_lines - 1 screen_needs_update.set() + elif c == ord("X") and current_line >= len(input_fields): # Shift-X to clear all entries + to_remove = [] + for entry in output_entries: + if cancel_entry(entry): + to_remove.append(entry) + for entry in to_remove: + entry_index = output_entries.index(entry) + output_entries.pop(entry_index) + screen_needs_update.set() + total_lines = len(input_fields) + len(output_entries) + if current_line >= total_lines: + current_line = total_lines - 1 elif c == 9: # Tab key + total_lines = len(input_fields) + len(output_entries) current_line = (current_line + 1) % total_lines screen_needs_update.set() elif c == ord("r") and current_line >= len(input_fields): @@ -278,10 +470,8 @@ def main(stdscr): # Reset the entry to "Waiting" status entry["status"] = "Waiting" entry["output"] = "" - if "speed" in entry: - del entry["speed"] - if "pcc" in entry: - del entry["pcc"] + entry["speed"] = None + entry["pcc"] = None entry["process"] = None entry["log_file"] = None entry["stop_event"].clear() @@ -385,28 +575,24 @@ def draw_changes(stdscr, input_fields, output_entries, current_line, last_drawn_ output_start_y = header_y + 1 for idx, entry in enumerate(output_entries): y = output_start_y + idx - if y >= max_y - 3: # Change this line from max_y - 2 to max_y - 3 + if y >= max_y - 3: break - if ( - idx >= len(last_drawn_state["output_entries"]) - or entry != last_drawn_state["output_entries"][idx] - or current_line != last_drawn_state["current_line"] - ): + + # Only draw if entry has changed or selection state changed + if entry.changed or current_line != last_drawn_state["current_line"]: draw_output_entry(stdscr, entry, y, current_line == len(input_fields) + idx, max_x) + entry.mark_drawn() # Mark as drawn after updating # Clear any extra lines if output entries were removed for y in range( output_start_y + len(output_entries), - min( - output_start_y + len(last_drawn_state["output_entries"]), max_y - 3 - ), # Change this line from max_y - 2 to max_y - 3 + min(output_start_y + len(last_drawn_state["output_entries"]), max_y - 3), ): stdscr.move(y, 0) stdscr.clrtoeol() - # Update last_drawn_state - last_drawn_state["output_entries"] = [entry.copy() for entry in output_entries] last_drawn_state["current_line"] = current_line + last_drawn_state["output_entries"] = [{"log_id": entry.log_id} for entry in output_entries] def draw_input_field(stdscr, field, is_selected, max_x): @@ -423,13 +609,13 @@ def draw_input_field(stdscr, field, is_selected, max_x): def draw_output_entry(stdscr, entry, y, is_selected, max_x): cols = [ - entry["command_name"], - entry["model"], - entry["device"], - entry["status"], - entry.get("speed", ""), - entry.get("pcc", ""), - entry["output"], + entry.command_name, + entry.model, + entry.device, + entry.status, + entry.speed if entry.speed else "", + entry.pcc if entry.pcc else "", + entry.output, ] col_widths = [20, 10, 10, 20, 10, 10, max_x - 85] # Adjusted widths to accommodate the PCC column @@ -441,7 +627,7 @@ def draw_output_entry(stdscr, entry, y, is_selected, max_x): else: color = curses.color_pair(0) if i == 3: # Status column - status = entry["status"] + status = entry.status if status == "Waiting" or status == "Cancelled": color = COLOR_PAIR_WAITING elif status in ["Running", "Initializing device", "Prefill", "Decode", "Starting"] or status.startswith( @@ -527,10 +713,13 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): env["FAKE_DEVICE"] = entry["device"] env["LLAMA_DIR"] = get_llama_dir(entry["model"]) + # Open log file + entry.open_log_file() + # Define command shortcuts command_shortcuts = { "demo": "pytest models/demos/llama3/demo/demo.py -k instruct_weights-1", - "demo_1layer": "pytest models/demos/llama3/demo/demo.py -k single_layer", + "demo-1layer": "pytest models/demos/llama3/demo/demo.py -k single_layer", "attention": "pytest models/demos/llama3/tests/test_llama_attention.py", "attention-prefill": "pytest models/demos/llama3/tests/test_llama_attention_prefill.py", "mlp": "pytest models/demos/llama3/tests/test_llama_mlp.py", @@ -575,11 +764,6 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): # Prepare the command cmd_list = shlex.split(command_input) - # Open log file - log_filename = get_log_filename(entry["device"], entry["model"], command_input) - os.makedirs("logs", exist_ok=True) - entry["log_file"] = open(os.path.join("logs", log_filename), "w") - # If the command is invalid, write the output to the log file and return before trying to run the bad command if entry["status"] == "Error": entry["log_file"].write(entry["output"] + "\n") @@ -600,31 +784,28 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): def process_output(entry, screen_lock, output_entries, screen_needs_update): - process = entry["process"] - log_file = entry["log_file"] + process = entry.process + log_file = entry.log_file previous_line = "" try: for line in iter(process.stdout.readline, ""): - if entry["stop_event"].is_set(): - break # Write to log file log_file.write(line) log_file.flush() # Update status and output based on output - status, output, speed, pcc = parse_output_line(line, previous_line, entry["status"]) + status, output, speed, pcc = parse_output_line(line, previous_line, entry.status) previous_line = line.strip() - with entry["lock"]: - if status != entry["status"] or output or speed is not None or pcc is not None: - entry["status"] = status + with entry.lock: + if status != entry.status or output or speed is not None or pcc is not None: + entry.status = status # This will mark entry as changed via __setattr__ if output: - entry["output"] = output + entry.output = output if speed is not None: - entry["speed"] = f"{speed:.1f}" + entry.speed = f"{speed:.1f}" if pcc is not None: - current_pcc = entry.get("pcc") - if current_pcc is None or float(pcc) < float(current_pcc): - entry["pcc"] = pcc + if entry.pcc is None or float(pcc) < float(entry.pcc): + entry.pcc = pcc screen_needs_update.set() with screen_lock: @@ -637,17 +818,19 @@ def process_output(entry, screen_lock, output_entries, screen_needs_update): # Wait for the process to fully terminate process.wait() - with entry["lock"]: + with entry.lock: if process.returncode != 0: - exception_name = find_exception_in_log(entry["log_file"].name) - entry["status"] = "Error" - if exception_name: - entry["output"] = exception_name + if entry.stop_event.is_set(): + entry.status = "Cancelled" + else: + exception_name = find_exception_in_log(entry.log_file.name) + entry.status = "Error" + if exception_name: + entry.output = exception_name reset_device_async(entry, screen_lock, screen_needs_update) - screen_needs_update.set() else: - entry["status"] = "Finished" - entry["process"] = None + entry.status = "Finished" + entry.process = None log_file.close() screen_needs_update.set() # Ensure screen is updated after process termination @@ -746,13 +929,6 @@ def get_command_name(command_input): return command_name -def get_log_filename(device, model, command_input): - command_name = get_command_name(command_input) - filename = f"{device}-{model}-{command_name}.log" - filename = filename.replace("/", "_") - return filename - - def find_exception_in_log(log_filename): exception_name = None with open(log_filename, "r") as f: @@ -807,12 +983,12 @@ def reset_device_async(entry, screen_lock, screen_needs_update): except subprocess.CalledProcessError as e: pass finally: - with entry["lock"]: - entry["status"] = previous_status + with entry.lock: + entry.status = previous_status screen_needs_update.set() - previous_status = entry["status"] - entry["status"] = "Resetting" + previous_status = entry.status + entry.status = "Resetting" reset_thread = threading.Thread(target=reset_thread) reset_thread.daemon = True reset_thread.start() @@ -827,13 +1003,28 @@ def draw_help_bar(stdscr, current_line, num_input_fields, num_output_entries): def get_help_text(current_line, num_input_fields, num_output_entries): if current_line == 0: - return "Shortcuts: demo, attention, mlp, decoder, decoder-prefill, model, model-prefill, model-quick | Enter: Submit | ↑↓: Navigate fields | Esc: Exit" + return "Shortcuts: demo, demo-1layer, attention, mlp, rmsnorm, decoder, model, model-quick, 'help' for full list | Enter: Submit | ↑↓: Navigate fields | Esc: Exit" elif current_line <= num_input_fields - 1: return "Enter: Next field | ↑↓: Navigate fields | Esc: Exit" else: - return ( - "Enter: View log | Backspace/x: Cancel/remove entry | r: Restart entry | ↑↓: Navigate entries | Esc: Exit" - ) + return "Enter: View log | Backspace/x: Cancel entry | X: Cancel all | r: Restart entry | ↑↓: Navigate entries | Esc: Exit" + + +def cancel_entry(entry): + """Handle removal of a single entry, returning True if entry was removed""" + with entry["lock"]: + if entry["process"] and entry["process"].poll() is None: + # Cancel the running process + entry["stop_event"].set() + terminate_process_tree(entry["process"].pid) + entry["status"] = "Terminating" + # Entry is still running, so don't remove it + return False + elif entry["status"] != "Resetting": + # Safe to remove the entry if it's already cancelled + return True + # Entry is running/resetting, so don't remove it + return False if __name__ == "__main__":