Skip to content

Commit

Permalink
#0: 4-bit MLP and accuracy/perf comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 28, 2024
1 parent 4584bc3 commit 2ce545a
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 1 deletion.
35 changes: 35 additions & 0 deletions models/demos/llama3/PERF.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Llama 3 model performance and accuracy

Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected from [tests/test_llama_accuracy.py](tests/test_llama_accuracy.py). You can generate this table by running these tests with the `lt` tool (tell it to run `accuracy,demo`) and pressing `m` whilst in the results section to export to markdown.

4-bit MLP:

| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|-------|--------|-----------|-----------|---------------|
| 1b | N150 | 79 | 98 | 90.5 |
| 1b | N300 | 81 | 98 | 101.7 |
| 1b | T3K | 81 | 98 | 97.5 |
| 3b | N150 | 85 | 96 | 49.0 |
| 3b | N300 | 88 | 97 | 56.9 |
| 3b | T3K | 88 | 97 | 54.5 |
| 8b | N150 | 86 | 98 | 28.4 |
| 8b | N300 | 84 | 98 | 38.6 |
| 8b | T3K | 84 | 98 | 52.6 |
| 11b | N300 | 86 | 97 | 38.6 |
| 70b | T3K | 95 | 100 | 14.3 |

Mixed-bit MLP (main):

| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|-------|--------|-----------|-----------|---------------|
| 1b | N150 | 77 | 96 | 85.8 |
| 1b | N300 | 80 | 98 | 98.6 |
| 1b | T3K | 78 | 98 | 97.2 |
| 3b | N150 | 88 | 98 | 44.1 |
| 3b | N300 | 88 | 98 | 53.9 |
| 3b | T3K | 88 | 98 | 54.8 |
| 8b | N150 | 89 | 98 | 23.5 |
| 8b | N300 | 90 | 98 | 34.1 |
| 8b | T3K | 88 | 97 | 49.9 |
| 11b | N300 | 90 | 97 | 33.8 |
| 70b | T3K | 95 | 100 | 14.5 |
4 changes: 4 additions & 0 deletions models/demos/llama3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,7 @@ pytest models/demos/llama3/demo/demo.py -k 'instruct and 1_batch'
# Run 2 continuous batches with general weights
pytest models/demos/llama3/demo/demo.py -k 'general and 2_batch'
```

### Expected performance and accuracy

See [PERF.md](PERF.md) for expected performance and accuracy.
47 changes: 47 additions & 0 deletions models/demos/llama3/lt
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@ def main(stdscr):
entry["log_file"] = None
entry["stop_event"].clear()
screen_needs_update.set()
elif c == ord("m") and current_line >= len(input_fields):
# Export results to markdown
export_results_to_markdown(output_entries, stdscr)
screen_needs_update.set()
else:
if current_line < len(input_fields) and not exiting:
current_field = current_line
Expand Down Expand Up @@ -1033,6 +1037,49 @@ def cancel_entry(entry):
return False


def export_results_to_markdown(output_entries, stdscr):
demo_results = {}
accuracy_results = {}

# Collect results from entries
for entry in output_entries:
if entry.command_name == "demo" and entry.status == "Finished":
demo_results[(entry.model, entry.device)] = entry.speed
elif entry.command_name == "accuracy" and entry.status == "Finished":
# Parse Top-1 and Top-5 from output
top1 = "N/A"
top5 = "N/A"
if entry.output:
match = re.search(r"Top-1: (\d+)% \| Top-5: (\d+)%", entry.output)
if match:
top1, top5 = match.groups()
accuracy_results[(entry.model, entry.device)] = (top1, top5)

# Create markdown table
markdown_lines = [
"| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |",
"|-------|--------|-----------|-----------|---------------|",
]

for key in demo_results.keys():
model, device = key
speed = demo_results.get(key, "N/A")
top1, top5 = accuracy_results.get(key, ("N/A", "N/A"))
markdown_lines.append(f"| {model} | {device} | {top1} | {top5} | {speed} |")

# Write to PERF.md
with open("PERF.md", "w") as f:
f.write("\n".join(markdown_lines) + "\n")

# Clear screen and show message
stdscr.clear()
stdscr.addstr(0, 0, "\n".join(markdown_lines))
stdscr.addstr(len(markdown_lines) + 1, 0, f"Table written to {os.path.abspath('PERF.md')}")
stdscr.addstr(len(markdown_lines) + 2, 0, "Press any key to return...")
stdscr.refresh()
stdscr.getch() # Wait for a key press


if __name__ == "__main__":
os.environ["TERM"] = "xterm-256color"

Expand Down
3 changes: 2 additions & 1 deletion models/demos/llama3/tt/llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(
cache_file_name=cache_name(name),
)

self.four_bit_mlp = self.args.is_large_model
# Set to "self.args.is_large_model" for mixed-mode MLP which is slightly more accurate
self.four_bit_mlp = True

# Sharded weights
self.w1 = as_sharded_tensor(
Expand Down

0 comments on commit 2ce545a

Please sign in to comment.