diff --git a/.github/workflows/test_ipex.yml b/.github/workflows/test_ipex.yml index 8e02bd551..7bb8947ab 100644 --- a/.github/workflows/test_ipex.yml +++ b/.github/workflows/test_ipex.yml @@ -18,6 +18,7 @@ jobs: fail-fast: false matrix: python-version: [3.8, 3.9] + transformers-version: [4.39.0, 4.41.2] os: [ubuntu-latest] runs-on: ${{ matrix.os }} @@ -32,6 +33,7 @@ jobs: python -m pip install --upgrade pip pip install torch torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu pip install .[ipex,tests] + pip install transformers==${{ matrix.transformers-version }} - name: Test with Pytest run: | pytest tests/ipex/ diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 60ff3b721..0d87a5fd6 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -23,6 +23,7 @@ from optimum.intel.utils.import_utils import is_ipex_version from .modeling_utils import ( + _IPEX_MINIMUM_VERSION_FOR_PATCHING, _IPEXLlamaDecoderLayerRef, _llama_attn_forward, _llama_layer_norm_forward, @@ -62,10 +63,12 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - if is_ipex_version("<", "2.5.0"): - raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache") + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): + raise ImportError( + f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports RotaryEmbedding and IndirectAccessKVCacheAttention" + ) - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding ipex_rope = RotaryEmbedding( model.config.max_position_embeddings, @@ -73,7 +76,7 @@ def _patch_llama_model(model): model.config.rope_theta, model.config.architectures[0], ) - ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) + ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings) patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index f75e559ea..a2b73e74a 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -19,9 +19,15 @@ from torch import nn from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import repeat_kv +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv -from optimum.intel.utils.import_utils import is_ipex_version +from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version + + +# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version +_TRANSFORMERS_MIN_VERSION = "4.39.0" +_TRANSFORMERS_MAX_VERSION = "4.41.2" +_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0" # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 @@ -51,27 +57,27 @@ def _llama_attn_forward( query = query.view(bsz, q_len, self.num_heads, self.head_dim) key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim) value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - # Use ipex op to rotary position embedding more efficient. - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) if use_cache: + # Use ipex op to rotary position embedding more efficient. + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) # This ipex op pre-allocates buffers for past_key_values and use beam index history # which to decide which beam should be used to make attention scale dot more efficient. (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( @@ -87,6 +93,8 @@ def _llama_attn_forward( value_states = value.transpose(1, 2) query_states = query.transpose(1, 2) key_states = key.transpose(1, 2) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) kv_seq_len = key_states.shape[-2] past_key_value = None @@ -219,8 +227,16 @@ def _llama_model_forward( # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayerRef(nn.Module): def __init__(self, module, config, distributed=False): - if is_ipex_version("<", "2.5.0"): - raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): + raise ImportError( + f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul and LinearAdd" + ) + if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version( + ">", _TRANSFORMERS_MAX_VERSION + ): + raise ImportError( + f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified." + ) from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd @@ -278,7 +294,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, ) - if not self.distributed: + if hasattr(self, "mha_linear_add"): hidden_states = self.mha_linear_add(hidden_states, residual) else: hidden_states = self.self_attn.o_proj(hidden_states) @@ -288,12 +304,15 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - mlp_gate = self.linear_silu_mul(hidden_states) - - if not self.distributed: - hidden_states = self.mlp_linear_add(mlp_gate, residual) + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.mlp.down_proj(mlp_gate) + hidden_states = residual + hidden_states else: - hidden_states = self.mlp.down_proj(mlp_gate) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index e929a4ddb..3750d5622 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -18,7 +18,7 @@ import warnings from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import intel_extension_for_pytorch as ipex import torch @@ -50,7 +50,7 @@ from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager -from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model +from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model from ..generation.modeling import prepare_jit_inputs from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device @@ -60,10 +60,11 @@ _IPEX_SUPPORT_MODEL_TYPES = ("llama",) +_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") def _is_patched_with_ipex(model, task): - if is_ipex_version("<", "2.5.0"): + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): return False if isinstance(model, torch.jit.ScriptModule): @@ -73,7 +74,12 @@ def _is_patched_with_ipex(model, task): return True return False else: - return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK + # The ipex IAKV op in patched model requires the hidden size at least 64 + return ( + model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES + and task in _IPEX_EXPORTED_TASK + and model.config.hidden_size >= 64 + ) def ipex_jit_trace(model, task, use_cache): @@ -83,6 +89,7 @@ def ipex_jit_trace(model, task, use_cache): if _is_patched_with_ipex(model, task): model = _patch_model(model) + # Todo: integerate in prepare_jit_inputs. sample_inputs = get_dummy_input(model, return_dict=True) # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755. _enable_tpp() @@ -92,9 +99,10 @@ def ipex_jit_trace(model, task, use_cache): model.config.return_dict = False - if "past_key_values" in sample_inputs and use_cache: - # Make sure the model will output past_key_values in generation tasks - model.config.use_cache = True + if "past_key_values" in sample_inputs: + model.config.use_cache = use_cache + if not use_cache: + sample_inputs.pop("past_key_values") model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True) # Disable repack while jit tracing to reduce the memory @@ -522,6 +530,23 @@ def _prepare_past_key_values(self, input_ids): return past_key_values + # Temporary fix, will delete when https://github.com/huggingface/transformers/pull/31226 release. + def _get_initial_cache_position(self, input_ids, model_kwargs): + """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" + if not model_kwargs.get("use_cache", True): + model_kwargs["cache_position"] = None + return model_kwargs + + past_length = 0 + if "past_key_values" in model_kwargs: + past_length = model_kwargs["past_key_values"][0][0].shape[-2] + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + else: + cur_len = input_ids.shape[-1] + model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) + return model_kwargs + def forward( self, input_ids: torch.LongTensor = None, @@ -561,6 +586,25 @@ def forward( return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], **kwargs: Dict + ) -> Tuple[GenerationConfig, Dict]: + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) + generation_method = generation_config.get_generation_mode().value + if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS: + raise ValueError( + f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" + ) + + return generation_config, model_kwargs + + def generate(self, *args, **kwargs): + if self._is_ipex_exported and kwargs.get("assistant_model", None): + raise ValueError( + f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" + ) + return super().generate(*args, **kwargs) + def _prepare_inputs_for_generation_for_llama( input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs diff --git a/setup.py b/setup.py index b8869f46a..cfb28db87 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ "neural-compressor": ["neural-compressor>=2.2.0", "onnxruntime<1.15.0", "accelerate"], "openvino": ["openvino>=2023.3", "nncf>=2.10.0", "openvino-tokenizers[transformers]"], "nncf": ["nncf>=2.10.0"], - "ipex": ["intel-extension-for-pytorch", "transformers>=4.36.0,<4.39.0"], + "ipex": ["intel-extension-for-pytorch", "transformers>=4.39.0,<=4.41.2"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 2a2f18f6f..8664b99ce 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -14,6 +14,7 @@ # ruff: noqa +import tempfile import time import unittest @@ -87,10 +88,16 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**tokens) # Compare tensor outputs for output_name in {"logits", "last_hidden_state"}: if output_name in transformers_outputs: self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4)) + self.assertTrue(torch.equal(outputs[output_name], loaded_model_outputs[output_name])) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): @@ -139,11 +146,19 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**tokens) + self.assertIn("start_logits", outputs) self.assertIn("end_logits", outputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.start_logits, transformers_outputs.start_logits, atol=1e-4)) self.assertTrue(torch.allclose(outputs.end_logits, transformers_outputs.end_logits, atol=1e-4)) + self.assertTrue(torch.equal(outputs.start_logits, loaded_model_outputs.start_logits)) + self.assertTrue(torch.equal(outputs.end_logits, loaded_model_outputs.end_logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): @@ -171,14 +186,14 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "gpt2", "gpt_neo", "gpt_neox", + "mistral", "llama", "llama2", - "mistral", # "phi", "mpt", "opt", ) - IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama",) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2",) GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.0 @@ -204,8 +219,14 @@ def test_compare_to_transformers(self, model_arch): transformers_model = AutoModelForCausalLM.from_pretrained(model_id) with torch.no_grad(): transformers_outputs = transformers_model(**tokens) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**inputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): @@ -219,18 +240,23 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) + # High optimized model llama is not supported assisted decoding for now. @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_assisted_decoding(self, model_arch): + if model_arch == "llama2": + return model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) transformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokens = tokenizer("This is a sample input", return_tensors="pt") - ipex_output = ipex_model.generate(**tokens, do_sample=False) - ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model) - transformers_output = transformers_model.generate(**tokens, do_sample=False) + ipex_output = ipex_model.generate(**tokens, do_sample=False, max_new_tokens=4) + ipex_output_assisted = ipex_model.generate( + **tokens, do_sample=False, assistant_model=transformers_model, max_new_tokens=4 + ) + transformers_output = transformers_model.generate(**tokens, do_sample=False, max_new_tokens=4) transformers_output_assisted = transformers_model.generate( - **tokens, do_sample=False, assistant_model=ipex_model + **tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4 ) self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) self.assertTrue(torch.equal(transformers_output, transformers_output_assisted)) @@ -243,24 +269,25 @@ def test_assisted_decoding(self, model_arch): } ) ) - @unittest.skipIf(is_ipex_version("<", "2.5.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") + @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): model_id = MODEL_NAMES[model_arch] set_seed(SEED) model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache) - self.assertEqual(model.use_cache, use_cache) trasnformers_model = AutoModelForCausalLM.from_pretrained(model_id) + self.assertEqual(model.use_cache, use_cache) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token # Test with batch_size is 1 and 2. texts = ["This is a sample", ["This is the first input", "This is the second input"]] generation_configs = ( - GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=True), - GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=True), - GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=True), - GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=True), - GenerationConfig(max_new_tokens=4, do_sample=not use_cache, top_p=1.0, top_k=5, penalty_alpha=0.6), - GenerationConfig(max_new_tokens=4, do_sample=True, top_p=0.9, top_k=0), + GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=False), + GenerationConfig( + max_new_tokens=4, do_sample=False, top_p=0.9, top_k=0, pad_token_id=tokenizer.eos_token_id + ), ) for text in texts: tokens = tokenizer(text, padding=True, return_tensors="pt") @@ -268,7 +295,7 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): outputs = model.generate(**tokens, generation_config=generation_config) transformers_outputs = trasnformers_model.generate(**tokens, generation_config=generation_config) self.assertIsInstance(outputs, torch.Tensor) - self.assertEqual(outputs, transformers_outputs) + self.assertTrue(torch.equal(outputs, transformers_outputs)) def test_compare_with_and_without_past_key_values(self): model_id = "echarlaix/tiny-random-gpt2-torchscript" @@ -326,8 +353,14 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**inputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-3)) + self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): @@ -366,9 +399,16 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**inputs) + self.assertIn("logits", outputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch):