This project is a JAX implementation of the RWKV (Receptance Weighted Key Value) language model. RWKV is an novel architecture that combines the efficiency of RNNs with the powerful expressiveness of Transformers. The long-term goal of this JAX implementation is to become as fast or faster than the official implementation.
pip install -r requirements.txt
Head to json2binidx
Thanks to Howard-Hou json2binidx_tool (Orginal)
Just Edit the config and start training/finetuning.
python train.py (after editing the config)
python generate.py --model_path your/path/model.rwkv
By default generate.py will use config.yaml model-config to generate, but you specify it manually by '--config yourconfig.yaml' , Also Checkout the other args in generate.py
- Implemented only Data parallesim for multi node training.
- Implement custom cuda kernel for time mixing for gpu training.
- Implement mixed precision training(not sure currently).
- Add the other time mixing versions.
- Implement State tuning.
- Write a conversion script for weights (.rwkv to .pth) to make it compatible with other RWKV projects.
- Implement State tuning.
Bug fixes, feature additions, or performance improvements, your input is valuable. Please feel free to open issues or submit pull requests, Thanks.
This implementation is based on the original RWKV model developed by BlinkDL.