Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup and fix of multiplication by OneHotMatrix #1756

Merged
merged 22 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
da9fbba
adding commented out sparse optimizations
racinmat Oct 25, 2021
9427363
Speedup of multiplication by OneHotMatrix.
racinmat Oct 25, 2021
a1f35a2
return benchs
racinmat Oct 25, 2021
6d0468e
removing benchmark which should be in different repo
racinmat Oct 25, 2021
5c31069
removed unrelated code, fixed typo, fixed NEWS entry.
racinmat Oct 25, 2021
60d2516
defining special method in order to keep reshaped arrays untouched
racinmat Oct 25, 2021
6e79d64
Update test/onehot.jl
racinmat Oct 25, 2021
7268c89
Update test/cuda/cuda.jl
racinmat Oct 26, 2021
5f7ce6b
Update src/onehot.jl
racinmat Oct 26, 2021
931d409
Update test/cuda/cuda.jl
racinmat Oct 26, 2021
2507126
adding optimization of multiplication by adjoint
racinmat Oct 26, 2021
644bd1a
Merge branch 'master' of https://github.com/racinmat/Flux.jl
racinmat Oct 26, 2021
5ac7a3d
fixed dimension check, added tests to check different dimensionality
racinmat Oct 26, 2021
75f0e9c
Update src/onehot.jl
racinmat Oct 26, 2021
7ce132f
using gather for OneHotLike
racinmat Oct 26, 2021
f221ee9
Update src/onehot.jl
racinmat Oct 26, 2021
0a67472
using different method for onehot vector and onehot matrix
racinmat Oct 26, 2021
fcda965
Merge branch 'master' of https://github.com/racinmat/Flux.jl
racinmat Oct 26, 2021
6e2da25
returning to default implementation for onehot because I don't want t…
racinmat Oct 26, 2021
775efee
dispatching on OneHotLike of dimension 2
racinmat Oct 26, 2021
51b7c00
added many tests for reshaped matrices
racinmat Oct 26, 2021
35ab120
fixed tests, using only onehot for adjoint
racinmat Oct 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Flux Release Notes

## v0.12.8
* Optimized inference and gradient calculation of OneHotMatrix[pr](https://github.com/FluxML/Flux.jl/pull/1756)

## v0.12.7
* Added support for [`GRUv3`](https://github.com/FluxML/Flux.jl/pull/1675)
* The layers within `Chain` and `Parallel` may now [have names](https://github.com/FluxML/Flux.jl/issues/1680).
Expand Down
14 changes: 14 additions & 0 deletions src/onehot.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Adapt
import .CUDA
using LinearAlgebra, NNlib

"""
OneHotArray{T,L,N,M,I} <: AbstractArray{Bool,M}
Expand Down Expand Up @@ -224,6 +225,19 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
return A[:, onecold(B)]
end

function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
racinmat marked this conversation as resolved.
Show resolved Hide resolved
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
return NNlib.gather(A, _indices(B))
end

function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
B_dim = length(_indices(parent(B)))
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))
end

for wrapper in [:Adjoint, :Transpose]
@eval begin
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{<:Any, L}) where {L, T}
Expand Down
3 changes: 3 additions & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ end
@testset "onehot gpu" begin
y = Flux.onehotbatch(ones(3), 1:2) |> gpu;
@test (repr("text/plain", y); true)

gA = rand(3, 2) |> gpu;
@test gradient(A -> sum(A * y), gA)[1] isa CuArray
end

@testset "onecold gpu" begin
Expand Down
31 changes: 30 additions & 1 deletion test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end
b1 = Flux.OneHotVector(1, 3)
b2 = Flux.OneHotVector(3, 5)

@test A*b1 == A[:,1]
@test A * b1 == A[:,1]
@test b1' * A == Array(b1') * A
@test A' * b1 == A' * Array(b1)
@test v' * b2 == v' * Array(b2)
Expand All @@ -41,6 +41,35 @@ end
@test_throws DimensionMismatch A*b2
end

@testset "AbstractMatrix-OneHotMatrix multiplication" begin
A = [1 3 5; 2 4 6; 3 6 9]
v = [1, 2, 3, 4, 5]
X = reshape(v, (5, 1))
b1 = Flux.OneHotMatrix([1, 1, 2, 2], 3)
b2 = Flux.OneHotMatrix([2, 4, 1, 3], 5)
b3 = Flux.OneHotMatrix([1, 1, 2], 4)
b4 = reshape(Flux.OneHotMatrix([1 2 3; 2 2 1], 3), 3, :)
b5 = reshape(b4, 6, :)
b6 = reshape(Flux.OneHotMatrix([1 2 2; 2 2 1], 2), 3, :)
b7 = reshape(Flux.OneHotMatrix([1 2 3; 1 2 3], 3), 6, :)

@test A * b1 == A[:,[1, 1, 2, 2]]
@test b1' * A == Array(b1') * A
@test A' * b1 == A' * Array(b1)
@test A * b3' == A * Array(b3')
@test transpose(X) * b2 == transpose(X) * Array(b2)
@test A * b4 == A[:,[1, 2, 2, 2, 3, 1]]
@test A * b5' == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
@test A * b6 == hcat(A[:,1], 2*A[:,2], A[:,2], A[:,1]+A[:,2])
@test A * b7' == A[:,[1, 2, 3, 1, 2, 3]]

@test_throws DimensionMismatch A*b1'
@test_throws DimensionMismatch A*b2
@test_throws DimensionMismatch A*b2'
@test_throws DimensionMismatch A*b6'
@test_throws DimensionMismatch A*b7
end

@testset "OneHotArray" begin
using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike

Expand Down