Slax is a JAX library built on top of Flax. Its purpose is to provide many learning rules and gradient approximations, with a heavy focus on training spiking neural networks. Some of these algorithms, however, are also applicable to any recurrent network.
Slax was originally written with the Flax linen API but is now being rewritten with the NNX API. Functionality and documentation is limited and unreliable until the conversion is finished. An alpha-state release with linen is already on PyPI (pip install slax
), but check back soon for the next version!