diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index debcf42ab38fad..d81fe3efd998e5 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -19,8 +19,10 @@ import unittest import pytest +from packaging import version from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( backend_empty_cache, require_bitsandbytes, @@ -640,3 +642,56 @@ def test_speculative_generation(self): del model backend_empty_cache(torch_device) gc.collect() + + @slow + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + qwen_model = "Qwen/Qwen2.5-0.5B" + + tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="", padding_side="right") + EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% sugar. I have a jar of 1000 grams of sugar. I use"] + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = Qwen2ForCausalLM.from_pretrained( + qwen_model, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + }, + ), + ) + + prompt = ["My favourite condiment is "] + prompt_tokens = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)