From ceda8c0bbd2eb3233c4db80d08c0bb1cc1d322c9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Dec 2024 21:31:04 +0530 Subject: [PATCH] refactor: centralize the CIFAR10 examples --- docs/src/.vitepress/config.mts | 4 +- docs/src/tutorials/index.md | 6 +- examples/{ConvMixer => CIFAR10}/Project.toml | 5 +- examples/{ConvMixer => CIFAR10}/README.md | 65 +++++------ .../{ConvMixer/main.jl => CIFAR10/common.jl} | 104 ++++++++---------- examples/CIFAR10/conv_mixer.jl | 50 +++++++++ examples/CIFAR10/mlp_mixer.jl | 6 + examples/CIFAR10/simple_cnn.jl | 36 ++++++ 8 files changed, 174 insertions(+), 102 deletions(-) rename examples/{ConvMixer => CIFAR10}/Project.toml (89%) rename examples/{ConvMixer => CIFAR10}/README.md (79%) rename examples/{ConvMixer/main.jl => CIFAR10/common.jl} (56%) create mode 100644 examples/CIFAR10/conv_mixer.jl create mode 100644 examples/CIFAR10/mlp_mixer.jl create mode 100644 examples/CIFAR10/simple_cnn.jl diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index f785f6a31..bdd870e08 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -243,8 +243,8 @@ export default defineConfig({ link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/DDIM", }, { - text: "ConvMixer on CIFAR-10", - link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer", + text: "Different Vision Models on CIFAR-10", + link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10", }, ], }, diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 75c45f7b9..6b01da2b7 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -97,10 +97,10 @@ const large_models = [ desc: "Train a Diffusion Model to generate images from Gaussian noises." }, { - href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer", + href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10", src: "https://datasets.activeloop.ai/wp-content/uploads/2022/09/CIFAR-10-dataset-Activeloop-Platform-visualization-image-1.webp", - caption: "ConvMixer on CIFAR-10", - desc: "Train ConvMixer on CIFAR-10 to 90% accuracy within 10 minutes." + caption: "Vision Models on CIFAR-10", + desc: "Train differnt vision models on CIFAR-10 to 90% accuracy within 10 minutes." } ]; diff --git a/examples/ConvMixer/Project.toml b/examples/CIFAR10/Project.toml similarity index 89% rename from examples/ConvMixer/Project.toml rename to examples/CIFAR10/Project.toml index 125c6612d..540774f4e 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/CIFAR10/Project.toml @@ -12,9 +12,8 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" +ProgressTables = "e0b4b9f6-8cc7-451e-9c86-94c5316e9f73" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -34,9 +33,7 @@ MLDatasets = "0.7.14" MLUtils = "0.4.4" OneHotArrays = "0.2.5" Optimisers = "0.4.1" -PreferenceTools = "0.1.2" Printf = "1.10" -ProgressBars = "1.5.1" Random = "1.10" Reactant = "0.2.11" Statistics = "1.10" diff --git a/examples/ConvMixer/README.md b/examples/CIFAR10/README.md similarity index 79% rename from examples/ConvMixer/README.md rename to examples/CIFAR10/README.md index 54d7d1f94..6e1841663 100644 --- a/examples/ConvMixer/README.md +++ b/examples/CIFAR10/README.md @@ -1,6 +1,35 @@ -# Train ConvMixer on CIFAR-10 +# Train Vision Models on CIFAR-10 - โœˆ๏ธ ๐Ÿš— ๐Ÿฆ ๐Ÿˆ ๐ŸฆŒ ๐Ÿ• ๐Ÿธ ๐ŸŽ ๐Ÿšข ๐Ÿšš +โœˆ๏ธ ๐Ÿš— ๐Ÿฆ ๐Ÿˆ ๐ŸฆŒ ๐Ÿ• ๐Ÿธ ๐ŸŽ ๐Ÿšข ๐Ÿšš + +We have the following scripts to train vision models on CIFAR-10: + +1. `simple_cnn.jl`: Simple CNN model with a sequence of convolutional layers. +2. `mlp_mixer.jl`: MLP-Mixer model. +3. `conv_mixer.jl`: ConvMixer model. + +To get the options for each script, run the script with the `--help` flag. + +> [!NOTE] +> To train the model using Reactant.jl pass in `--backend=reactant` to the script. This is +> the recommended approch to train the models present in this directory. + +## Simple CNN + +```bash +julia --startup-file=no \ + --project=. \ + --threads=auto \ + simple_cnn.jl \ + --backend=reactant +``` + +On a RTX 4050 6GB Laptop GPU the training takes approximately 3 mins and the final training +and test accuracies are 97% and 65%, respectively. + +## MLP-Mixer + +## ConvMixer > [!NOTE] > This code has been adapted from https://github.com/locuslab/convmixer-cifar10 @@ -11,14 +40,11 @@ for new experiments on small datasets. You can get around **90.0%** accuracy in just **25 epochs** by running the script with the following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2. -> [!NOTE] -> To train the model using Reactant.jl pass in `--backend=reactant` to the script. - ```bash julia --startup-file=no \ --project=. \ --threads=auto \ - main.jl \ + conv_mixer.jl \ --lr-max=0.05 \ --weight-decay=0.0001 \ --backend=reactant @@ -54,32 +80,7 @@ Epoch 24: Learning Rate 8.29e-04, Train Acc: 99.99%, Test Acc: 90.79%, Time: 21. Epoch 25: Learning Rate 4.12e-04, Train Acc: 100.00%, Test Acc: 90.83%, Time: 21.32 ``` -## Usage - -```bash - main [options] [flags] - -Options - - --batchsize <512::Int> - --hidden-dim <256::Int> - --depth <8::Int> - --patch-size <2::Int> - --kernel-size <5::Int> - --weight-decay <0.01::Float64> - --seed <42::Int> - --epochs <25::Int> - --lr-max <0.01::Float64> - --backend - -Flags - --clip-norm - - -h, --help Print this help message. - --version Print version. -``` - -## Notes +### Notes 1. To match the results from the original repo, we need more augmentation strategies, that are currently not implemented in DataAugmentation.jl. diff --git a/examples/ConvMixer/main.jl b/examples/CIFAR10/common.jl similarity index 56% rename from examples/ConvMixer/main.jl rename to examples/CIFAR10/common.jl index 9d1c6cb5e..84647e8ae 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/CIFAR10/common.jl @@ -1,9 +1,6 @@ -using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA, - MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random, - Statistics, Zygote -using Reactant, Enzyme - -CUDA.allowscalar(false) +using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays, + Printf, ProgressTables, Random +using LuxCUDA, Reactant @concrete struct TensorDataset dataset @@ -18,7 +15,7 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac return stack(parent โˆ˜ itemdata โˆ˜ Base.Fix1(apply, ds.transform), img), y end -function get_dataloaders(batchsize; kwargs...) +function get_cifar10_dataloaders(batchsize; kwargs...) cifar10_mean = (0.4914, 0.4822, 0.4465) cifar10_std = (0.2471, 0.2435, 0.2616) @@ -38,35 +35,6 @@ function get_dataloaders(batchsize; kwargs...) return trainloader, testloader end -function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) - #! format: off - return Chain( - Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), - BatchNorm(dim), - [ - Chain( - SkipConnection( - Chain( - Conv( - (kernel_size, kernel_size), dim => dim, gelu; - groups=dim, pad=SamePad() - ), - BatchNorm(dim) - ), - + - ), - Conv((1, 1), dim => dim, gelu), - BatchNorm(dim) - ) - for _ in 1:depth - ]..., - GlobalMeanPool(), - FlattenLayer(), - Dense(dim => 10) - ) - #! format: on -end - function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 cdev = cpu_device() @@ -79,41 +47,37 @@ function accuracy(model, ps, st, dataloader) return total_correct / total end -Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, - patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.005, - clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, - backend::String="gpu_if_available") - rng = Random.default_rng() - Random.seed!(rng, seed) - +function get_accelerator_device(backend::String) if backend == "gpu_if_available" - accelerator_device = gpu_device() + return gpu_device() elseif backend == "gpu" - accelerator_device = gpu_device(; force=true) + return gpu_device(; force=true) elseif backend == "reactant" - accelerator_device = reactant_device(; force=true) + return reactant_device(; force=true) elseif backend == "cpu" - accelerator_device = cpu_device() + return cpu_device() else error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \ `reactant`, and `cpu`.") end +end +function train_model( + model, opt, scheduler=nothing; + backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25 +) + rng = Random.default_rng() + Random.seed!(rng, seed) + + accelerator_device = get_accelerator_device(backend) kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () - trainloader, testloader = get_dataloaders(batchsize; kwargs...) |> accelerator_device + trainloader, testloader = get_cifar10_dataloaders(batchsize; kwargs...) |> + accelerator_device - model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) ps, st = Lux.setup(rng, model) |> accelerator_device - opt = AdamW(; eta=lr_max, lambda=weight_decay) - clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) - train_state = Training.TrainState(model, ps, st, opt) - lr_schedule = linear_interpolation( - [0, epochs * 2 รท 5, epochs * 4 รท 5, epochs + 1], [0, lr_max, lr_max / 20, 0] - ) - adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() if backend == "reactant" @@ -128,16 +92,32 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: loss_fn = CrossEntropyLoss(; logits=Val(true)) + pt = ProgressTable(; + header=[ + "Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)" + ], + widths=[24, 24, 24, 24, 24], + format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"], + color=[:normal, :normal, :blue, :blue, :normal], + border=true, + alignment=[:center, :center, :center, :center, :center] + ) + @printf "[Info] Training model\n" + initialize(pt) + for epoch in 1:epochs stime = time() lr = 0 for (i, (x, y)) in enumerate(trainloader) - lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader)) - train_state = Optimisers.adjust!(train_state, lr) - (_, _, _, train_state) = Training.single_train_step!( + if scheduler !== nothing + lr = scheduler((epoch - 1) + (i + 1) / length(trainloader)) + train_state = Optimisers.adjust!(train_state, lr) + end + (_, loss, _, train_state) = Training.single_train_step!( adtype, loss_fn, (x, y), train_state ) + isnan(loss) && error("NaN loss encountered!") end ttime = time() - stime @@ -150,8 +130,10 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: Lux.testmode(train_state.states), testloader ) * 100 - @printf "[Train] Epoch %2d: Learning Rate %.6f, Train Acc: %.4f%%, Test Acc: \ - %.4f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime + scheduler === nothing && (lr = NaN32) + next(pt, [epoch, lr, train_acc, test_acc, ttime]) end + + finalize(pt) @printf "[Info] Finished training\n" end diff --git a/examples/CIFAR10/conv_mixer.jl b/examples/CIFAR10/conv_mixer.jl new file mode 100644 index 000000000..55f0b20da --- /dev/null +++ b/examples/CIFAR10/conv_mixer.jl @@ -0,0 +1,50 @@ +using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme + +@isdefined(includet) ? includet("common.jl") : include("common.jl") + +CUDA.allowscalar(false) + +function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) + #! format: off + return Chain( + Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), + BatchNorm(dim), + [ + Chain( + SkipConnection( + Chain( + Conv( + (kernel_size, kernel_size), dim => dim, gelu; + groups=dim, pad=SamePad() + ), + BatchNorm(dim) + ), + + + ), + Conv((1, 1), dim => dim, gelu), + BatchNorm(dim) + ) + for _ in 1:depth + ]..., + GlobalMeanPool(), + FlattenLayer(), + Dense(dim => 10) + ) + #! format: on +end + +Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, + patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001, + clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, + backend::String="reactant") + model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) + + opt = AdamW(; eta=lr_max, lambda=weight_decay) + clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) + + lr_schedule = linear_interpolation( + [0, epochs * 2 รท 5, epochs * 4 รท 5, epochs + 1], [0, lr_max, lr_max / 20, 0] + ) + + return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs) +end diff --git a/examples/CIFAR10/mlp_mixer.jl b/examples/CIFAR10/mlp_mixer.jl new file mode 100644 index 000000000..1132d0991 --- /dev/null +++ b/examples/CIFAR10/mlp_mixer.jl @@ -0,0 +1,6 @@ +using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme + +CUDA.allowscalar(false) + +@isdefined(includet) ? includet("common.jl") : include("common.jl") + diff --git a/examples/CIFAR10/simple_cnn.jl b/examples/CIFAR10/simple_cnn.jl new file mode 100644 index 000000000..23dd51051 --- /dev/null +++ b/examples/CIFAR10/simple_cnn.jl @@ -0,0 +1,36 @@ +using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme + +@isdefined(includet) ? includet("common.jl") : include("common.jl") + +CUDA.allowscalar(false) + +function SimpleCNN() + return Chain( + Conv((3, 3), 3 => 16, gelu; stride=2, pad=1), + BatchNorm(16), + Conv((3, 3), 16 => 32, gelu; stride=2, pad=1), + BatchNorm(32), + Conv((3, 3), 32 => 64, gelu; stride=2, pad=1), + BatchNorm(64), + Conv((3, 3), 64 => 128, gelu; stride=2, pad=1), + BatchNorm(128), + GlobalMeanPool(), + FlattenLayer(), + Dense(128 => 64, gelu), + BatchNorm(64), + Dense(64 => 10) + ) +end + +Comonicon.@main function main(; + batchsize::Int=512, weight_decay::Float64=0.0001, + clip_norm::Bool=false, seed::Int=1234, epochs::Int=50, lr::Float64=0.003, + backend::String="reactant" +) + model = SimpleCNN() + + opt = AdamW(; eta=lr, lambda=weight_decay) + clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) + + return train_model(model, opt, nothing; backend, batchsize, seed, epochs) +end