Skip to content

sw32-seo/neuralODE

Repository files navigation

Neural ODE with Flax

This is the result of project "Reproduce Neural ODE and SDE" in HuggingFace Flax/JAX community week.

main.py will execute training of ResNet or OdeNet for MNIST dataset.

Dependency

JAX and Flax

For JAX installation, please follow here.

or simply, type

pip install jax jaxlib

For Flax installation,

pip install flax

Tensorflow-datasets will download MNIST dataset to environment.

How to run training

For (small) ResNet training,

python main.py --model=resnet --lr=1e-4 --n_epoch=20 --batch_size=64 

For Neural ODE training,

python main.py --model=odenet --lr=1e-4 --n_epoch=20 --batch_size=64

For Continuous Normalizing Flow,

python main.py --model=cnf --sample_dataset=circles

Sample datasets can be chosen as circles, moons, or scurve.

Sample Results

cnf-viz cnf-viz cnf-viz

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages