diff --git a/forge/test/mlir/llama/test_llama_inference.py b/forge/test/mlir/llama/test_llama_inference.py index f70fc6bf3..7a75b141a 100644 --- a/forge/test/mlir/llama/test_llama_inference.py +++ b/forge/test/mlir/llama/test_llama_inference.py @@ -134,19 +134,19 @@ def test_llama_inference_cache_cpu(model_path): pytest.param("meta-llama/Llama-3.2-1B", marks=pytest.mark.xfail(reason="Unsupported Op: repeat_interleave")), ], ) -@pytest.mark.parametrize("seq_len", [1, 2, 4, 7, 8, 16, 28, 32, 63, 64, 99, 117, 128, 256, 341, 512, 1024, 1790, 2048]) -@pytest.mark.skip(reason="No need to run in CI as it takes a long time to run.") +@pytest.mark.parametrize("seq_len", [128, 512, 2048]) def test_llama_input_sequence_lengths(model_path, seq_len): # Load Model and Tokenizer - framework_model, tokenizer = load_model(model_path, seq_len=seq_len) + framework_model, tokenizer = load_model(model_path, num_hidden_layers=1) # Adjust tokenizer for max sequence length padding - tokenizer.pad_token = "" + tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" tokenizer.model_max_length = seq_len prompt = "Q: What is the largest animal?\nA:" input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids + input_ids = input_ids.to(torch.int32) # Compile the model and run fwd pass compiled_model = forge.compile(framework_model, input_ids) diff --git a/forge/test/mlir/llama/utils/utils.py b/forge/test/mlir/llama/utils/utils.py index dc33f5116..880470fb8 100644 --- a/forge/test/mlir/llama/utils/utils.py +++ b/forge/test/mlir/llama/utils/utils.py @@ -16,6 +16,7 @@ def load_model(model_path="openlm-research/open_llama_3b", **kwargs): config.use_cache = kwargs.get("use_cache", False) config.output_attentions = kwargs.get("output_attentions", False) config.output_hidden_states = kwargs.get("output_hidden_states", False) + config.num_hidden_layers = kwargs.get("num_hidden_layers", 26) # Load the model framework_model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto", config=config)