Skip to content

Commit

Permalink
Merge pull request #21 from leor-c/main
Browse files Browse the repository at this point in the history
Slightly more efficient / cleaner implementation of the chunkwise relative pos. enc.
  • Loading branch information
fkodom authored Nov 9, 2023
2 parents 3cf9797 + 1537965 commit 7d9c1a7
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions yet_another_retnet/retention.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,17 +493,9 @@ def forward_chunkwise(
v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)

if self.relative_position:
# global (cross-chunk) relative position embedding
# global (cross-chunk) + intra-chunk relative position embedding
assert self.thetas is not None
thetas = rearrange(self.thetas, "d -> () () () d")
angles = start_idx * thetas
sin = torch.sin(angles)
cos = torch.cos(angles)
q = _theta_shift(q, sin, cos)
k = _theta_shift(k, sin, cos)

# intra-chunk relative position encoding
indices = torch.arange(q.size(2), device=q.device, dtype=q.dtype)
indices = torch.arange(start_idx, start_idx + q.size(2), device=q.device, dtype=q.dtype)
indices = rearrange(indices, "n -> () () n ()")
thetas = rearrange(self.thetas, "d -> () () () d")
angles = indices * thetas
Expand Down

0 comments on commit 7d9c1a7

Please sign in to comment.