Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme Testing + Caching in compute_gradients #640

Merged
merged 14 commits into from
May 15, 2024
Merged

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented May 12, 2024

  • Add testing for normalization functions for Enzyme
  • Caching in Training
    • Inplace update for Optimisers
    • Enzyme Training Utilities
  • Rewrite the caching for Enzyme
  • Needs some tests
    • Train a simple enough MLP in the tests

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 8bdde08 Previous: 64ba96d Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3639.25 ns 3633 ns 1.00
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7171.75 ns 7103.166666666667 ns 1.01
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 21029 ns 20759 ns 1.01
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9728 ns 9595.8 ns 1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9046.75 ns 8806 ns 1.03
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4457.125 ns 4427 ns 1.01
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1203.7539682539682 ns 1206.5289256198348 ns 1.00
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1116.9872611464968 ns 1119.1708860759493 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1188.3161764705883 ns 1198.7238805970148 ns 0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1795 ns 1795.396551724138 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 178.88664812239222 ns 178.75070028011206 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17423 ns 17362 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 17092 ns 17443 ns 0.98
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37139 ns 37119 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28333 ns 28172 ns 1.01
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19967 ns 19957 ns 1.00
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 16712 ns 16821 ns 0.99
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4343.714285714285 ns 4306.571428571428 ns 1.01
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3899.75 ns 3846 ns 1.01
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 4003.75 ns 3947.25 ns 1.01
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4996.428571428572 ns 4824.714285714285 ns 1.04
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1658 ns 1656.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 38849589.5 ns 38507163 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57463704.5 ns 57582497 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 76284433.5 ns 75605722.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 88942467 ns 88334752 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 72781539 ns 72169093 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11978857 ns 11692461 ns 1.02
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 8337439 ns 17394732.5 ns 0.48
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7019228 ns 6995759 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6969104 ns 6978091 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 10049824 ns 9930897 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6377212 ns 6387632 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 713976867 ns 693404244 ns 1.03
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2834760052 ns 2833937581 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 145850873 ns 156241063 ns 0.93
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 842710339 ns 834588032 ns 1.01
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2574059116 ns 2548330832 ns 1.01
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 202158213 ns 178313969 ns 1.13
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 655838421 ns 678237525.5 ns 0.97
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2765113753 ns 2822783216 ns 0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 131452603 ns 120430437.5 ns 1.09
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 172815028.5 ns 175244675 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 645535092 ns 651623284 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 34669482 ns 45831746 ns 0.76
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 165298740 ns 165271861 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 641045050 ns 639340636 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30403729 ns 30230425.5 ns 1.01
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 187925737 ns 186370118 ns 1.01
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 717992556 ns 708821435 ns 1.01
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 37845546 ns 35641646.5 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1276726758 ns 1269088947 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1854655259 ns 1867861740 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2195954604 ns 1985497479 ns 1.11
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2320413681 ns 2378938010.5 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1888979051.5 ns 1858802371.5 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 347336582 ns 550456586.5 ns 0.63
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 320737784 ns 321344700 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 321591947.5 ns 323696598.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 385026390 ns 365630240.5 ns 1.05
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11903504 ns 11811649.5 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17779926 ns 17740770.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19142173 ns 19084948 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23833763 ns 23808418.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17808550 ns 17848657.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1153773 ns 1164698.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 2501924.5 ns 5670763 ns 0.44
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2048820 ns 2044411 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2027415 ns 2030189 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2075285 ns 2070069.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 197198 ns 209019 ns 0.94
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 293927 ns 293681.5 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 267428 ns 268068 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 370820 ns 370370 ns 1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 411265.5 ns 412428 ns 1.00
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 275784 ns 276023 ns 1.00
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 408981 ns 413690 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 84137 ns 83986 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 82734 ns 81692 ns 1.01
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 84428 ns 82072 ns 1.03
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 87764 ns 87042 ns 1.01
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104554 ns 104775 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 199995225 ns 186685434.5 ns 1.07
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 326496003.5 ns 321228142.5 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 400994732 ns 392606286 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 462902977 ns 460952246 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 371403208 ns 370351810 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 335761862 ns 340754848 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 51360283.5 ns 99666130 ns 0.52
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 44195754 ns 43812770.5 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43921733.5 ns 43647694 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 50184364.5 ns 49549909.5 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28887089.5 ns 28547342 ns 1.01
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19203788 ns 19074425 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19668172.5 ns 19527795.5 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23871830.5 ns 23388473 ns 1.02
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24346770.5 ns 24094753 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19745369 ns 19700860 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6532224 ns 6506694 ns 1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6544657 ns 6506504 ns 1.01
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6497849 ns 6500898 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6497188 ns 6496878.5 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

ext/LuxEnzymeExt.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Member Author

We should be caching the parameter gradients loss function compiled trace and such but this should be good initial version, we anyways need a redesign of the training API later on.

ext/LuxEnzymeExt.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/more_enzyme_tests branch from a101eef to 475a8cc Compare May 12, 2024 22:38
@avik-pal
Copy link
Member Author

Need to wait for SciMLSensitivity SciML/SciMLSensitivity.jl#1046 before the doc build goes through

@wsmoses
Copy link
Contributor

wsmoses commented May 12, 2024 via email

@avik-pal
Copy link
Member Author

avik-pal commented May 12, 2024

Structural in what way

As in I want to say don't backpropagate wrt this value. For Zygote I would put a nothing

@wsmoses
Copy link
Contributor

wsmoses commented May 12, 2024 via email

@avik-pal
Copy link
Member Author

So how would I annotate the return type? I am getting a tuple containing a scalar, named tuple and an arbitrary object, we don't need to backpropagate for the last two

@wsmoses
Copy link
Contributor

wsmoses commented May 12, 2024 via email

@avik-pal
Copy link
Member Author

you mean something like

function compute_gradients(........)
	st_new_outer = Ref()
    stats_outer = Ref()

    function wrapper_function(args...)
	    y, st_new, stats = objective_function(args...)
	    st_new_outer[] = st_new
	    stats_outer[] = stats
	    return y
	end

    .....
end

@wsmoses
Copy link
Contributor

wsmoses commented May 13, 2024 via email

@avik-pal avik-pal changed the title More Enzyme test coverage Enzyme Testing + Caching in compute_gradients May 13, 2024
ext/LuxOptimisersExt.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Member Author

avik-pal commented May 13, 2024

using ADTypes, Lux, Random, Enzyme, Optimisers

model = Chain(Conv((3, 3), 3 => 6), GroupNorm(6, 3, gelu), Conv((3, 3), 6 => 32),
    BatchNorm(32, gelu), GlobalMeanPool(), FlattenLayer(), Dense(32, 1))

x = rand(Float32, 32, 32, 3, 4);
tstate = Lux.Experimental.TrainState(Xoshiro(0), model, Adam(0.001f0));

function obj_fn(model, ps, st, x)
    y, st_new = model(x, ps, st)
    return sum(abs2, y), st_new, (;)
end

grads, loss, stats, tstate_new = Lux.Experimental.compute_gradients(
    AutoEnzyme(), obj_fn, x, tstate);

grads, loss, stats, tstate_new = Lux.Experimental.compute_gradients(
    AutoEnzyme(), obj_fn, x, tstate_new);

@btime Lux.Experimental.compute_gradients($AutoEnzyme(), $obj_fn, $x, $tstate);
# 14.726 ms (461 allocations: 9.75 MiB)

@btime Lux.Experimental.compute_gradients($AutoEnzyme(), $obj_fn, $x, $tstate_new);
# 14.233 ms (447 allocations: 9.74 MiB)

Caching seems to work correctly.

@avik-pal avik-pal force-pushed the ap/more_enzyme_tests branch 2 times, most recently from 43a3bef to 0f0559b Compare May 13, 2024 00:34
ext/LuxEnzymeExt.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/more_enzyme_tests branch from d21a9c0 to 7822166 Compare May 13, 2024 01:54
ext/LuxEnzymeExt.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Member Author

Ok I did something wrong, it segfaulted the training test https://github.com/LuxDL/Lux.jl/actions/runs/9056705562/job/24879628489?pr=640#step:6:739

@avik-pal
Copy link
Member Author

Locally things pass. Now we need to wait for SciMLSensitivity compats to be updated.

ext/LuxEnzymeExt.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/more_enzyme_tests branch 3 times, most recently from 0539e90 to 1f1de4c Compare May 13, 2024 15:00
@avik-pal
Copy link
Member Author

CI is not picking up on the latest SciMLSensitivity

@avik-pal avik-pal force-pushed the ap/more_enzyme_tests branch 2 times, most recently from 8e2723c to f2f97ef Compare May 14, 2024 20:58
@avik-pal avik-pal force-pushed the ap/more_enzyme_tests branch 2 times, most recently from 70f6e5d to ecd2b3d Compare May 14, 2024 23:17
@avik-pal avik-pal force-pushed the ap/more_enzyme_tests branch from ecd2b3d to 8bdde08 Compare May 14, 2024 23:30
Copy link

codecov bot commented May 15, 2024

Codecov Report

Attention: Patch coverage is 74.19355% with 16 lines in your changes are missing coverage. Please review.

Project coverage is 87.16%. Comparing base (64ba96d) to head (8bdde08).

Files Patch % Lines
src/utils.jl 26.66% 11 Missing ⚠️
ext/LuxEnzymeExt.jl 86.36% 3 Missing ⚠️
src/contrib/training.jl 89.47% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #640      +/-   ##
==========================================
- Coverage   87.56%   87.16%   -0.40%     
==========================================
  Files          49       50       +1     
  Lines        2380     2439      +59     
==========================================
+ Hits         2084     2126      +42     
- Misses        296      313      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@avik-pal avik-pal merged commit 3df3cee into main May 15, 2024
55 of 57 checks passed
@avik-pal avik-pal deleted the ap/more_enzyme_tests branch May 15, 2024 00:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants