Skip to content

Commit

Permalink
[RetNet] Fix error heads
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs authored May 18, 2024
1 parent 54bc132 commit b466c32
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions fla/layers/multiscale_retention.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,19 +196,21 @@ def forward(
if self.feature_map_fn is not None:
q, k = map(self.feature_map_fn, (q, k))

seqlen_offset = 0
seqlen_offset, max_seqlen = 0, None
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
max_seqlen = q.shape[1] + seqlen_offset
if attention_mask is not None:
# to deliminate the offsets of padding tokens
seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
q, k = self.rotary(q, k, seqlen_offset, q.shape[1] + max(seqlen_offset))
max_seqlen = q.shape[1] + max(seqlen_offset)
q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
q = q.transpose(1, 2)
if self.num_kv_groups > 1:
k = repeat(k, 'b t h d -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
v = repeat(v, 'b t (h d) -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
else:
k, v = rearrange(k, 'b t h d -> b h t d'), rearrange(v, 'b t (h d) -> b h t d')
k, v = rearrange(k, 'b t h d -> b h t d'), rearrange(v, 'b t (h d) -> b h t d', h=self.num_kv_heads)

state = last_state[-1] if use_cache else None
if mode == 'chunk':
Expand Down

0 comments on commit b466c32

Please sign in to comment.