diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index dd67a7735a647..e310941afacf3 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -24,14 +24,14 @@ from .conftest import run_greedy_equality_correctness_test # main model -MAIN_MODEL = "ibm-granite/granite-3b-code-instruct" +MAIN_MODEL = "JackFram/llama-160m" # speculative model -SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator" +SPEC_MODEL = "ibm-fms/llama-160m-accelerator" # max. number of speculative tokens: this corresponds to # n_predict in the config.json of the speculator model. -MAX_SPEC_TOKENS = 5 +MAX_SPEC_TOKENS = 3 # precision PRECISION = "float32"