Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ipex 2.3 released #725

Merged
merged 50 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
5351f4a
ipex 2.3 released
jiqing-feng May 23, 2024
1f98d6d
skip tests
jiqing-feng May 27, 2024
b2b93bb
skip testing without pkv
jiqing-feng May 27, 2024
64dcde4
add tests skip
jiqing-feng May 27, 2024
945f6b6
only llama2 with at least 64 head size support IAKV
jiqing-feng May 27, 2024
c8922f3
cannot assert same outputs cause do_sample=True
jiqing-feng May 27, 2024
2ddfa7a
rm tiny-llama model testing cause it not work for IAKV
jiqing-feng May 27, 2024
f4e887d
fix code style
jiqing-feng May 28, 2024
d96ea58
fix style
jiqing-feng May 28, 2024
ec24d5a
rm tiny llama on test pipeline
jiqing-feng May 28, 2024
871de7b
fix tests
jiqing-feng May 30, 2024
d0c8951
support use_cache=False
jiqing-feng May 30, 2024
537f0aa
rm use_cache in model_kwargs
jiqing-feng May 30, 2024
5a71790
set use_cache
jiqing-feng May 30, 2024
bde814e
Update optimum/intel/ipex/modeling_base.py
jiqing-feng May 31, 2024
4a81ea9
fix spelling error
jiqing-feng May 31, 2024
3a61e84
fix style
jiqing-feng May 31, 2024
fd69407
add transformers version warning
jiqing-feng May 31, 2024
1032a26
add compare resultes
jiqing-feng May 31, 2024
c8e7969
add warning
jiqing-feng May 31, 2024
afdc8d7
set pad_token_id
jiqing-feng May 31, 2024
1d1df34
limited transformers
jiqing-feng Jun 3, 2024
aaaa4c3
fix transformers version
jiqing-feng Jun 3, 2024
f6b8010
update transformers version
jiqing-feng Jun 4, 2024
51e47b6
fix version
jiqing-feng Jun 4, 2024
5204b24
temporary fix for multi-query model
jiqing-feng Jun 4, 2024
8f2f025
fix code styke
jiqing-feng Jun 4, 2024
8dc5ad5
add transformers version tests
jiqing-feng Jun 4, 2024
e482e58
Update .github/workflows/test_ipex.yml
jiqing-feng Jun 5, 2024
d366b80
check geenration method
jiqing-feng Jun 5, 2024
3948cad
Update optimum/intel/ipex/modeling_base.py
jiqing-feng Jun 5, 2024
d1b63ef
fix use_cache
jiqing-feng Jun 5, 2024
ea4d3e2
add hidden size limitation for patch
jiqing-feng Jun 5, 2024
bcb2b5a
add llama in tests
jiqing-feng Jun 5, 2024
f5f1af8
add re-load tests
jiqing-feng Jun 5, 2024
c08c957
fix hidden size check
jiqing-feng Jun 5, 2024
51e6f3d
rm norm config
jiqing-feng Jun 5, 2024
d06123b
add version variable
jiqing-feng Jun 5, 2024
641e8f9
fix import
jiqing-feng Jun 5, 2024
50c1059
rm useless logger
jiqing-feng Jun 5, 2024
a961746
rm useless logging
jiqing-feng Jun 5, 2024
c2253a8
fix last round review
jiqing-feng Jun 6, 2024
e29ea58
Merge branch 'huggingface:main' into rename
jiqing-feng Jun 6, 2024
caa27c3
Update .github/workflows/test_ipex.yml
echarlaix Jun 6, 2024
78498ab
Update optimum/intel/ipex/modeling_base.py
echarlaix Jun 6, 2024
97f7876
Update optimum/intel/ipex/modeling_base.py
echarlaix Jun 6, 2024
cf3525a
Update setup.py
echarlaix Jun 6, 2024
8ba602d
Update optimum/exporters/ipex/modeling_utils.py
echarlaix Jun 6, 2024
f15a1f5
fix
jiqing-feng Jun 6, 2024
36ae751
limit the new tokens of assisted decoding tests
jiqing-feng Jun 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
transformers-version: [4.38.0, 4.41.2]
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
os: [ubuntu-latest]

runs-on: ${{ matrix.os }}
Expand All @@ -32,6 +33,7 @@ jobs:
python -m pip install --upgrade pip
pip install torch torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install .[ipex,tests]
pip install transformers==${{ matrix.transformers-version }}
- name: Test with Pytest
run: |
pytest tests/ipex/
11 changes: 7 additions & 4 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from optimum.intel.utils.import_utils import is_ipex_version

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXLlamaDecoderLayerRef,
_llama_attn_forward,
_llama_layer_norm_forward,
Expand Down Expand Up @@ -62,18 +63,20 @@ def patch_op(m, target_m, new_op_name, new_op):


def _patch_llama_model(model):
if is_ipex_version("<", "2.5.0"):
raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache")
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(
f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports RotaryEmbedding and IndirectAccessKVCacheAttention"
)

from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding

ipex_rope = RotaryEmbedding(
model.config.max_position_embeddings,
model.config.hidden_size // model.config.num_attention_heads,
model.config.rope_theta,
model.config.architectures[0],
)
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings)
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)

Expand Down
77 changes: 48 additions & 29 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
from torch import nn
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import repeat_kv
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

from optimum.intel.utils.import_utils import is_ipex_version
from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.38.0"
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
_TRANSFORMERS_MAX_VERSION = "4.41.2"
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0"


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
Expand Down Expand Up @@ -51,27 +57,27 @@ def _llama_attn_forward(
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
# Use ipex op to rotary position embedding more efficient.
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)

if use_cache:
# Use ipex op to rotary position embedding more efficient.
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
Expand All @@ -87,6 +93,8 @@ def _llama_attn_forward(
value_states = value.transpose(1, 2)
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
kv_seq_len = key_states.shape[-2]

past_key_value = None
Expand Down Expand Up @@ -219,8 +227,16 @@ def _llama_model_forward(
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayerRef(nn.Module):
def __init__(self, module, config, distributed=False):
if is_ipex_version("<", "2.5.0"):
raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd")
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul and LinearAdd"
)
if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version(
">", _TRANSFORMERS_MAX_VERSION
):
raise ImportError(
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
)

from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd

Expand Down Expand Up @@ -278,7 +294,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
)
if not self.distributed:
if hasattr(self, "mha_linear_add"):
hidden_states = self.mha_linear_add(hidden_states, residual)
else:
hidden_states = self.self_attn.o_proj(hidden_states)
Expand All @@ -288,12 +304,15 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

mlp_gate = self.linear_silu_mul(hidden_states)

if not self.distributed:
hidden_states = self.mlp_linear_add(mlp_gate, residual)
if hasattr(self, "linear_silu_mul"):
mlp_gate = self.linear_silu_mul(hidden_states)
if hasattr(self, "mlp_linear_add"):
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.mlp.down_proj(mlp_gate)
hidden_states = residual + hidden_states
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
else:
hidden_states = self.mlp.down_proj(mlp_gate)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
Expand Down
58 changes: 51 additions & 7 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import intel_extension_for_pytorch as ipex
import torch
Expand Down Expand Up @@ -50,7 +50,7 @@
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model
from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model
from ..generation.modeling import prepare_jit_inputs
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device
Expand All @@ -60,10 +60,11 @@


_IPEX_SUPPORT_MODEL_TYPES = ("llama",)
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")


def _is_patched_with_ipex(model, task):
if is_ipex_version("<", "2.5.0"):
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
return False

if isinstance(model, torch.jit.ScriptModule):
Expand All @@ -73,7 +74,12 @@ def _is_patched_with_ipex(model, task):
return True
return False
else:
return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK
# The ipex IAKV op in patched model requires the hidden size at least 64
return (
model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES
and task in _IPEX_EXPORTED_TASK
and model.config.hidden_size >= 64
)


def ipex_jit_trace(model, task, use_cache):
Expand All @@ -83,6 +89,7 @@ def ipex_jit_trace(model, task, use_cache):

if _is_patched_with_ipex(model, task):
model = _patch_model(model)
# Todo: integerate in prepare_jit_inputs.
sample_inputs = get_dummy_input(model, return_dict=True)
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
_enable_tpp()
Expand All @@ -92,9 +99,10 @@ def ipex_jit_trace(model, task, use_cache):

model.config.return_dict = False

if "past_key_values" in sample_inputs and use_cache:
# Make sure the model will output past_key_values in generation tasks
model.config.use_cache = True
if "past_key_values" in sample_inputs:
model.config.use_cache = use_cache
if not use_cache:
sample_inputs.pop("past_key_values")

model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
# Disable repack while jit tracing to reduce the memory
Expand Down Expand Up @@ -522,6 +530,23 @@ def _prepare_past_key_values(self, input_ids):

return past_key_values

# Temporary fix, will delete when https://github.com/huggingface/transformers/pull/31226 release.
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
if not model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = None
return model_kwargs

past_length = 0
if "past_key_values" in model_kwargs:
past_length = model_kwargs["past_key_values"][0][0].shape[-2]
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
cur_len = input_ids.shape[-1]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -561,6 +586,25 @@ def forward(

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
generation_method = generation_config.get_generation_mode().value
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS:
raise ValueError(
f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)

return generation_config, model_kwargs

def generate(self, **kwargs):
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
if self._is_ipex_exported and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
return super().generate(**kwargs)
echarlaix marked this conversation as resolved.
Show resolved Hide resolved


def _prepare_inputs_for_generation_for_llama(
input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"neural-compressor": ["neural-compressor>=2.2.0", "onnxruntime<1.15.0", "accelerate"],
"openvino": ["openvino>=2023.3", "nncf>=2.10.0", "openvino-tokenizers[transformers]"],
"nncf": ["nncf>=2.10.0"],
"ipex": ["intel-extension-for-pytorch", "transformers>=4.36.0,<4.39.0"],
"ipex": ["intel-extension-for-pytorch", "transformers>=4.38.0,<=4.41.2"],
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down
Loading
Loading