diff --git a/models/demos/llama3/PERF.md b/models/demos/llama3/PERF.md new file mode 100644 index 000000000000..25835100b2b4 --- /dev/null +++ b/models/demos/llama3/PERF.md @@ -0,0 +1,35 @@ +# Llama 3 model performance and accuracy + +Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected from [tests/test_llama_accuracy.py](tests/test_llama_accuracy.py). You can generate this table by running these tests with the `lt` tool (tell it to run `accuracy,demo`) and pressing `m` whilst in the results section to export to markdown. + +4-bit MLP: + +| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | +|-------|--------|-----------|-----------|---------------| +| 1b | N150 | 79 | 98 | 90.5 | +| 1b | N300 | 81 | 98 | 101.7 | +| 1b | T3K | 81 | 98 | 97.5 | +| 3b | N150 | 85 | 96 | 49.0 | +| 3b | N300 | 88 | 97 | 56.9 | +| 3b | T3K | 88 | 97 | 54.5 | +| 8b | N150 | 86 | 98 | 28.4 | +| 8b | N300 | 84 | 98 | 38.6 | +| 8b | T3K | 84 | 98 | 52.6 | +| 11b | N300 | 86 | 97 | 38.6 | +| 70b | T3K | 95 | 100 | 14.3 | + +Mixed-bit MLP (main): + +| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | +|-------|--------|-----------|-----------|---------------| +| 1b | N150 | 77 | 96 | 85.8 | +| 1b | N300 | 80 | 98 | 98.6 | +| 1b | T3K | 78 | 98 | 97.2 | +| 3b | N150 | 88 | 98 | 44.1 | +| 3b | N300 | 88 | 98 | 53.9 | +| 3b | T3K | 88 | 98 | 54.8 | +| 8b | N150 | 89 | 98 | 23.5 | +| 8b | N300 | 90 | 98 | 34.1 | +| 8b | T3K | 88 | 97 | 49.9 | +| 11b | N300 | 90 | 97 | 33.8 | +| 70b | T3K | 95 | 100 | 14.5 | diff --git a/models/demos/llama3/README.md b/models/demos/llama3/README.md index 19826aa77c1d..43ad9556ab1a 100644 --- a/models/demos/llama3/README.md +++ b/models/demos/llama3/README.md @@ -84,3 +84,7 @@ pytest models/demos/llama3/demo/demo.py -k 'instruct and 1_batch' # Run 2 continuous batches with general weights pytest models/demos/llama3/demo/demo.py -k 'general and 2_batch' ``` + +### Expected performance and accuracy + +See [PERF.md](PERF.md) for expected performance and accuracy. diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index c75817ea3c5d..75e04befb9ce 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -477,6 +477,10 @@ def main(stdscr): entry["log_file"] = None entry["stop_event"].clear() screen_needs_update.set() + elif c == ord("m") and current_line >= len(input_fields): + # Export results to markdown + export_results_to_markdown(output_entries, stdscr) + screen_needs_update.set() else: if current_line < len(input_fields) and not exiting: current_field = current_line @@ -1033,6 +1037,49 @@ def cancel_entry(entry): return False +def export_results_to_markdown(output_entries, stdscr): + demo_results = {} + accuracy_results = {} + + # Collect results from entries + for entry in output_entries: + if entry.command_name == "demo" and entry.status == "Finished": + demo_results[(entry.model, entry.device)] = entry.speed + elif entry.command_name == "accuracy" and entry.status == "Finished": + # Parse Top-1 and Top-5 from output + top1 = "N/A" + top5 = "N/A" + if entry.output: + match = re.search(r"Top-1: (\d+)% \| Top-5: (\d+)%", entry.output) + if match: + top1, top5 = match.groups() + accuracy_results[(entry.model, entry.device)] = (top1, top5) + + # Create markdown table + markdown_lines = [ + "| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |", + "|-------|--------|-----------|-----------|---------------|", + ] + + for key in demo_results.keys(): + model, device = key + speed = demo_results.get(key, "N/A") + top1, top5 = accuracy_results.get(key, ("N/A", "N/A")) + markdown_lines.append(f"| {model} | {device} | {top1} | {top5} | {speed} |") + + # Write to PERF.md + with open("PERF.md", "w") as f: + f.write("\n".join(markdown_lines) + "\n") + + # Clear screen and show message + stdscr.clear() + stdscr.addstr(0, 0, "\n".join(markdown_lines)) + stdscr.addstr(len(markdown_lines) + 1, 0, f"Table written to {os.path.abspath('PERF.md')}") + stdscr.addstr(len(markdown_lines) + 2, 0, "Press any key to return...") + stdscr.refresh() + stdscr.getch() # Wait for a key press + + if __name__ == "__main__": os.environ["TERM"] = "xterm-256color" diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index 1d18953b5d4c..a119825cff87 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -39,7 +39,8 @@ def __init__( cache_file_name=cache_name(name), ) - self.four_bit_mlp = self.args.is_large_model + # Set to "self.args.is_large_model" for mixed-mode MLP which is slightly more accurate + self.four_bit_mlp = True # Sharded weights self.w1 = as_sharded_tensor(