Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onehotbatch performance #1844

Closed
ChristianMichelsen opened this issue Jan 22, 2022 · 1 comment · Fixed by #1861
Closed

onehotbatch performance #1844

ChristianMichelsen opened this issue Jan 22, 2022 · 1 comment · Fixed by #1861

Comments

@ChristianMichelsen
Copy link

After a discourse thread, I was recommended to create a github issue about the performance of onehotbatch. Let's use the following MWE:

using Flux
using BenchmarkTools

const bases_dna = ['A', 'C', 'G', 'T']

function ohe_custom(sequence)
    return collect(sequence) .== permutedims(bases_dna)
end

function ohe_flux(sequence)
    return Flux.onehotbatch(collect(sequence), bases_dna)
end

sequence = "CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA"

So right now the dimensions of the two functions are different (transposed), but that can easily be changed with a permutedims applied to any one of them, otherwise they return the same onehot encoded matrix. So far, so good. However, when benchmarking them, we find the following:

@btime ohe_custom(sequence);
# output: 550.514 ns (5 allocations: 464 bytes)

and

@btime ohe_flux(sequence);
# output: 69.274 μs (374 allocations: 17.30 KiB)

As we can see, ohe_flux is more than 100 times slower than ohe_custom and with 70 times more allocations.

Another minor detail is the size of the different outputs:

Base.summarysize(sequence)
# output: 54
Base.summarysize(ohe_custom(sequence))
# output: 96
Base.summarysize(ohe_flux(sequence))
# output: 232
@mcabbott
Copy link
Member

mcabbott commented Jan 22, 2022

That's pretty slow. The hope would be that this is still faster at *, which it is. But you'd have to do many of them to compensate for one construction.

# Construction:

julia> @btime ohe_custom(sequence);
  min 989.583 ns, mean 1.031 μs (5 allocations, 464 bytes)
46×4 BitMatrix:

julia> @btime ohe_flux(sequence);
  min 64.750 μs, mean 68.083 μs (374 allocations, 17.30 KiB)
4×46 OneHotMatrix(::Vector{UInt32}) with eltype Bool:

# Special path:

julia> @btime $(rand(10,4)) * $(ohe_flux(sequence));
  min 678.286 ns, mean 1.362 μs (1 allocation, 3.75 KiB)

julia> @btime $(rand(10,4)) * $(ohe_custom(sequence))';
  min 3.599 μs, mean 6.996 μs (4 allocations, 25.58 KiB)

# Not optimised:

julia> @btime Flux.logitcrossentropy($(rand(4, 46)), $(ohe_flux(sequence)));
  min 4.363 μs, mean 4.969 μs (9 allocations, 5.44 KiB)

julia> @btime Flux.logitcrossentropy($(rand(46, 4)), $(ohe_custom(sequence)));
  min 3.922 μs, mean 4.517 μs (9 allocations, 3.72 KiB)

# Alternatives:

julia> @btime viabool(collect($sequence), bases_dna);
  min 1.942 μs, mean 2.030 μs (7 allocations, 720 bytes)

julia> @btime unrolled(collect($sequence), bases_dna);
  min 913.395 ns, mean 988.052 ns (4 allocations, 528 bytes)

julia> @btime unrolled($sequence, $(Tuple(bases_dna)));
  min 610.000 ns, mean 647.950 ns (2 allocations, 256 bytes)

It's definitely possible to make a faster constructor for this. Two quick attempts are:

function viabool(data::AbstractVector, labels::AbstractVector)
    bool = data .== permutedims(labels)
    all(x -> count(x)==1, eachrow(bool)) || throw("some data out of labels")
    inds = bool * (UInt32(1):UInt32(length(labels)))
    Flux.OneHotArray(inds, length(labels))
end

unrolled(data, labels) = unrolled(data, Tuple(labels))
function unrolled(data, labels::Tuple)
    inds = map(x -> UInt32(_find(x, labels)), data)
    any(i -> i>length(labels), inds) && throw("some data out of labels")
    Flux.OneHotArray(inds, length(labels))
end
function unrolled(data, labels::Tuple)
    inds = [UInt32(_find(x, labels)) for x in data]
    # any(i -> i>length(labels), inds) && throw("some data out of labels")
    Flux.OneHotArray(inds, length(labels))
end

# _find(val, labels::Tuple, i::Integer=1) = val == first(labels) ? i : _find(val, Base.tail(labels), i+1)
_find(val, labels::Tuple, i::Integer=1) = ifelse(val == first(labels), i, _find(val, Base.tail(labels), i+1))
_find(val, labels::Tuple{}, i::Integer) = i+1

@btime findfirst(==('G'), bases_dna)    # 5ns
@btime _find('G', $(Tuple(bases_dna)))  # 2ns
@btime searchsortedfirst($bases_dna, 'G')  # 3ns

Making the labels into a tuple makes some sense in that the OneHotArray has their number in its type. But unrolling like this is likely to be awful with 100 labels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants