From 21675a85f82cd7f30da1e8160316d596426596f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:12:59 +0200 Subject: [PATCH 1/6] remove unnecessary unsqueeze-squeeze in llama --- .../open_llama/modeling_open_llama.py | 16 ++++++++-------- .../models/llama/modeling_llama.py | 19 ++++++++----------- .../models/persimmon/modeling_persimmon.py | 19 ++++++++----------- 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 8469b86eb9d53e..9e7da2f70e7d6b 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -121,8 +121,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -130,8 +130,8 @@ def forward(self, x, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -151,8 +151,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama @@ -178,8 +178,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def rotate_half(x): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5e7a879c07e88f..e3d7f3a9e992ea 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -111,8 +111,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -120,8 +120,8 @@ def forward(self, x, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -140,8 +140,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): @@ -166,8 +166,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def rotate_half(x): @@ -178,9 +178,6 @@ def rotate_half(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 5c6cde7f8a6d44..9a43320c234209 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -94,8 +94,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -103,8 +103,8 @@ def forward(self, x, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -124,8 +124,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon @@ -151,8 +151,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -165,9 +165,6 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) From 768b8b0f609f65244b782d56861077ff0879b2b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Sat, 16 Sep 2023 09:58:53 +0200 Subject: [PATCH 2/6] correct other models --- .../models/falcon/modeling_falcon.py | 17 ++++++++-------- .../models/gpt_neox/modeling_gpt_neox.py | 20 +++++++++---------- .../modeling_gpt_neox_japanese.py | 8 ++++---- .../models/idefics/modeling_idefics.py | 19 +++++++++--------- .../models/llama/modeling_llama.py | 5 +++-- 5 files changed, 34 insertions(+), 35 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index c541fab0a253a7..4e0d0e9aa23545 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -94,8 +94,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() - self.cos_cached = emb.cos()[None, :, :] - self.sin_cached = emb.sin()[None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() self.cos_cached = self.cos_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype) @@ -107,8 +107,8 @@ def cos_sin( if total_length > self.seq_len_cached: self._set_cos_sin_cache(total_length, device, dtype) # Gather cos, sin at the designated position ids - cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] - sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] + cos = self.cos_cached[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = self.sin_cached[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] return cos, sin def forward(self, query, key, past_key_values_length, position_ids): @@ -155,8 +155,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() - self.cos_cached = emb.cos()[None, :, :] - self.sin_cached = emb.sin()[None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() self.cos_cached = self.cos_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype) @@ -189,8 +189,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() - self.cos_cached = emb.cos()[None, :, :] - self.sin_cached = emb.sin()[None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() self.cos_cached = self.cos_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype) @@ -432,6 +432,7 @@ def forward( key_layer = torch.cat((past_key, key_layer), dim=1) value_layer = torch.cat((past_value, value_layer), dim=1) + print("key_layer", key_layer.shape) _, kv_length, _ = key_layer.shape if use_cache: present = (key_layer, value_layer) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 6710892dc5c3af..f522a6ca9d5316 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -306,14 +306,14 @@ def _set_cos_sin_cache(self, seq_len, device): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device) - return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) + return self.cos_cached[:seq_len].to(x.device), self.sin_cached[:seq_len].to(x.device) class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): @@ -331,8 +331,8 @@ def _set_cos_sin_cache(self, seq_len, device): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): @@ -357,8 +357,8 @@ def _set_cos_sin_cache(self, seq_len, device): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() def rotate_half(x): @@ -369,10 +369,8 @@ def rotate_half(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] - gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) - cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index e66c51c1b4fb9b..55d9a7a993da12 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -239,7 +239,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding -class RotaryEmbedding(torch.nn.Module): +class GPTNeoXRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() @@ -259,14 +259,14 @@ def _set_cos_sin_cache(self, seq_len, device): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device) - return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) + return self.cos_cached[:seq_len].to(x.device), self.sin_cached[:seq_len].to(x.device) def rotate_half(x): diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index db5cbb75fe5fd4..90cc5d59a65313 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -507,8 +507,8 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -519,11 +519,11 @@ def forward(self, x, seq_len=None): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -534,11 +534,10 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] - gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) - cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e3d7f3a9e992ea..0b70ddac1dd0fa 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -177,9 +177,10 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed From 4183f9f8089d5e5c585d4d1b49ef9c4acb247238 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Sat, 16 Sep 2023 10:02:32 +0200 Subject: [PATCH 3/6] fix --- src/transformers/models/falcon/modeling_falcon.py | 5 ++--- .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 4e0d0e9aa23545..8105142c6643bd 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -107,8 +107,8 @@ def cos_sin( if total_length > self.seq_len_cached: self._set_cos_sin_cache(total_length, device, dtype) # Gather cos, sin at the designated position ids - cos = self.cos_cached[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = self.sin_cached[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = self.cos_cached[position_ids] # [bs, seq_len, dim] + sin = self.sin_cached[position_ids] # [bs, seq_len, dim] return cos, sin def forward(self, query, key, past_key_values_length, position_ids): @@ -432,7 +432,6 @@ def forward( key_layer = torch.cat((past_key, key_layer), dim=1) value_layer = torch.cat((past_value, value_layer), dim=1) - print("key_layer", key_layer.shape) _, kv_length, _ = key_layer.shape if use_cache: present = (key_layer, value_layer) diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 55d9a7a993da12..9583ff924d3d68 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -239,7 +239,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding -class GPTNeoXRotaryEmbedding(torch.nn.Module): +class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() From e0bab593ec22b8d641cecc479a6329f563bb5087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Sat, 16 Sep 2023 10:07:18 +0200 Subject: [PATCH 4/6] revert gpt_neox_japanese --- .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 9583ff924d3d68..e66c51c1b4fb9b 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -259,14 +259,14 @@ def _set_cos_sin_cache(self, seq_len, device): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device) - return self.cos_cached[:seq_len].to(x.device), self.sin_cached[:seq_len].to(x.device) + return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) def rotate_half(x): From 4f89883b2b91cfe7de17da01c2ba3359efecafe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 5 Oct 2023 11:12:44 +0200 Subject: [PATCH 5/6] fix copie --- .../deprecated/open_llama/modeling_open_llama.py | 7 ++----- .../models/gpt_neox/modeling_gpt_neox.py | 12 +++++++----- .../modeling_gpt_neox_japanese.py | 8 ++++---- .../models/mistral/modeling_mistral.py | 15 ++++++--------- .../models/persimmon/modeling_persimmon.py | 4 ++-- 5 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 73a6dd149fa9b4..3558880f13a16b 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -188,11 +188,8 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 5a5e18b8d5d937..b421be9dc2750f 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -343,14 +343,15 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) - def _set_cos_sin_cache(self, seq_len, device): + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: @@ -358,15 +359,15 @@ def _set_cos_sin_cache(self, seq_len, device): (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) + self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def rotate_half(x): @@ -376,6 +377,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] sin = sin[position_ids].unsqueeze(1) diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 5e94096a6da35d..98753edeb544f8 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -261,8 +261,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -270,8 +270,8 @@ def forward(self, x, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 8edd0fc60e6299..a55f16a23d5b52 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -149,8 +149,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -158,8 +158,8 @@ def forward(self, x, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -173,11 +173,8 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 2da9c2cd49e065..6445527b669cac 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -165,8 +165,8 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed From 69a4840d9748402dc7a8af93529da3c6409b40ed Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 5 Oct 2023 11:16:23 +0000 Subject: [PATCH 6/6] fix test --- tests/models/mistral/test_modeling_mistral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 403f2cc7347041..df1143d2516afd 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -430,7 +430,8 @@ class MistralIntegrationTest(unittest.TestCase): def test_model_7b_logits(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto") - out = model(torch.tensor([input_ids])).logits + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + out = model(input_ids).logits.cpu() # Expected mean on dim = -1 EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]]) torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)