The implementation is inspired from the paper Infini Transformer & dingo-actual
The purpose of the Infini Tranformer is to scale the context length to infinitely long inputs but with bounded memory & computation. A concept of Infini attention is used along with vanilla attention mechanism. Both masked local attention and long-term linear attention is used so as to not lose context in longer inputs. A limitation of bounded memory parameters enables fast inference in LLMs.
Transformer-based LLMs have a constrained context-dependent memory, due to the nature of the attention mechanism :
- Constrained, context dependent memory
- Quadratic complexity in both memory footprint & computation time
- Serving models with longer contexts is costly
Advantages of using Compressive Memory :
- Maintains fixed no. of parameters with limited costs
- New info is added by changing these parameters. Note that this is done with an objective that this information can be retrived back
Our main aim is to store bindings of key & value states in the compressive memory & retrieve by using the query vectors, similar to the concept Meta-learned Neural Memory
This is the working of Tranformer-XL
The input sequence is broken down into segments to effectively tend to intricate details (refer to the paper for more detials). Notice how it discards old contexts since it caches the KV states for the last segment only. However, using the Infini Transformer, one can carry forward the entire context history.Infini attention computes both local and global context states and combines them for its output effectively reusing old KV attention states. It maintains H number of parallel compressive memory per attention layer (H is the number of attention heads)
This is the essence of the paper. These are formulae from the paper using which a certain bounded parameters are used to compress the long context for inference.
Instead of computing new memory entries for compressive memory, the query, key and value states from the dot-product attention computation are reused. This helps in long context adaption & speeds up training & inference. A linear attention mechanism is used to cast the memory update and retrieveal process.
Here, 's' is the segment number. For an example, when we are at segment (let's say) 2, s-1 would mean the previous / segment 1 of tokens. The paper proposes each segment length to be 2048 tokens. You can modify it according to your input sequence length.
New content (A) is retrieved from the memory (M) by using the query (Q) as :
The memory and the normalization terms are updated with the new KV entries and the next states are obtained as :
The local attention state (Ad) & memory retrieved content (Am, i.e. from the previous segments) are aggregated using a hyperparameter, β :
A = sigmoid(β) * Am + (1 - sigmoid(β)) * Ad
The hyperparameter, β is determined by a learnable trade-off between the long-term and local information flows in the model.
Infini-Transformer has a constant memory complexity of dkey × dvalue + dkey for storing compressed context in Ms and zs for each head in a single layer. On the other hand, for the other models, their complexity increases along with the sequence dimension. In Transformer-XL, Compressive Transformers & Memorizing Transformers, the memory complexity depends on the cache size. In case of RTM & AutoCompressors, it depends on the prompt size.
Refer to Associative Matrix to get to know more about compressive memory.
Refer here for more on the working of the update rule & retrieval mechanism.
- Scale the architecture to multiple Infini attention layers
- LLM pre-training on large datasets
- Perform the Book Summarization task by finetuning the LLM on the BookSum dataset