Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add IPEX-XPU support for Llama2 model Inference #703

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5c4d13f
add xpu patch to optimum intel (#7)
ganyi1996ppo Apr 22, 2024
b1d6989
can run but precision error
jiqing-feng Apr 25, 2024
f2de914
optimize optimum
ganyi1996ppo Apr 26, 2024
9295457
further optimize
ganyi1996ppo Apr 26, 2024
c55216a
finalize
faaany May 8, 2024
5b3b72d
fix version
faaany May 9, 2024
4897144
fix ipex version check
faaany May 11, 2024
5351f4a
ipex 2.3 released
jiqing-feng May 23, 2024
6289b57
change versions
faaany May 24, 2024
3824300
debug beam search
faaany May 24, 2024
872a3eb
remove reference elimination
faaany May 24, 2024
d1d0ca0
refactor IPEXLlamaAttention
faaany May 25, 2024
3b8900d
Merge branch 'ipex-cpu' into ipex-xpu
faaany May 26, 2024
815d238
Merge branch 'huggingface:main' into ipex-xpu
faaany May 26, 2024
89e10d6
add xpu port
faaany May 26, 2024
9acaba4
Fix llama and gemma modeling patching for openvino export (#714)
echarlaix May 23, 2024
2f4909c
Fix nncf quantization for decoder models (#727)
echarlaix May 24, 2024
17d02d3
Merge branch 'ipex-xpu' of https://github.com/faaany/optimum-intel in…
faaany May 26, 2024
f186ce7
remove
faaany May 26, 2024
1ff78b2
fix version
faaany May 26, 2024
ff7f785
bug fix
faaany May 26, 2024
e3dac89
change module
faaany May 26, 2024
8725f49
improve device
faaany May 26, 2024
57cfe11
remove
faaany May 26, 2024
ee78f95
simplfy rmsnorm
faaany May 27, 2024
a930f31
Merge branch 'ipex-xpu' of https://github.com/faaany/optimum-intel in…
faaany May 27, 2024
6098943
style
faaany May 27, 2024
e0fb06e
fix group attention
faaany Jun 7, 2024
aa8d395
fix weight shape
faaany Jun 7, 2024
0a56b19
Merge branch 'main' into ipex-xpu
faaany Jun 7, 2024
548d83f
fix rebase bug
faaany Jun 7, 2024
68187e5
revert openvino
faaany Jun 7, 2024
efedca4
revert openvino
faaany Jun 7, 2024
bd03552
remove duplicates
faaany Jun 7, 2024
0d3930a
use the correct black
faaany Jun 7, 2024
b4ba6d0
Merge branch 'main' into ipex-xpu
faaany Sep 6, 2024
1fd464b
fix merge conflict
kaixuanliu Sep 10, 2024
6a52fdf
Merge pull request #1 from kaixuanliu/ipex-xpu
kaixuanliu Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_llama_model_forward,
)

from .modeling.modeling_llama import _IPEXLlamaDecoderLayer

_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
_IPEX_EXPORTED_TASK = ("text-generation",)
Expand Down Expand Up @@ -62,26 +63,34 @@ def patch_op(m, target_m, new_op_name, new_op):


def _patch_llama_model(model):
if is_ipex_version("<", "2.5.0"):
raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache")

from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding

ipex_rope = RotaryEmbedding(
model.config.max_position_embeddings,
model.config.hidden_size // model.config.num_attention_heads,
model.config.rope_theta,
model.config.architectures[0],
)
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)

convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)

convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)

ipex_version = "2.1.0" if "xpu" in str(model.device) else "2.3.0"
faaany marked this conversation as resolved.
Show resolved Hide resolved
if is_ipex_version("<", ipex_version):
raise ImportError(f"Only ipex version > {ipex_version} supports RotaryEmbedding and IndirectAccessKVCache")

if "cpu" in str(model.device):
from intel_extension_for_pytorch.llm.modules import RotaryEmbedding
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache

ipex_rope = RotaryEmbedding(
model.config.max_position_embeddings,
model.config.hidden_size // model.config.num_attention_heads,
model.config.rope_theta,
model.config.architectures[0],
)
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)

patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)

convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)

convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
else:
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
return model


Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/ipex/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

155 changes: 155 additions & 0 deletions optimum/exporters/ipex/modeling/modeling_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import torch.nn as nn
from typing import Optional, Tuple
import intel_extension_for_pytorch as ipex


class _IPEXLlamaAttention(nn.Module):
def __init__(self, module, config, distributed=False) -> None:
super().__init__()
self.module = module
self.config = config
self.distributed = distributed

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
residual: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
residual (`torch.Tensor`): residual tensor to the layer of shape `
"""
pass


class _IPEXLlamaMLP(nn.Module):
def __init__(self, module, config, distributed=False) -> None:
super().__init__()
self.module = module
self.config = config
self.distributed = distributed

def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor, **kwargs):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
"""
pass


class _IPEXLlamaDecoderLayer(nn.Module):
def __init__(self, module, config, distributed=False) -> None:
super().__init__()
self.layer_idx = module.self_attn.layer_idx
module_device = str(module.self_attn.q_proj.weight.device)
if "xpu" in module_device:
from .xpu.xpu_modeling_llama import _IPEXLlamaAttentionXPU, _IPEXLlamaMLPXPU

self.attn = _IPEXLlamaAttentionXPU(module.self_attn, config, distributed)
self.mlp = _IPEXLlamaMLPXPU(module.mlp, config, distributed)
else:
self.attn = _IPEXLlamaAttention(module.self_attn, config, distributed)
self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed)
self.input_layernorm = ipex.llm.modules.RMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
self.post_attention_layernorm = ipex.llm.modules.RMSNorm(
module.post_attention_layernorm.weight, module.post_attention_layernorm.variance_epsilon
)

def preprocess_for_optimize(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
postion_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attention: Optional[bool] = True,
use_cache: Optional[bool] = False,
**kwargs,
):
return hidden_states, attention_mask, postion_ids, past_key_value

def postprocess_for_optimize(
self, hidden_states, output_attention, use_cache, self_attn_weight, present_key_value, **kwargs
):
outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
if output_attention:
outputs += (self_attn_weight,)

return outputs

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
outputs = self.preprocess_for_optimize(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs
)
(hidden_states, attention_mask, position_ids, past_key_value) = outputs
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, present_key_value, self_attn_weight = self.attn(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
None,
residual,
**kwargs,
)

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, residual, **kwargs)

outputs = self.postprocess_for_optimize(
hidden_states, output_attentions, use_cache, self_attn_weight, present_key_value, **kwargs
)

return outputs
Loading