-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom stacking for StaticArrays #564
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #564 +/- ##
=======================================
Coverage 98.00% 98.00%
=======================================
Files 106 108 +2
Lines 4808 4812 +4
=======================================
+ Hits 4712 4716 +4
Misses 96 96
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
This PR adds an extension in order to have these two paths, function DI.stack_vec_col(t::NTuple{B,<:SArray}) where {B}
return hcat(map(vec, t)...)
end
stack_vec_col(t::NTuple) = stack(vec, t; dims=2) Is this clearly better than just always using julia> tm = ntuple(i -> fill(i,2,3), 10);
julia> @btime stack(vec, $tm; dims=1);
142.494 ns (12 allocations: 912 bytes)
julia> @btime hcat(map(vec, $tm)...);
101.484 ns (12 allocations: 912 bytes) |
It's clearly better in the case of static arrays: julia> using StaticArrays, BenchmarkTools
julia> ts = ntuple(i -> @SMatrix(ones(2,3)), 10);
julia> @btime stack(vec, $ts; dims=2);
311.713 ns (1 allocation: 544 bytes)
julia> @btime hcat(map(vec, $ts)...);
7.246 ns (0 allocations: 0 bytes) |
Sorry I read your comment the wrong way. I did some more thorough benchmarks in this issue and |
Wow that's quite hard to decode. (Probably I should have used julia> tm = ntuple(i -> rand(100, 100), 10);
julia> res1 = @btime stack(vec, $tm); # called "bad stack, function" at link
14.250 μs (13 allocations: 781.64 KiB)
julia> res2 = @btime stack(vec, $tm, dims=2); # version with dims as in PR
14.250 μs (13 allocations: 781.64 KiB)
julia> res3 = @btime hcat(map(vec, $tm)...); # called "good stack, function" at link
64.416 μs (13 allocations: 781.64 KiB)
julia> res1 == res2 == res3
true Whether that's true in general I don't know, perhaps I'm somewhat surprised. And if it is, whether it's worth the complexity is your call. Note that julia> myvec(x) = Base.ReshapedArray(x, (length(x),), ()); # not sure 3rd argument is optimal!
julia> res4 = @btime hcat(map(myvec, $tm)...); # faster
38.792 μs (3 allocations: 781.33 KiB)
julia> @btime stack(myvec, $tm); # slower
22.084 μs (3 allocations: 781.33 KiB)
julia> res1 == res4
true |
In the current state of things, By the way, any ideas on how to implement
Wow, I didn't know that. I naively thought that a Matrix was just a vector in a trench coat, hence this would be free. |
Partial answer to #563
Related:
Versions
DI core
t
into Jacobian/Hessian blocksDI extensions
StaticArrays (new extension):
t::NTuple{_,<:SArray}
by usinghcat(map(vec, t)...)
instead ofstack(vec, t)
. Benchmarks show a x2 speedup on the example from Add direct Enzyme support SciML/NonlinearSolve.jl#476, and now each block is anSMatrix
.