Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank] Add perplexity test with latest values #839

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
name: CI - sharktank perplexity

on:
pull_request:
workflow_dispatch:
schedule:
# Weekdays at 11:00 AM UTC = 03:00 AM PST / 04:00 AM PDT
Expand Down Expand Up @@ -65,7 +66,7 @@ jobs:
- name: Run perplexity test with IREE
run: |
source ${VENV_DIR}/bin/activate
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=50 --iree-device='local-task' --iree-hal-target-device=llvm-cpu --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_eval_short.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ jobs:
- name: Run perplexity test with vmfb
run: |
source ${VENV_DIR}/bin/activate
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=4 --iree-device='local-task' --iree-hal-target-device=llvm-cpu --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
4 changes: 3 additions & 1 deletion sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def compile_to_vmfb(
compile_args = [
f"iree-compile",
f"{mlir_path}",
f"--iree-hip-target={self.iree_hip_target}",
f"-o={vmfb_path}",
]
if self.tensor_parallelism_size > 1:
Expand All @@ -228,6 +227,9 @@ def compile_to_vmfb(
f"--iree-hal-target-device={self.iree_hal_target_device}"
]
compile_args += iree_hal_target_devices

if self.iree_hal_target_device == "hip":
compile_args += [f"--iree-hip-target={self.iree_hip_target}"]
if hal_dump_path:
compile_args += [
f"--iree-hal-dump-executable-files-to={hal_dump_path}/files"
Expand Down
55 changes: 55 additions & 0 deletions sharktank/tests/evaluate/baseline_perplexity_scores.json
Original file line number Diff line number Diff line change
Expand Up @@ -314,5 +314,60 @@
16.469248
],
"mean_perplexity": 14.991893
},
"llama3_8B_f16_decomposed_iree_fused_rotary": {
"perplexities": [
32.642723,
22.781422,
16.205664,
1752.897827,
96.782982,
13.282438,
833.510254,
233.861618,
8.571332,
53.638924,
5395.990234,
15.051053,
10.305574,
24.451237,
1329.104614,
10.311046,
7.12526,
80.784698,
11.205793,
21.927229,
38.855869,
804.015503,
16.602699,
63.275284,
3480.645508,
19.889006,
20.02833,
13.375215,
406.924988,
447.162445,
9.940166,
102.297752,
58.567444,
10.406951,
3407.479004,
1408.818726,
9.500906,
589.322998,
9.515339,
8573.863281,
5842.658691,
290.635254,
3790.750244,
3.99186,
22.958128,
10.617723,
90238.078125,
1032.760864,
1493.486206,
35.754234
],
"mean_perplexity": 2644.452213
}
}
44 changes: 43 additions & 1 deletion sharktank/tests/evaluate/perplexity_iree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sharktank.evaluate import perplexity_iree

is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'")
is_cpu = pytest.mark.skipif("config.getoption('iree_device') != 'local-task'")
skipif_run_quick_llama_test = pytest.mark.skipif(
'not config.getoption("run-nightly-llama-tests")',
reason="Run large tests if --run-nightly-llama-tests is passed",
Expand All @@ -25,7 +26,6 @@
"baseline_perplexity_scores",
"batch_size",
)
@is_mi300x
class PerplexityTest(unittest.TestCase):
def setUp(self):
self.current_perplexity_all = {}
Expand All @@ -34,6 +34,41 @@ def setUp(self):
with open(self.baseline_perplexity_scores, "r") as f:
self.baseline_perplexity = json.load(f)

@is_cpu
def test_llama3_8B_f16_decomposed_fused_rotary(self):

# Llama 3.1 8B decomposed

model_name = "llama3_8B_f16_decomposed_iree_fused_rotary"
baseline_perplexity = self.baseline_perplexity[model_name]

current_perplexity = perplexity_iree.main(
[
f"--irpa-file={self.llama3_8b_f16_model}",
f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
f"--iree-device={self.iree_device}",
f"--iree-hal-target-device={self.iree_hal_target_device}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=decomposed",
f"--num-prompts={self.batch_size}",
]
)

baseline_mean_perplexity = round(
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)

perplexity_difference = current_mean_perplexity - baseline_mean_perplexity

self.assertAlmostEqual(
baseline_mean_perplexity,
current_mean_perplexity,
delta=self.delta,
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@is_mi300x
@pytest.mark.xfail(reason="Runtime segfault", run=False)
def test_llama3_8B_f16_decomposed(self):

Expand Down Expand Up @@ -69,6 +104,7 @@ def test_llama3_8B_f16_decomposed(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@is_mi300x
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error")
def test_llama3_8B_f16(self):
Expand Down Expand Up @@ -105,6 +141,7 @@ def test_llama3_8B_f16(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@is_mi300x
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error")
def test_llama3_8B_fp8_decomposed(self):
Expand Down Expand Up @@ -141,6 +178,7 @@ def test_llama3_8B_fp8_decomposed(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@is_mi300x
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error")
def test_llama3_8B_fp8(self):
Expand Down Expand Up @@ -177,6 +215,7 @@ def test_llama3_8B_fp8(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@is_mi300x
@skipif_run_quick_llama_test
@pytest.mark.xfail(
reason="Sharding is unsupported",
Expand Down Expand Up @@ -215,6 +254,7 @@ def test_llama3_405B_f16_decomposed(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@is_mi300x
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error")
def test_llama3_405B_f16(self):
Expand Down Expand Up @@ -251,6 +291,7 @@ def test_llama3_405B_f16(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@is_mi300x
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error")
def test_llama3_405B_fp8_decomposed(self):
Expand Down Expand Up @@ -287,6 +328,7 @@ def test_llama3_405B_fp8_decomposed(self):
msg=f"Current perplexity deviates baseline by {perplexity_difference}",
)

@is_mi300x
@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error")
def test_llama3_405B_fp8(self):
Expand Down
Loading