Skip to content

Commit

Permalink
Refactor InfiniAttention class to return past_key_value in infiniAtte…
Browse files Browse the repository at this point in the history
…ntion.py and modeling_qwen_transformers.py
  • Loading branch information
jlamprou committed Apr 22, 2024
1 parent c18eb55 commit 7581c10
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion infiniAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(
combined_output = combined_output.transpose(1, 2).contiguous()
combined_output = combined_output.view(bsz, q_len, self.hidden_size)
final_output = self.o_proj(combined_output)
return final_output, None, None, (M, z)
return final_output, None, past_key_value, (M, z)

def _retrieve_from_memory(self, Q, M, z):
# Retrieve context from compressive memory using linear attention (Eq. 3)
Expand Down
2 changes: 1 addition & 1 deletion modeling_qwen_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def forward(
combined_output = combined_output.transpose(1, 2).contiguous()
combined_output = combined_output.view(bsz, q_len, self.hidden_size)
final_output = self.o_proj(combined_output)
return final_output, None, None, (M, z)
return final_output, None, past_key_value, (M, z)

def _retrieve_from_memory(self, Q, M, z):
# Retrieve context from compressive memory using linear attention (Eq. 3)
Expand Down

0 comments on commit 7581c10

Please sign in to comment.