-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: NilArray
for fast size propagation
#811
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 8629a8b | Previous: b6171a6 | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3633.125 ns |
3675.625 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
6646.714285714285 ns |
8093.5 ns |
0.82 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
20879 ns |
21210 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9842.4 ns |
9748.2 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
9184.75 ns |
9167.2 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4571 ns |
4470.875 ns |
1.02 |
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) |
5064.4375 ns |
4956.875 ns |
1.02 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
1040.125786163522 ns |
2373.4 ns |
0.44 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
1061.2422360248447 ns |
2270.3 ns |
0.47 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1809.2745098039215 ns |
1790.017543859649 ns |
1.01 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
180.2808510638298 ns |
179.70239774330042 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17202 ns |
17562.5 ns |
0.98 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
12814 ns |
24787 ns |
0.52 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
37360 ns |
38393 ns |
0.97 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
29505 ns |
29025 ns |
1.02 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
19877 ns |
21590 ns |
0.92 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
17022 ns |
17092 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) |
25859 ns |
25648 ns |
1.01 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
1442.7 ns |
20248 ns |
0.0712514816278151 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
1487.8 ns |
14448 ns |
0.10 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4936.357142857143 ns |
4846.285714285715 ns |
1.02 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1654.1 ns |
1659.2 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
87422394 ns |
77690170 ns |
1.13 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
49068510 ns |
76782338 ns |
0.64 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
148852073 ns |
155414925 ns |
0.96 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
173980105 ns |
167638289.5 ns |
1.04 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
161133250 ns |
142842293.5 ns |
1.13 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
11689961.5 ns |
11557321.5 ns |
1.01 |
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) |
164052470.5 ns |
199234044.5 ns |
0.82 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
6018522.5 ns |
15528408.5 ns |
0.39 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
6044988 ns |
15540189 ns |
0.39 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
38421814.5 ns |
30661456 ns |
1.25 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6414706 ns |
6376663 ns |
1.01 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
1061301570.5 ns |
1064055959.5 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
2930247441 ns |
2970205700 ns |
0.99 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) |
220745078 ns |
178121161 ns |
1.24 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
1413195702 ns |
1320655778 ns |
1.07 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
4014711404 ns |
3516351096 ns |
1.14 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) |
401902295 ns |
344809509 ns |
1.17 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
1535181279 ns |
1431616033 ns |
1.07 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
4081079728 ns |
4058579611 ns |
1.01 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) |
481251223 ns |
436008182 ns |
1.10 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
419700776 ns |
381866129 ns |
1.10 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
929157584.5 ns |
905256978 ns |
1.03 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) |
48962456 ns |
54567006.5 ns |
0.90 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
405274961 ns |
382293897 ns |
1.06 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
912404034.5 ns |
870357323.5 ns |
1.05 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) |
29924166 ns |
54472914.5 ns |
0.55 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
501445895 ns |
551222188 ns |
0.91 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
1419467769 ns |
1387168504 ns |
1.02 |
vgg16/cpu/forward/Flux/(32, 32, 3, 2) |
176639583 ns |
164122645 ns |
1.08 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1242390168.5 ns |
1180058919 ns |
1.05 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1582686085 ns |
1610297742 ns |
0.98 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2420171712 ns |
2289727615.5 ns |
1.06 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2706252208 ns |
2640437136 ns |
1.02 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
2184605134.5 ns |
2193753011.5 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) |
2329743859 ns |
2122924359 ns |
1.10 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
286789405 ns |
282003619 ns |
1.02 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
286283817 ns |
286261947 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
484850885.5 ns |
437257287 ns |
1.11 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
11816331 ns |
11806435 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
15283871.5 ns |
34527638 ns |
0.44 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
16457481.5 ns |
16364743 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
21131640 ns |
21004093 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
15324828.5 ns |
15284140 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1168357 ns |
1148921.5 ns |
1.02 |
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) |
19286160 ns |
35777843.5 ns |
0.54 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
1872865 ns |
4500694 ns |
0.42 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
1881051 ns |
4506207 ns |
0.42 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
2067470 ns |
2045686 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
201267 ns |
196300 ns |
1.03 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
411114.5 ns |
378068 ns |
1.09 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
209902 ns |
314462 ns |
0.67 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
383542.5 ns |
377972 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
533248.5 ns |
520691 ns |
1.02 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
292938 ns |
289716 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
406301 ns |
401777 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) |
429348.5 ns |
425321 ns |
1.01 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
56887 ns |
157406 ns |
0.36 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
57648 ns |
162456 ns |
0.35 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
92303 ns |
91953 ns |
1.00 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
105066 ns |
104407 ns |
1.01 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
337073468 ns |
297649242 ns |
1.13 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
272045813.5 ns |
287837994 ns |
0.95 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
560383862.5 ns |
545531151.5 ns |
1.03 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
649394013 ns |
655809148 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
610291545 ns |
554893727 ns |
1.10 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
340241373 ns |
316084028.5 ns |
1.08 |
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) |
621345235 ns |
583442251.5 ns |
1.06 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
38492275.5 ns |
40159465 ns |
0.96 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
38495211 ns |
40173961.5 ns |
0.96 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
104690381 ns |
96663497 ns |
1.08 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
28089279 ns |
28321531 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
23382291 ns |
21078472 ns |
1.11 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
19449597 ns |
17393481 ns |
1.12 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
22746955 ns |
22657728 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
26574643.5 ns |
28019412 ns |
0.95 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19364727 ns |
19298592.5 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) |
20907247 ns |
20720819 ns |
1.01 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6595981 ns |
6086608 ns |
1.08 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6478264 ns |
6101998 ns |
1.06 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6549491 ns |
6509879.5 ns |
1.01 |
This comment was automatically generated by workflow using github-action-benchmark.
debf54b
to
5c81b59
Compare
Benchmark Results (ASV)
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
4fcf073
to
c10bb8c
Compare
NilArray
for fast size propagation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #811 +/- ##
==========================================
- Coverage 95.34% 93.78% -1.56%
==========================================
Files 58 59 +1
Lines 2855 2959 +104
==========================================
+ Hits 2722 2775 +53
- Misses 133 184 +51 ☔ View full report in Codecov by Sentry. |
This doesn't introduce any public API! All changes made here are private till we tag a 1.0 release, after which
outputsize
will use this implementation.All the dependencies added in this PR were already being installed in Lux one way or another