You will need Python 3.7 or later.
- Clone the repository:
> git clone https://github.com/matthias-wright/flaxmodels.git
- Go into the directory:
> cd flaxmodels/training/resnet
- Install Jax with CUDA.
- Install requirements:
> pip install -r requirements.txt
CUDA_VISIBLE_DEVICES=0 python main.py
The script will automatically use all the visible GPUs for distributed training.
CUDA_VISIBLE_DEVICES=0,1 python main.py
CUDA_VISIBLE_DEVICES=0,1 python main.py --mixed_precision
--work_dir
- Path to directory for logging and checkpoints (str).--data_dir
- Path for storing the dataset (str).--name
- Name of the training run (str).--group
- Group name of the training run (str).--arch
- Architecture (str). Options: resnet18, resnet34, resnet50, resnet101, resnet152.--resume
- Resume training from best checkpoint (bool).--num_epochs
- Number of epochs (int).--learning_rate
- Learning rate (float).--warmup_epochs
- Number of warmup epochs with lower learning rate (int).--batch_size
- Batch size (int).--num_classes
- Number of classes (int).--img_size
- Image size (int).--img_channels
- Number of image channels (int).--mixed_precision
- Use mixed precision training (bool).--random_seed
- Random seed (int).--wandb
- Use Weights&Biases for logging (bool).--log_every
- Log every log_every steps (int).
ResNet18 was trained on the Imagenette dataset. The validation accuracy is around 90%.
- Images were resized to 256x256 (random crops for training and center crops for evaluation).
- Data augmentation: flipping, brightness, hue, contrast.
- Learning rate schedule: Cosine Annealing.
- Training was done from scratch, no transfer learning.