Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: partial NNlib.gather support + better indexing support #252

Merged
merged 12 commits into from
Nov 10, 2024
Merged

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Nov 9, 2024

ideally we should be using stablehlo.gather, but this at least lets the currently fallback work

Features

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 041cdc1 Previous: 3ba7c3e Ratio
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 7424670654 ns 5787425685 ns 1.28
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5264966268 ns 5292258390 ns 0.99
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5264317193 ns 6086056532 ns 0.86
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7514686090 ns 7587601119 ns 0.99
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 33317417750 ns 28087750784 ns 1.19
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1619340991 ns 1563822331 ns 1.04
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1598882918 ns 1543677512 ns 1.04
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1584562639 ns 1553822136 ns 1.02
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3441713443 ns 3309603029 ns 1.04
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2874671246 ns 3236551447 ns 0.89
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2215854691 ns 2198150190 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2204071574 ns 2155687426 ns 1.02
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2219471987 ns 2192886728 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4048913838 ns 3908194881 ns 1.04
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 5896617175 ns 5993416352 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1474593639 ns 1406808783.5 ns 1.05
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1490931566 ns 1407299141 ns 1.06
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1472645239 ns 1410969730 ns 1.04
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3300872585 ns 3156311368 ns 1.05
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1075581601 ns 1099155376.5 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1781738074 ns 1727787162 ns 1.03
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1775692962 ns 1727804980 ns 1.03
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1774448644 ns 1711663111 ns 1.04
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3633624848 ns 3460051766 ns 1.05
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3023174554 ns 3010659432 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2239670715 ns 2148427239 ns 1.04
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2223280066 ns 2170426380 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2223870589 ns 2187259107 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 4122664875 ns 3958804601 ns 1.04
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 5808293868 ns 6647100753 ns 0.87
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3080854019 ns 3146044029 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3097572316 ns 3146912971 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3060124717 ns 3047329260 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 5012791931 ns 4862728550 ns 1.03
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 17226984911 ns 12794226734 ns 1.35
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3256938221 ns 3132478421 ns 1.04
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3540998740 ns 3179953038 ns 1.11
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3525641867 ns 3185074336 ns 1.11
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5203636130 ns 5092564084 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 13915624571 ns 12253319305 ns 1.14
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 2033777204 ns 1855345054 ns 1.10
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 2005229554 ns 1849809131 ns 1.08
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1874882089 ns 1855337197 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3710154672 ns 3604644289 ns 1.03
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3394500738 ns 5868629461.5 ns 0.58

This comment was automatically generated by workflow using github-action-benchmark.

Base automatically changed from ap/causal_mask to main November 9, 2024 18:39
function NNlib.gather!(
dst::TracedRArray{T1,N}, src::AnyTracedRArray{T2,N}, idxs::AbstractArray
) where {T1,T2,N}
dims = NNlib.scatter_dims(src, dst, idxs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we emit a warning here at least in the interim?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this is the right way to go. Even for a small testcase (nanoGPT) it takes forever to compile. Let me try to understand the stablehlo gather and get it fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 , if you'd like feel free to separate the dynamic_slice stuff into a different PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimized the common cases and printing a warning for the other cases.

src/TracedRArray.jl Outdated Show resolved Hide resolved
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)]
indices = map(enumerate(indices)) do (idx, i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think this should be doable with gather. That part I'm less confident we have all the optimization rules to lower into dynamic slice

@avik-pal avik-pal marked this pull request as ready for review November 10, 2024 04:20
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@avik-pal avik-pal changed the title feat: temporarily unbreak NNlib.gather feat: partial NNlib.gather support + better indexing support Nov 10, 2024
@avik-pal avik-pal mentioned this pull request Nov 9, 2024
34 tasks
@avik-pal avik-pal requested a review from wsmoses November 10, 2024 04:48
ext/ReactantNNlibExt.jl Outdated Show resolved Hide resolved
# instead of unrolling the loop (the case for AbstractArray can just use
# `stablehlo.gather`). See above for the special case implementation that is optimized.
function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractArray)
@warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you'd like you can put this behind a global to make sure it's only printed once (I think other indexing does that).

Though also I'm fine with having it always warn

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's what maxlog does!

src/TracedRArray.jl Outdated Show resolved Hide resolved
src/TracedRArray.jl Outdated Show resolved Hide resolved
@wsmoses
Copy link
Member

wsmoses commented Nov 10, 2024

yeah it looks like we never setup a 32bit jll anyways, but we should still try to make things match when possible

@avik-pal avik-pal merged commit 9d666f8 into main Nov 10, 2024
21 of 33 checks passed
@avik-pal avik-pal deleted the ap/gather branch November 10, 2024 17:16
Pangoraw pushed a commit to Pangoraw/Reactant.jl that referenced this pull request Nov 11, 2024
…D#252)

* feat: unbreak NNlib.gather

* feat: use dynamic slicing

* chore: apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* feat: add an overload of Base.Tuple

* fix: ambiguity error

* feat: special case `gather!` for the most common cases

* feat: optimize the special case of indexing with unitranges

* test: dynamic slice test

* test: port NNlib gather tests over

* chore: apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix: mark length as Int64

* fix: use the C API for dimension numbers
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

getindex assumes static indexing
2 participants