Skip to content

Triton-based implementation of Sparse Mixture of Experts.

License

Notifications You must be signed in to change notification settings

qamcintyre/scattermoe

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

scattermoe

Triton-based implementation of Sparse Mixture-of-Experts (SMoE) on GPUs. ScatterMoE builds upon existing implementations, and overcoming some of the limitations to improve inference, training speed, and memory footprint. This implementation achieves this by avoiding padding and making excessive copies of the input. We also fuse expert linear transforms and reordering operations with ParallelLinear, a module that can be used to extend the concept of SMoEs.

This implementation is lightweight (~700 lines). It will work within an FSDP or pipeline parallel framework, but does not include any additional multi-node training infrastructure code. You can find the report here

Installation

# Check all is working well.
PYTHONPATH=. pytest tests
# Install editable. This will allow you to modify scattermoe in this directory.
pip install -e .

Usage

from scattermoe.mlp import MLP

# Initialise module...
mlp = MLP(
    input_size=x_dim, hidden_size=h_dim,
    activation=nn.GELU(),
    num_experts=E, top_k=k
)

# Calling module...
Y = mlp(
    X,         # input tensor
    k_weights, # top-k weights from router
    k_idxs     # top-k indices from router
)

Bibtex

If you use ScatterMoE in your project, cite us!

@article{tan2024scattered,
  title={Scattered Mixture-of-Experts Implementation},
  author={Tan, Shawn and Shen, Yikang and Panda, Rameswar and Courville, Aaron},
  journal={arXiv preprint arXiv:2403.08245},
  year={2024}
}

Enjoy!

Version 0.2.0

  • Made compileable.

More examples

  1. Integration into HuggingFace Mixtral

About

Triton-based implementation of Sparse Mixture of Experts.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%