Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: more coverage for NNlib functions #258

Merged
merged 8 commits into from
Nov 11, 2024
Merged

feat: more coverage for NNlib functions #258

merged 8 commits into from
Nov 11, 2024

Conversation

avik-pal
Copy link
Collaborator

No description provided.

# _zero, _one, _inf = T(0), T(1), T(Inf)
# @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))
#end
@trace if all(isfinite, max_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already putting that to good use!

That said. I kind of wonder how far we are from the original nnlib impl at this point, and whether
it would make sense to put @trace in the NNlib implementation itself:

https://github.com/FluxML/NNlib.jl/blob/02138682a4fc5ca019759218be50e59907d4527c/src/softmax.jl#L60

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are mostly there. I am trying to debug an issue with fastmath annotations inside @trace if, but that should mostly do it

@avik-pal avik-pal force-pushed the ap/nnlib_inplace! branch 2 times, most recently from f1937fd to 3102dfa Compare November 10, 2024 16:12
@avik-pal avik-pal mentioned this pull request Nov 10, 2024
34 tasks
@avik-pal avik-pal marked this pull request as ready for review November 10, 2024 17:08
Base automatically changed from ap/gather to main November 10, 2024 17:16
@avik-pal
Copy link
Collaborator Author

something https://github.com/EnzymeAD/Reactant.jl/actions/runs/11767109998/job/32775434246?pr=258#step:9:1126 seems to have been broken, let's hold merging this

@wsmoses
Copy link
Member

wsmoses commented Nov 10, 2024

something https://github.com/EnzymeAD/Reactant.jl/actions/runs/11767109998/job/32775434246?pr=258#step:9:1126 seems to have been broken, let's hold merging this

cc @Pangoraw

So this implies there's reverse mode happening, and we haven't fixed caching infra for it. I presume there's now reverse mode AD of control flow here (which isn't able to be fully removed via optimizations)?

# end
inf_num = Reactant.promote_to(TracedRNumber{T}, Inf)
zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
@trace if all(isfinite, max_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now can you remove the traced if here and above (marking with a todo), then I presume this will work

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 34b6e1e Previous: 9d666f8 Ratio
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 6852764800 ns 6791483981 ns 1.01
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 6697261766 ns 5758935576 ns 1.16
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5294995398 ns 5991010944 ns 0.88
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7431371705 ns 7226896499 ns 1.03
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 32637240532 ns 35602806226 ns 0.92
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1566470580 ns 1554984111 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1546124252 ns 1544583024 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1574353821 ns 1548949983 ns 1.02
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3300883461 ns 3305789982 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 3386749570 ns 2506900881 ns 1.35
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2147888797 ns 2135647085 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2152036908 ns 2114392785 ns 1.02
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2165646884 ns 2152785430 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3898524209 ns 3920703119 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 6063162560 ns 6189728956.5 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1429171908 ns 1457078887 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1424476406 ns 1444716715 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1415807607 ns 1436557810 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3140633666 ns 3216204610 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1150610587 ns 1056965356.5 ns 1.09
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1702918284 ns 1742579811 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1718980070 ns 1724456565 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1714788165 ns 1703597363 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3451135114 ns 3527083552 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 2966206737 ns 3168974882.5 ns 0.94
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2200141553 ns 2194009586 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2174766519 ns 2185020760 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2188699863 ns 2145844912 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3916361982 ns 3888344576 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 6507662228 ns 5917571640 ns 1.10
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3006898599 ns 2952461976 ns 1.02
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 2998074765 ns 3000541626 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3028050464 ns 3037769031 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4914821059 ns 5003447240 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 9900260894 ns 16939295133 ns 0.58
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3160947079 ns 3182179123 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3152623721 ns 3260047692 ns 0.97
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3203707923 ns 3279955673 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5050627005 ns 5217808864 ns 0.97
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 11506613809 ns 13308548315 ns 0.86
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1862317310 ns 1848723326 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1868433814 ns 1856209821 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 2255253810 ns 1839420405 ns 1.23
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3586114165 ns 3618461240 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 4246493104 ns 3120647639.5 ns 1.36

This comment was automatically generated by workflow using github-action-benchmark.

@wsmoses wsmoses merged commit f2a91bf into main Nov 11, 2024
22 of 35 checks passed
@wsmoses wsmoses deleted the ap/nnlib_inplace! branch November 11, 2024 04:08
Pangoraw pushed a commit to Pangoraw/Reactant.jl that referenced this pull request Nov 11, 2024
* feat: use dynamic slicing

* feat: special case `gather!` for the most common cases

* feat: use `@trace` to implement softmax

* refactor: directly overload inplace conv routine from NNlib

* refactor: overload inplace pooling layers

* refactor: overload inplace batched matmul

* fix: reactant needs latest reactant core

* fix: temporarily avoid tracing in softmax and logsoftmax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants