Skip to content

Commit

Permalink
Update RoPE tests (#1746)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Sep 26, 2024
1 parent 60ec41a commit d36d9ec
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ def test_rope():
x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()
position_ids = torch.arange(seq_len).unsqueeze(0)

theirs = GPTNeoXRotaryEmbedding(head_size, seq_len)
theirs_rot_emb = GPTNeoXRotaryEmbedding(head_size, seq_len)
theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids)

ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device)
# their rope cache has 2 added dimensions and the cos/sin is duplicated
torch.testing.assert_close(ours_cos_cached, theirs.cos_cached.squeeze())
torch.testing.assert_close(ours_sin_cached, theirs.sin_cached.squeeze())
torch.testing.assert_close(ours_cos_cached, theirs_cos.squeeze())
torch.testing.assert_close(ours_sin_cached, theirs_sin.squeeze())

ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached)
theirs_x_rope, _ = apply_rotary_pos_emb(x, x, theirs.cos_cached, theirs.sin_cached, position_ids)
theirs_x_rope, _ = apply_rotary_pos_emb(x, x, theirs_cos, theirs_sin, position_ids)
torch.testing.assert_close(ours_x_rope, theirs_x_rope)

0 comments on commit d36d9ec

Please sign in to comment.