@ Itai Shapira
This repository contains an implementation of the Random Speculative Sampling Algorithm proposed in the paper "Accelerating Large Language Model Decoding with Speculative Sampling." (https://arxiv.org/pdf/2302.01318.pdf).
This repository was created as a pset solution for Harvard CS229.
The transformer architecture has a significant advantage over earlier sequence models in terms of parallel processing. However, when generating new sequences with transformers during inference time, each output token is dependent on all previously generated tokens, making the models run serially for each token generated. This can be a slow and computationally expensive process, especially for large models with billions of parameters. Additionally, the inability to perform batched inference wastes computational resources as large model inference is typically limited by memory bandwidth rather than compute power. To address this issue, the authors of the paper above propose leveraging the fact that certain tokens are easier to predict than others and can be generated accurately with smaller and weaker but faster models. This technique is not limited to transformers and can be applied to any large autoregressive language model.
The paper proposes a technique called Speculative Sampling to speed up decoding in large language models. Decoding can be slow and computationally expensive in large models, as each token prediction requires running the model forward through multiple layers. Speculative Sampling reduces this computational cost by parallelizing the decoding process and speculatively predicting multiple tokens at once. This sampling scheme preserves the distribution of the target model.
The algorithm uses two models: small-but-fast model (in this case, gpt2) and large-but-slow model. The algorithm is implemented using Hugging Face transformer models gpt2 (from the small model) and gpt-large (for the large model). Add a note that says that, for maximal speedups, the small model should be at least an order of magnitude
Given a t-token prefix, the algorithm generates k possible tokens sequentially using the slow-but-fast model. Next, using the big model, we compute the distrubtions of next-tokens in parallel using the provisional tokens of the small model. Next, we’ll perform a kind of rejection sampling to combine our sets of predictions, in a way that presevers the orginial distrubtion of the big model. The algorithm repeats this process until the end token is generated or a maximum length is reached.
This repository contains an implementation of the Random Speculative Sampling Algorithm in Python using the PyTorch library.
For maximal speedups, the small model should be at least an order of magnitude smaller than the large one. Yet, since the vocabularies of the two models need to be the same, we're stuck with gpt-2. It should be possible to fit both models on a single Colab GPU.
The implementation also contains autoregressive runtimes for the small model and the large model, and compares those to runtimes for the efficient attention algorithm.
If you are having trouble observing a speedup, use an extremely "predictable" prompt where the large model and the small model agree, like "A B C D". This will make it easier for the efficient inference algorithm to skip executions of the large model.
Python 3 PyTorch Transformers
"Accelerating Large Language Model Decoding with Speculative Sampling" (https://arxiv.org/pdf/2302.01318.pdf)
See also CS229 website: https://boazbk.github.io/mltheoryseminar/. Instructor: Boaz Barak. Teaching Fellows: Gustaf Ahdritz, Gal Kaplun