-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hardware][intel GPU] bump up ipex version to 2.3 #8365
Changes from all commits
4974d0f
608547a
c5e9128
71710ce
9ff9708
85c697e
f8f1b66
3e0a8f4
b6e4a59
903ba2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,29 +27,27 @@ def _reshape_activation_tensor( | |
|
||
@staticmethod | ||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: | ||
x1, x2 = ipex_ops._reshape_activation_tensor(x) | ||
ipex.llm.functional.silu_mul(x1, x2, out) | ||
ipex.llm.functional.silu_and_mul(x, out) | ||
|
||
@staticmethod | ||
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: | ||
x1, x2 = ipex_ops._reshape_activation_tensor(x) | ||
ipex.llm.functional.gelu_mul(x1, x2, out, "none") | ||
ipex.llm.functional.gelu_and_mul(x, out) | ||
|
||
@staticmethod | ||
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: | ||
x1, x2 = ipex_ops._reshape_activation_tensor(x) | ||
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh") | ||
ipex.llm.functional.gelu_and_mul(x, out) | ||
|
||
@staticmethod | ||
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: | ||
out.copy_(torch.nn.functional.gelu(x)) | ||
def gelu_fast(x: torch.Tensor) -> torch.Tensor: | ||
return torch.nn.functional.gelu(x) | ||
|
||
@staticmethod | ||
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: | ||
out.copy_(torch.nn.functional.gelu(x)) | ||
def gelu_new(x: torch.Tensor) -> torch.Tensor: | ||
return torch.nn.functional.gelu(x) | ||
Comment on lines
+41
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for gelu fast and gelu new. |
||
|
||
# TODO add implementation of gelu_quick here | ||
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: | ||
@staticmethod | ||
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: | ||
ipex.llm.functional.gelu_quick(x, out) | ||
|
||
@staticmethod | ||
def paged_attention_v1( | ||
|
@@ -160,67 +158,26 @@ def rotary_embedding( | |
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] | ||
is_neox: bool, | ||
) -> None: | ||
if positions.dim() == 1: | ||
positions = positions.unsqueeze(0) | ||
query = query.unsqueeze(0) | ||
key = key.unsqueeze(0) | ||
|
||
rotary_dim = cos_sin_cache.size(1) | ||
query = query.view(*query.shape[:-1], -1, head_size) | ||
key = key.view(*key.shape[:-1], -1, head_size) | ||
|
||
query_rot = query[..., :rotary_dim] | ||
key_rot = key[..., :rotary_dim] | ||
|
||
cos_sin = cos_sin_cache[positions.long()] | ||
cos, sin = cos_sin.chunk(2, dim=-1) | ||
|
||
if is_neox: | ||
cos = cos.repeat(1, 1, 2).unsqueeze(-2) | ||
sin = sin.repeat(1, 1, 2).unsqueeze(-2) | ||
else: | ||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, | ||
rotary_dim, is_neox, positions) | ||
rot_dim = cos_sin_cache.size(1) | ||
ipex.llm.functional.rotary_embedding_batched(positions, query, key, | ||
head_size, cos_sin_cache, | ||
is_neox, rot_dim) | ||
|
||
@staticmethod | ||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, | ||
key: torch.Tensor, head_size: int, | ||
cos_sin_cache: torch.Tensor, is_neox: bool, | ||
rot_dim: int, | ||
cos_sin_cache_offsets: torch.Tensor) -> None: | ||
if positions.dim() == 1: | ||
positions = positions.unsqueeze(0) | ||
query = query.unsqueeze(0) | ||
key = key.unsqueeze(0) | ||
cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions) | ||
rotary_dim = cos_sin_cache.size(1) | ||
query = query.view(*query.shape[:-1], -1, head_size) | ||
key = key.view(*key.shape[:-1], -1, head_size) | ||
|
||
query_rot = query[..., :rotary_dim] | ||
key_rot = key[..., :rotary_dim] | ||
|
||
cos_sin = cos_sin_cache[torch.add(positions, | ||
cos_sin_cache_offsets).long()] | ||
cos, sin = cos_sin.chunk(2, dim=-1) | ||
|
||
if is_neox: | ||
cos = cos.repeat(1, 1, 2).unsqueeze(-2) | ||
sin = sin.repeat(1, 1, 2).unsqueeze(-2) | ||
else: | ||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
|
||
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, | ||
rotary_dim, is_neox, positions) | ||
ipex.llm.functional.rotary_embedding_batched(positions, query, key, | ||
head_size, cos_sin_cache, | ||
is_neox, rot_dim, | ||
cos_sin_cache_offsets) | ||
|
||
@staticmethod | ||
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, | ||
epsilon: float) -> None: | ||
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon) | ||
out.copy_(tmp) | ||
def rms_norm(input: torch.Tensor, weight: torch.Tensor, | ||
epsilon: float) -> torch.Tensor: | ||
return ipex.llm.functional.rms_norm(input, weight, epsilon) | ||
|
||
@staticmethod | ||
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, | ||
|
@@ -246,11 +203,14 @@ def varlen_attention( | |
return_softmax: bool, | ||
gen_: torch.Generator, | ||
) -> None: | ||
ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q, | ||
seqlen_k, max_seqlen_q, | ||
max_seqlen_k, pdropout, | ||
softmax_scale, zero_tensors, | ||
is_causal, return_softmax, gen_) | ||
ipex.llm.functional.varlen_attention(query.contiguous(), | ||
key.contiguous(), | ||
value.contiguous(), out, | ||
Comment on lines
+206
to
+208
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. Ipex 2.3 require qkv here to be contiguous for this API. |
||
seqlen_q.int(), seqlen_k.int(), | ||
max_seqlen_q, max_seqlen_k, | ||
pdropout, softmax_scale, | ||
zero_tensors, is_causal, | ||
return_softmax, gen_) | ||
|
||
@staticmethod | ||
def reshape_and_cache( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you use the same function for gelu and gelu tanh. Is this intended? This might cause subtle accuracy drop for some models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, they are same kernel in ipex implementation. cc @ganyi1996ppo