Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikayk committed Apr 26, 2024
1 parent 50b8ce6 commit 4ee6aaf
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchtune/models/phi3/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
TODO: The implementation below can be made more efficient
for inference.
"""
# input tensor has shape [b, s, n_h, n_d]
# input tensor has shape [b, s, n_h, h_d]
seq_len = x.size(1)
head_dim = x.size(-1)

Expand All @@ -104,7 +104,7 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
rotated = torch.cat((-x2, x1), dim=-1)

# cos: [s, h_d]
# x: [b, s, n_h, n_d]
# x: [b, s, n_h, h_d]
# For the matrix multiplication to line up, transpose the input
# and the rotated input
x_out = (x.transpose(1, 2) * cos) + (rotated.transpose(1, 2) * sin)
Expand Down

0 comments on commit 4ee6aaf

Please sign in to comment.