Skip to content

Commit

Permalink
Update required dims for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
pmarkovicTT committed Feb 4, 2025
1 parent 0e2ab4b commit 710afb4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 4 additions & 4 deletions forge/test/mlir/llama/test_llama_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,19 @@ def test_llama_inference_cache_cpu(model_path):
pytest.param("meta-llama/Llama-3.2-1B", marks=pytest.mark.xfail(reason="Unsupported Op: repeat_interleave")),
],
)
@pytest.mark.parametrize("seq_len", [1, 2, 4, 7, 8, 16, 28, 32, 63, 64, 99, 117, 128, 256, 341, 512, 1024, 1790, 2048])
@pytest.mark.skip(reason="No need to run in CI as it takes a long time to run.")
@pytest.mark.parametrize("seq_len", [128, 512, 2048])
def test_llama_input_sequence_lengths(model_path, seq_len):
# Load Model and Tokenizer
framework_model, tokenizer = load_model(model_path, seq_len=seq_len)
framework_model, tokenizer = load_model(model_path, num_hidden_layers=1)

# Adjust tokenizer for max sequence length padding
tokenizer.pad_token = "<pad>"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.model_max_length = seq_len

prompt = "Q: What is the largest animal?\nA:"
input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
input_ids = input_ids.to(torch.int32)

# Compile the model and run fwd pass
compiled_model = forge.compile(framework_model, input_ids)
Expand Down
1 change: 1 addition & 0 deletions forge/test/mlir/llama/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def load_model(model_path="openlm-research/open_llama_3b", **kwargs):
config.use_cache = kwargs.get("use_cache", False)
config.output_attentions = kwargs.get("output_attentions", False)
config.output_hidden_states = kwargs.get("output_hidden_states", False)
config.num_hidden_layers = kwargs.get("num_hidden_layers", 26)

# Load the model
framework_model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto", config=config)
Expand Down

0 comments on commit 710afb4

Please sign in to comment.