Skip to content

Help scripts for navigating the low level jax libraries

Notifications You must be signed in to change notification settings

ElleLeonne/jax_toolkit

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

A dumping ground for the various Jax utilities needed to continue my work. It's worth noting that many of these tools are tailor made for my specific use case. Not only may they require some adjustments, but posting them here often requires removal of proprietary data, which may have the unintended consequence of stripping away ease of use or showing off some more advanced functionality.

Jax is a low-level library, and is often better suited to custom built tools over generalizable libraries. Still, I hope this may still be useful to someone.

#Table of Contents

Migrator: A tool to convert Pytorch weights to Jax Pytrees. Strips off PT weights into a list, and leaves you to assign the leaves yourself.

Preprocessor: A class for preparing attention masks, position_ids, and labels. Pads, caches, and prepares data.

Checkpointer: A tool for saving model shards in Orbax. Built to handle low memory systems where weights may cause overflow issues.

Train & utils.trainer: Training pipeline tools for basic implementation of the flax_model. Designed to work with Preprocessor outputs.

About

Help scripts for navigating the low level jax libraries

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages