-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: partial NNlib.gather support + better indexing support #252
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reactant.jl Benchmarks
Benchmark suite | Current: 041cdc1 | Previous: 3ba7c3e | Ratio |
---|---|---|---|
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) |
7424670654 ns |
5787425685 ns |
1.28 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant |
5264966268 ns |
5292258390 ns |
0.99 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) |
5264317193 ns |
6086056532 ns |
0.86 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) |
7514686090 ns |
7587601119 ns |
0.99 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux |
33317417750 ns |
28087750784 ns |
1.19 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) |
1619340991 ns |
1563822331 ns |
1.04 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant |
1598882918 ns |
1543677512 ns |
1.04 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) |
1584562639 ns |
1553822136 ns |
1.02 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) |
3441713443 ns |
3309603029 ns |
1.04 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux |
2874671246 ns |
3236551447 ns |
0.89 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) |
2215854691 ns |
2198150190 ns |
1.01 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant |
2204071574 ns |
2155687426 ns |
1.02 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) |
2219471987 ns |
2192886728 ns |
1.01 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) |
4048913838 ns |
3908194881 ns |
1.04 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux |
5896617175 ns |
5993416352 ns |
0.98 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) |
1474593639 ns |
1406808783.5 ns |
1.05 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant |
1490931566 ns |
1407299141 ns |
1.06 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) |
1472645239 ns |
1410969730 ns |
1.04 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) |
3300872585 ns |
3156311368 ns |
1.05 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux |
1075581601 ns |
1099155376.5 ns |
0.98 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) |
1781738074 ns |
1727787162 ns |
1.03 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant |
1775692962 ns |
1727804980 ns |
1.03 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) |
1774448644 ns |
1711663111 ns |
1.04 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) |
3633624848 ns |
3460051766 ns |
1.05 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux |
3023174554 ns |
3010659432 ns |
1.00 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) |
2239670715 ns |
2148427239 ns |
1.04 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant |
2223280066 ns |
2170426380 ns |
1.02 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) |
2223870589 ns |
2187259107 ns |
1.02 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) |
4122664875 ns |
3958804601 ns |
1.04 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux |
5808293868 ns |
6647100753 ns |
0.87 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) |
3080854019 ns |
3146044029 ns |
0.98 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant |
3097572316 ns |
3146912971 ns |
0.98 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) |
3060124717 ns |
3047329260 ns |
1.00 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) |
5012791931 ns |
4862728550 ns |
1.03 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux |
17226984911 ns |
12794226734 ns |
1.35 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) |
3256938221 ns |
3132478421 ns |
1.04 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant |
3540998740 ns |
3179953038 ns |
1.11 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) |
3525641867 ns |
3185074336 ns |
1.11 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) |
5203636130 ns |
5092564084 ns |
1.02 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux |
13915624571 ns |
12253319305 ns |
1.14 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) |
2033777204 ns |
1855345054 ns |
1.10 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant |
2005229554 ns |
1849809131 ns |
1.08 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) |
1874882089 ns |
1855337197 ns |
1.01 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) |
3710154672 ns |
3604644289 ns |
1.03 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux |
3394500738 ns |
5868629461.5 ns |
0.58 |
This comment was automatically generated by workflow using github-action-benchmark.
function NNlib.gather!( | ||
dst::TracedRArray{T1,N}, src::AnyTracedRArray{T2,N}, idxs::AbstractArray | ||
) where {T1,T2,N} | ||
dims = NNlib.scatter_dims(src, dst, idxs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we emit a warning here at least in the interim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think this is the right way to go. Even for a small testcase (nanoGPT) it takes forever to compile. Let me try to understand the stablehlo gather and get it fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 , if you'd like feel free to separate the dynamic_slice stuff into a different PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optimized the common cases and printing a warning for the other cases.
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} | ||
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)] | ||
indices = map(enumerate(indices)) do (idx, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think this should be doable with gather. That part I'm less confident we have all the optimization rules to lower into dynamic slice
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
# instead of unrolling the loop (the case for AbstractArray can just use | ||
# `stablehlo.gather`). See above for the special case implementation that is optimized. | ||
function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractArray) | ||
@warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you'd like you can put this behind a global to make sure it's only printed once (I think other indexing does that).
Though also I'm fine with having it always warn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah that's what maxlog does!
yeah it looks like we never setup a 32bit jll anyways, but we should still try to make things match when possible |
…D#252) * feat: unbreak NNlib.gather * feat: use dynamic slicing * chore: apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * feat: add an overload of Base.Tuple * fix: ambiguity error * feat: special case `gather!` for the most common cases * feat: optimize the special case of indexing with unitranges * test: dynamic slice test * test: port NNlib gather tests over * chore: apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: mark length as Int64 * fix: use the C API for dimension numbers
ideally we should be using
stablehlo.gather
, but this at least lets the currently fallback workFeatures
gather!
andgather
. These are not the most efficient implementations but at least they work now.getindex
assumes contiguous indexing #242getindex
assumes static indexing #243. We emitstablehlo.dynamic_slice
now