Skip to content

Commit

Permalink
remove comment, GPTNeoX caching, renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
HeegyuKim committed Sep 23, 2023
1 parent cca5e7c commit 4ef7120
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
23 changes: 15 additions & 8 deletions src/transformers/models/gpt_neox/modeling_flax_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,28 @@ def rotate_half(x):
return jnp.concatenate((-x2, x1), axis=-1)


class RotaryEmbedding(nn.Module):
class FlaxGPTNeoXRotaryEmbedding(nn.Module):
dim: int
max_seq_len_cached: int
base: int = 10000

def setup(self):
self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2).astype(jnp.float32) / self.dim)) # dim
self._init_cache(self.max_seq_len_cached)

def __call__(self, x=None, seq_len=None):
t = jnp.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype)
def _init_cache(self, seq_len: int):
t = jnp.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = jnp.outer(t, self.inv_freq)
emb = jnp.concatenate((freqs, freqs), axis=-1)
cos_cached = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0)
sin_cached = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0)
return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...]
self.cos_cached = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0)
self.sin_cached = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0)

def __call__(self, seq_len=None):
if seq_len > self.max_seq_len_cached:
self._init_cache(seq_len)
self.max_seq_len_cached = seq_len

return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
Expand Down Expand Up @@ -178,7 +185,7 @@ def setup(self):

self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")

self.rotary_emb = RotaryEmbedding(
self.rotary_emb = FlaxGPTNeoXRotaryEmbedding(
dim=self.rotary_ndims,
max_seq_len_cached=config.max_position_embeddings,
base=config.rotary_emb_base,
Expand Down Expand Up @@ -240,14 +247,14 @@ def __call__(
key = qkv[..., self.head_size : 2 * self.head_size] # .permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :] # .permute(0, 2, 1, 3)

cos, sin = self.rotary_emb(seq_len)
if self.rotary_ndims is not None:
k_rot = key[..., : self.rotary_ndims]
k_pass = key[..., self.rotary_ndims :]

q_rot = query[..., : self.rotary_ndims]
q_pass = query[..., self.rotary_ndims :]

cos, sin = self.rotary_emb(value, seq_len=seq_len)
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin, position_ids)

key = jnp.concatenate([k_rot, k_pass], axis=-1)
Expand Down
1 change: 0 additions & 1 deletion tests/models/gpt_neox/test_modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
# input_mask = random_attention_mask([self.batch_size, self.seq_length])
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)

token_labels = None
Expand Down

0 comments on commit 4ef7120

Please sign in to comment.