-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
onehotbatch performance #1844
Labels
Comments
That's pretty slow. The hope would be that this is still faster at # Construction:
julia> @btime ohe_custom(sequence);
min 989.583 ns, mean 1.031 μs (5 allocations, 464 bytes)
46×4 BitMatrix:
julia> @btime ohe_flux(sequence);
min 64.750 μs, mean 68.083 μs (374 allocations, 17.30 KiB)
4×46 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
# Special path:
julia> @btime $(rand(10,4)) * $(ohe_flux(sequence));
min 678.286 ns, mean 1.362 μs (1 allocation, 3.75 KiB)
julia> @btime $(rand(10,4)) * $(ohe_custom(sequence))';
min 3.599 μs, mean 6.996 μs (4 allocations, 25.58 KiB)
# Not optimised:
julia> @btime Flux.logitcrossentropy($(rand(4, 46)), $(ohe_flux(sequence)));
min 4.363 μs, mean 4.969 μs (9 allocations, 5.44 KiB)
julia> @btime Flux.logitcrossentropy($(rand(46, 4)), $(ohe_custom(sequence)));
min 3.922 μs, mean 4.517 μs (9 allocations, 3.72 KiB)
# Alternatives:
julia> @btime viabool(collect($sequence), bases_dna);
min 1.942 μs, mean 2.030 μs (7 allocations, 720 bytes)
julia> @btime unrolled(collect($sequence), bases_dna);
min 913.395 ns, mean 988.052 ns (4 allocations, 528 bytes)
julia> @btime unrolled($sequence, $(Tuple(bases_dna)));
min 610.000 ns, mean 647.950 ns (2 allocations, 256 bytes) It's definitely possible to make a faster constructor for this. Two quick attempts are: function viabool(data::AbstractVector, labels::AbstractVector)
bool = data .== permutedims(labels)
all(x -> count(x)==1, eachrow(bool)) || throw("some data out of labels")
inds = bool * (UInt32(1):UInt32(length(labels)))
Flux.OneHotArray(inds, length(labels))
end
unrolled(data, labels) = unrolled(data, Tuple(labels))
function unrolled(data, labels::Tuple)
inds = map(x -> UInt32(_find(x, labels)), data)
any(i -> i>length(labels), inds) && throw("some data out of labels")
Flux.OneHotArray(inds, length(labels))
end
function unrolled(data, labels::Tuple)
inds = [UInt32(_find(x, labels)) for x in data]
# any(i -> i>length(labels), inds) && throw("some data out of labels")
Flux.OneHotArray(inds, length(labels))
end
# _find(val, labels::Tuple, i::Integer=1) = val == first(labels) ? i : _find(val, Base.tail(labels), i+1)
_find(val, labels::Tuple, i::Integer=1) = ifelse(val == first(labels), i, _find(val, Base.tail(labels), i+1))
_find(val, labels::Tuple{}, i::Integer) = i+1
@btime findfirst(==('G'), bases_dna) # 5ns
@btime _find('G', $(Tuple(bases_dna))) # 2ns
@btime searchsortedfirst($bases_dna, 'G') # 3ns Making the labels into a tuple makes some sense in that the OneHotArray has their number in its type. But unrolling like this is likely to be awful with 100 labels. |
Merged
3 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
After a discourse thread, I was recommended to create a github issue about the performance of
onehotbatch
. Let's use the following MWE:So right now the dimensions of the two functions are different (transposed), but that can easily be changed with a
permutedims
applied to any one of them, otherwise they return the same onehot encoded matrix. So far, so good. However, when benchmarking them, we find the following:and
As we can see,
ohe_flux
is more than 100 times slower thanohe_custom
and with 70 times more allocations.Another minor detail is the size of the different outputs:
The text was updated successfully, but these errors were encountered: