Skip to content

Latest commit

 

History

History
113 lines (75 loc) · 5.92 KB

efficient_attn.md

File metadata and controls

113 lines (75 loc) · 5.92 KB

Efficient Attention

Here're some resources about Efficient Attention

Taxonomy of Efficient Attention

Intro

This section shares the literatures dedicated to optimizing attention mechanisms, especially focusing on the kernel operations that make the module the computational bottleneck of the Transformer.

This approach enables the expansion of the effective context length boundary for LLMs during inference by directly increasing the hyperparameter $L_{max}$ (maximum number of tokens in one sample) in the pre-training stage. We further categorize these methods into five distinct strategies as above, each with a specific focus.

Table of Contents

Prelimilaries

  • Attention Mechanism: As the core design of the Transformer implemented in the Multi-Head Attention (MHA) layer, the self-attention mechanism generally computes a weighted representation of each token in the input sequence based on its relevance to any other ones.

  • QKV Projection: More specifically, as illustrated in the Overview $\mathbf{(a)}$, the word-embedded token sequence $X \in \mathbb{R}^{L\times d_{in}}$ , concatenating long contexts and user prompts with total length $L$, will derive three embedding matrices with a linear projection layer (see the equation below): query $Q \in \mathbb{R}^{L\times d_q}$, key $K \in \mathbb{R}^{L\times d_k}$ and value $V \in \mathbb{R}^{L\times d_v}$.

$$ \begin{align} Q, K, V := \mathrm{split}\left( X \times W_{q,k,v} \right), \quad W_{q,k,v} \in \mathbb{R}^{d_{in}\times (d_q+d_k+d_v)} \end{align} $$

  • Attention Kernel: Once $Q,K,V$ are well prepared(after QKV projection, positional embeddings, kv cache updating, and other specific operations like reshaping or head-alignment), here come the attention kernel operations:
    • (1) First the unnormalized relevance matrix $P \in \mathbb{R}^{L\times L}$ are calculated by matrix multiplication of $Q, K^{\mathrm{T}}$, where each entry serves as relevance for corresponding pair of tokens.

$$ \begin{align} P := Q\times K^{\mathrm{T}} \end{align} $$

  • (2) Then, the normalized attention score matrix $A \in \mathbb{R}^{L\times L}$ is computed as: a scaling operation by factor $\sqrt{d_k}$, an element-wise mask operation with mask $M \in \mathbb{R}^{L\times L}$, and a row-wise softmax.

$$ \begin{align} A := \mathrm{softmax}[\cfrac{P}{\sqrt{d_k}}\odot M] \end{align} $$

  • (3) Finally, the output hidden states $O \in \mathbb{R}^{L\times d_o}$ are generated by a weighted sum of $V$ with attention weights in each row of $A$, usually followed with an extra linear transformation.

$$ \begin{align} O := (A\times V) \times W_o, \quad W_o \in \mathbb{R}^{d_v \times d_o} \end{align} $$

  • Common Configuration: Note that the embedding dimensions of $Q, K, V, O$ can be the same or not. Even though subscripts are used above to distinguish them for generality, the common configuration is simply setting: $d = d_q = d_k = d_v = d_o$ as default. As for the mask matrix $M$, it is typically used for masking padding tokens to align all batched input sequences and also applies casual mask operation of causal language modeling for generative LLMs. Furthermore, to capture diverse relationships, the model often employs multi-head attention instead of single-head one, performing the attention process in parallel with differently weighted $Q_h, K_h, V_h$ sets by dividing learnable parameters like $W_{q,k,v} \in \mathbb{R}^{d_{in}\times (3\times d)}$ into $W_{q,k,v}^{mh} \in \mathbb{R}^{d_{in}\times (3\times H\times d_{head})}$, where $H$ denotes the number of heads. Similar to embedding dimensions, the number of heads can be specific for $Q, K, V$, such as MQA or GQA techniques used in GLM, PaLM mentioned in Miscellaneous, which vary in different LLMs, yet they are considered the same by default.

  • Attention Complexity: In typical scenarios where $L \gg d$, the computational complexity of MHA involves time complexity as $O(L^2 d)$ , comprising: $O(Ld^2)$ for QKV projection, $O(L^2d)$ for the computation of $P$, $O(L^2)$ for the softmax operation to obtain $A$, $O(L^2d)$ for the multiplication of $A$ and $V$ , and $O(Ld^2)$ for the output projection of $O$. And it incurs $O(L^2)$ space complexity, involving: $O(Ld)$ for embeddings of $Q, K, V, O$, and additional $O(L^2)$ buffers for storing weights $P$ and $A$.

    Consequently, both temporal and spatial computational costs exhibit a quadratic increase with the expansion of the sequence length $L$, which can be burdensome for both training and inference and become the bottleneck of scaling in the context window size.

Other Relevant Surveys

Efficient Transformers: A Survey

paper link: here

citation:

@misc{tay2022efficient,
      title={Efficient Transformers: A Survey}, 
      author={Yi Tay and Mostafa Dehghani and Dara Bahri and Donald Metzler},
      year={2022},
      eprint={2009.06732},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Efficient attentions for long document summarization

paper link: here

citation:

@article{huang2021efficient,
  title={Efficient attentions for long document summarization},
  author={Huang, Luyang and Cao, Shuyang and Parulian, Nikolaus and Ji, Heng and Wang, Lu},
  journal={arXiv preprint arXiv:2104.02112},
  year={2021}
}