Skip to content

Commit

Permalink
Assume memory usage for flash-attn dependent only on chunk size
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Apr 18, 2024
1 parent b112b21 commit 6940f1f
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from exllamav2.module import ExLlamaV2Module
from exllamav2.rmsnorm import ExLlamaV2RMSNorm
from exllamav2.layernorm import ExLlamaV2LayerNorm
from exllamav2.attn import ExLlamaV2Attention
from exllamav2.attn import ExLlamaV2Attention, has_flash_attn
from exllamav2.lora import ExLlamaV2Lora
from exllamav2.mlp import ExLlamaV2MLP
from exllamav2.moe_mlp import ExLlamaV2MoEMLP
Expand Down Expand Up @@ -675,12 +675,20 @@ def forward(self,

# Limit chunk_size to keep size of attention operation <= max_attention_size

past_len = cache.current_seq_len
attn_size = (past_len + remaining_q_len) * remaining_q_len
max_a = self.config.max_attention_size
if attn_size > max_a:
cs = (math.sqrt(past_len ** 2 + 4 * max_a) - past_len) / 2
chunk_size = min(chunk_size, math.floor(cs))
if has_flash_attn:

# Can't measure increase in VRAM usage with longer k_len, assume usage is constant
# for given chunk_size
pass

else:

past_len = cache.current_seq_len
attn_size = (past_len + remaining_q_len) * remaining_q_len
max_a = self.config.max_attention_size
if attn_size > max_a:
cs = (math.sqrt(past_len ** 2 + 4 * max_a) - past_len) / 2
chunk_size = min(chunk_size, math.floor(cs))

# Process chunk

Expand Down

0 comments on commit 6940f1f

Please sign in to comment.