diff --git a/yet_another_retnet/retention.py b/yet_another_retnet/retention.py index 96fce50..0d662b0 100644 --- a/yet_another_retnet/retention.py +++ b/yet_another_retnet/retention.py @@ -493,17 +493,9 @@ def forward_chunkwise( v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) if self.relative_position: - # global (cross-chunk) relative position embedding + # global (cross-chunk) + intra-chunk relative position embedding assert self.thetas is not None - thetas = rearrange(self.thetas, "d -> () () () d") - angles = start_idx * thetas - sin = torch.sin(angles) - cos = torch.cos(angles) - q = _theta_shift(q, sin, cos) - k = _theta_shift(k, sin, cos) - - # intra-chunk relative position encoding - indices = torch.arange(q.size(2), device=q.device, dtype=q.dtype) + indices = torch.arange(start_idx, start_idx + q.size(2), device=q.device, dtype=q.dtype) indices = rearrange(indices, "n -> () () n ()") thetas = rearrange(self.thetas, "d -> () () () d") angles = indices * thetas