Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: NilArray for fast size propagation #811

Merged
merged 4 commits into from
Aug 18, 2024
Merged

feat: NilArray for fast size propagation #811

merged 4 commits into from
Aug 18, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Aug 1, 2024

This doesn't introduce any public API! All changes made here are private till we tag a 1.0 release, after which outputsize will use this implementation.

All the dependencies added in this PR were already being installed in Lux one way or another

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 8629a8b Previous: b6171a6 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3633.125 ns 3675.625 ns 0.99
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 6646.714285714285 ns 8093.5 ns 0.82
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20879 ns 21210 ns 0.98
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9842.4 ns 9748.2 ns 1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9184.75 ns 9167.2 ns 1.00
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4571 ns 4470.875 ns 1.02
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 5064.4375 ns 4956.875 ns 1.02
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1040.125786163522 ns 2373.4 ns 0.44
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1061.2422360248447 ns 2270.3 ns 0.47
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1809.2745098039215 ns 1790.017543859649 ns 1.01
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.2808510638298 ns 179.70239774330042 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17202 ns 17562.5 ns 0.98
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 12814 ns 24787 ns 0.52
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37360 ns 38393 ns 0.97
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29505 ns 29025 ns 1.02
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19877 ns 21590 ns 0.92
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17022 ns 17092 ns 1.00
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 25859 ns 25648 ns 1.01
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 1442.7 ns 20248 ns 0.0712514816278151
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 1487.8 ns 14448 ns 0.10
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4936.357142857143 ns 4846.285714285715 ns 1.02
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1654.1 ns 1659.2 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 87422394 ns 77690170 ns 1.13
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 49068510 ns 76782338 ns 0.64
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 148852073 ns 155414925 ns 0.96
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 173980105 ns 167638289.5 ns 1.04
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 161133250 ns 142842293.5 ns 1.13
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11689961.5 ns 11557321.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 164052470.5 ns 199234044.5 ns 0.82
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 6018522.5 ns 15528408.5 ns 0.39
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6044988 ns 15540189 ns 0.39
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 38421814.5 ns 30661456 ns 1.25
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6414706 ns 6376663 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 1061301570.5 ns 1064055959.5 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2930247441 ns 2970205700 ns 0.99
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 220745078 ns 178121161 ns 1.24
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 1413195702 ns 1320655778 ns 1.07
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 4014711404 ns 3516351096 ns 1.14
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 401902295 ns 344809509 ns 1.17
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 1535181279 ns 1431616033 ns 1.07
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 4081079728 ns 4058579611 ns 1.01
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 481251223 ns 436008182 ns 1.10
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 419700776 ns 381866129 ns 1.10
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 929157584.5 ns 905256978 ns 1.03
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 48962456 ns 54567006.5 ns 0.90
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 405274961 ns 382293897 ns 1.06
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 912404034.5 ns 870357323.5 ns 1.05
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29924166 ns 54472914.5 ns 0.55
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 501445895 ns 551222188 ns 0.91
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 1419467769 ns 1387168504 ns 1.02
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 176639583 ns 164122645 ns 1.08
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1242390168.5 ns 1180058919 ns 1.05
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1582686085 ns 1610297742 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2420171712 ns 2289727615.5 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2706252208 ns 2640437136 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 2184605134.5 ns 2193753011.5 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 2329743859 ns 2122924359 ns 1.10
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 286789405 ns 282003619 ns 1.02
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 286283817 ns 286261947 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 484850885.5 ns 437257287 ns 1.11
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11816331 ns 11806435 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 15283871.5 ns 34527638 ns 0.44
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 16457481.5 ns 16364743 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 21131640 ns 21004093 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 15324828.5 ns 15284140 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1168357 ns 1148921.5 ns 1.02
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 19286160 ns 35777843.5 ns 0.54
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 1872865 ns 4500694 ns 0.42
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 1881051 ns 4506207 ns 0.42
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2067470 ns 2045686 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 201267 ns 196300 ns 1.03
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 411114.5 ns 378068 ns 1.09
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 209902 ns 314462 ns 0.67
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 383542.5 ns 377972 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 533248.5 ns 520691 ns 1.02
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 292938 ns 289716 ns 1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 406301 ns 401777 ns 1.01
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 429348.5 ns 425321 ns 1.01
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 56887 ns 157406 ns 0.36
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 57648 ns 162456 ns 0.35
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 92303 ns 91953 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 105066 ns 104407 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 337073468 ns 297649242 ns 1.13
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 272045813.5 ns 287837994 ns 0.95
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 560383862.5 ns 545531151.5 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 649394013 ns 655809148 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 610291545 ns 554893727 ns 1.10
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 340241373 ns 316084028.5 ns 1.08
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 621345235 ns 583442251.5 ns 1.06
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 38492275.5 ns 40159465 ns 0.96
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 38495211 ns 40173961.5 ns 0.96
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 104690381 ns 96663497 ns 1.08
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28089279 ns 28321531 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 23382291 ns 21078472 ns 1.11
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19449597 ns 17393481 ns 1.12
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 22746955 ns 22657728 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 26574643.5 ns 28019412 ns 0.95
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19364727 ns 19298592.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 20907247 ns 20720819 ns 1.01
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6595981 ns 6086608 ns 1.08
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6478264 ns 6101998 ns 1.06
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6549491 ns 6509879.5 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal force-pushed the ap/nilarray branch 3 times, most recently from debf54b to 5c81b59 Compare August 18, 2024 03:49
@avik-pal avik-pal marked this pull request as ready for review August 18, 2024 18:59
Copy link
Contributor

github-actions bot commented Aug 18, 2024

Benchmark Results (ASV)

main 8629a8b... main/8629a8bb6553b7...
basics/overhead 0.0903 ± 0.0037 μs 0.0903 ± 0.0018 μs 1
time_to_load 1.04 ± 0.015 s 1.05 ± 0.0012 s 0.99

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@avik-pal avik-pal force-pushed the ap/nilarray branch 2 times, most recently from 4fcf073 to c10bb8c Compare August 18, 2024 20:40
@avik-pal avik-pal mentioned this pull request Aug 18, 2024
3 tasks
@avik-pal avik-pal changed the title feat: NilArray for check size propagation feat: NilArray for fast size propagation Aug 18, 2024
Copy link

codecov bot commented Aug 18, 2024

Codecov Report

Attention: Patch coverage is 49.00000% with 51 lines in your changes missing coverage. Please review.

Project coverage is 93.78%. Comparing base (3e77701) to head (8629a8b).
Report is 19 commits behind head on main.

Files Patch % Lines
src/helpers/size_propagator.jl 49.47% 48 Missing ⚠️
src/utils.jl 40.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #811      +/-   ##
==========================================
- Coverage   95.34%   93.78%   -1.56%     
==========================================
  Files          58       59       +1     
  Lines        2855     2959     +104     
==========================================
+ Hits         2722     2775      +53     
- Misses        133      184      +51     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@avik-pal avik-pal merged commit dcac536 into main Aug 18, 2024
64 of 76 checks passed
@avik-pal avik-pal deleted the ap/nilarray branch August 18, 2024 23:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant