Skip to content

Commit

Permalink
#0: Add X to clear all
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 8, 2024
1 parent e47daaf commit c7f2446
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions models/demos/llama3/lt
Original file line number Diff line number Diff line change
Expand Up @@ -439,19 +439,24 @@ def main(stdscr):
entry_index = current_line - len(input_fields)
if entry_index < len(output_entries):
entry = output_entries[entry_index]
with entry["lock"]:
if entry["process"] and entry["process"].poll() is None:
# Cancel the running process
entry["stop_event"].set()
terminate_process_tree(entry["process"].pid)
entry["status"] = "Terminating"
elif entry["status"] != "Resetting":
# Remove the entry if it's already cancelled
output_entries.pop(entry_index)
total_lines = len(input_fields) + len(output_entries)
if current_line >= total_lines:
current_line = total_lines - 1
if cancel_entry(entry):
output_entries.pop(entry_index)
total_lines = len(input_fields) + len(output_entries)
if current_line >= total_lines:
current_line = total_lines - 1
screen_needs_update.set()
elif c == ord("X") and current_line >= len(input_fields): # Shift-X to clear all entries
to_remove = []
for entry in output_entries:
if cancel_entry(entry):
to_remove.append(entry)
for entry in to_remove:
entry_index = output_entries.index(entry)
output_entries.pop(entry_index)
screen_needs_update.set()
total_lines = len(input_fields) + len(output_entries)
if current_line >= total_lines:
current_line = total_lines - 1
elif c == 9: # Tab key
total_lines = len(input_fields) + len(output_entries)
current_line = (current_line + 1) % total_lines
Expand Down Expand Up @@ -714,7 +719,7 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update):
# Define command shortcuts
command_shortcuts = {
"demo": "pytest models/demos/llama3/demo/demo.py -k instruct_weights-1",
"demo_1layer": "pytest models/demos/llama3/demo/demo.py -k single_layer",
"demo-1layer": "pytest models/demos/llama3/demo/demo.py -k single_layer",
"attention": "pytest models/demos/llama3/tests/test_llama_attention.py",
"attention-prefill": "pytest models/demos/llama3/tests/test_llama_attention_prefill.py",
"mlp": "pytest models/demos/llama3/tests/test_llama_mlp.py",
Expand Down Expand Up @@ -998,13 +1003,28 @@ def draw_help_bar(stdscr, current_line, num_input_fields, num_output_entries):

def get_help_text(current_line, num_input_fields, num_output_entries):
if current_line == 0:
return "Shortcuts: demo, attention, mlp, decoder, decoder-prefill, model, model-prefill, model-quick | Enter: Submit | ↑↓: Navigate fields | Esc: Exit"
return "Shortcuts: demo, demo-1layer, attention, mlp, rmsnorm, decoder, model, model-quick, 'help' for full list | Enter: Submit | ↑↓: Navigate fields | Esc: Exit"
elif current_line <= num_input_fields - 1:
return "Enter: Next field | ↑↓: Navigate fields | Esc: Exit"
else:
return (
"Enter: View log | Backspace/x: Cancel/remove entry | r: Restart entry | ↑↓: Navigate entries | Esc: Exit"
)
return "Enter: View log | Backspace/x: Cancel entry | X: Cancel all | r: Restart entry | ↑↓: Navigate entries | Esc: Exit"


def cancel_entry(entry):
"""Handle removal of a single entry, returning True if entry was removed"""
with entry["lock"]:
if entry["process"] and entry["process"].poll() is None:
# Cancel the running process
entry["stop_event"].set()
terminate_process_tree(entry["process"].pid)
entry["status"] = "Terminating"
# Entry is still running, so don't remove it
return False
elif entry["status"] != "Resetting":
# Safe to remove the entry if it's already cancelled
return True
# Entry is running/resetting, so don't remove it
return False


if __name__ == "__main__":
Expand Down

0 comments on commit c7f2446

Please sign in to comment.