A super brief README
install JAX : (following the instructions) https://github.com/google/jax
install flax (NN framework based on JAX)
pip install flax
Install tensorflow for visualizing the training process (tesorboard)
pip install tensorflow
download data from here [Google Drive]
Assume we use 2 gpus for training.
CUDA_VISIVLE_DEVICES=0,1 train.py --config zaragoza_bunny_all_grid
CUDA_VISIBLE_DEVICES=0, reander.py --config zaragonza_bunny_all_grid