Skip to content

Simple, minimal implementation of the Mamba SSM in one pytorch file. Using logcumsumexp (Heisen sequence).

License

Notifications You must be signed in to change notification settings

PeaBrane/mamba-tiny

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

39 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mamba-tiny

Tiny implementation of Mamba in PyTorch.

Featuring:

  • Equivalent numerical output as official implementation for both forward and backward pass
  • Simplified, readable, annotated code
  • An alternative to using parallel scan (not available in pytorch as of current) via cumsum, inspired by heisen_sequence

Does NOT include:

  • Recurrent mode of the network intended for inference. The demo code (sentence generation) effectively runs the network as if it were the forward pass during training, which is much slower than the recurrent mode.
  • Kernel fusion. This repo does not make any attempt to perform kernel fusion of the selective scan operations with the other dense operations. So all the internal states of the model would be explicitly materialized, so memory usage may be a bottleneck.
  • Proper parameter initialization (though this could be added without sacrificing readability)

Demo

See demo.ipynb for examples of prompt completions.

from model import Mamba
from transformers import AutoTokenizer

model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

generate(model, tokenizer, 'Mamba is the')

Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)

150 meters... 🫢 scary!

References

The Mamba architecture was introduced by Albert Gu and Tri Dao. The official implementation is here: https://github.com/state-spaces/mamba/tree/main

Related works using parallel scans in log-space:

About

Simple, minimal implementation of the Mamba SSM in one pytorch file. Using logcumsumexp (Heisen sequence).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 66.8%
  • Jupyter Notebook 33.2%