Skip to content

16-fold memory access reduction with nearly no loss

License

Notifications You must be signed in to change notification settings

MachineLearningSystem/DoubleSparse

 
 

Repository files navigation

Double Sparsity

[paper]

TL;DR

We introduce Double Sparsity, a technique to accelerate LLM inference by reducing memory access to the KV cache. It predicts important tokens using a subset of channels and computes attention using these important tokens. Without any fine-tuning, Double Sparsity achieves bandwidth-efficient attention for the KV cache with almost no loss.

Install

  1. Clone this repo and setup the environment
git clone https://github.com/andy-yang-1/DoubleSparse.git
cd DoubleSparse
conda create -n sparse python=3.9 -y
conda activate sparse
pip install -r requirement.txt
  1. Install torch for offloading feature
# no offloading 
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121

or

# offloading 
pip3 install torch==2.1
pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html

Accuracy Evaluation

  1. Wikitext-2 Perplexity
python3 evaluation/perplexity_eval.py --model_path meta-llama/Llama-2-7b-hf --heavy_const 128 --group_factor 4
  1. MMLU
python3 evaluation/mmlu.py
  1. Key-Value Retrieval
python3 evaluation/retrieval_eval.py --model-name-or-path meta-llama/Llama-2-7b-chat-hf

Speedup Evaluation

  1. Prepare Weight
cd path/to/DoubleSparse
python3 offloading/scripts/convert_hf_checkpoint.py --checkpoint_dir ~/checkpoints/meta-llama/Llama-2-7b-chat-hf --model_name meta-llama/Llama-2-7b-chat-hf
  1. Attention Operator Speedup
bash scripts/run_attn.sh
  1. End-to-End Inference Speedup
# no offloading
cd models/
python3 generate.py --checkpoint_path path/to/weight/model.pth --max_new_tokens 2048 --batch_size 4
# offloading
cd offloading/
python3 generate.py --checkpoint_path path/to/weight/model_offloading.pth --max_new_tokens 2048 --batch_size 4

Chat with Double Sparsity

Try to chat

python3 evaluation/chat.py --model_name meta-llama/Llama-2-7b-chat-hf

About

16-fold memory access reduction with nearly no loss

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%