Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update Embedding layer #1656

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient
using ChainRulesCore

export Chain, Dense, Maxout, SkipConnection, Parallel,
export Chain, Dense, Maxout, SkipConnection, Parallel, Embedding,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Expand Down
39 changes: 19 additions & 20 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ function Base.show(io::IO, m::Parallel)
end

"""
Embedding(in => out; init=randn)
Embedding(in => out; init=randn32)
Embedding(weight::AbstractMatrix)

A lookup table that stores embeddings of dimension `out`
for a vocabulary of size `in`.
Expand All @@ -493,41 +494,39 @@ The input to the layer can be either a vector of indexes
or the corresponding [onehot encoding](@ref Flux.OneHotArray).

# Examples
```jldoctest
julia> vocab_size, embed_size = 1000, 4;

julia> model = Flux.Embedding(vocab_size => embed_size)
Embedding(1000 => 4) # 4_000 parameters

julia> vocab_idxs = [1, 722, 53, 220, 3];
```jldoctest
julia> m = Embedding(reshape(-6:45, 2, 26) .+ 0.01f0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old example was much clearer.
This constructor (Embed(weight)) is not even part of the docstring, we should add it

Copy link
Member

@mcabbott mcabbott Feb 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed on the constructor.

The virtue of this example is that it doesn't have random numbers, so it can be a doctest. My hope is that onehotbatch("foo", 'a':'z') might connect with 26 well enough to be easy to follow. Maybe it can be made clearer somehow?

Embedding(26 => 2)

julia> x = Flux.OneHotMatrix(vocab_idxs, vocab_size); summary(x)
"1000×5 OneHotMatrix(::Vector{Int64}) with eltype Bool"
julia> m(5) # embedding vector for 5th element
2-element Vector{Float32}:
2.01
3.01

julia> model(x) |> summary
"4×5 Matrix{Float32}"
julia> m([6, 15, 15]) # applied to a batch
2×3 Matrix{Float32}:
4.01 22.01 22.01
5.01 23.01 23.01

julia> model(vocab_idxs) == model(x)
julia> ans == m(Flux.onehotbatch("foo", 'a':'z'))
true
```
"""
struct Embedding{W}
struct Embedding{W <: AbstractMatrix}
weight::W
end

@functor Embedding

Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))
Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the old constructor should be deprecated


(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
(m::Embedding)(x::AbstractVecOrMat{Bool}) = m.weight * x # handles OneHotLikeVector, OneHotLikeMatrix

function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
print(io, "Embedding($(size(m.weight, 2)) => $(size(m.weight, 1)))")
end
3 changes: 3 additions & 0 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}

const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T}
const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I}

_isonehot(x::OneHotArray) = true
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)

Expand Down
4 changes: 4 additions & 0 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,7 @@ for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
end
end
end

function NNlib.gather(src::AbstractArray{Tsrc, Nsrc}, idx::AbstractArray{<:Nil}) where {Tsrc, Nsrc}
fill(nil, (size(src)[1:Nsrc-1]..., size(idx)...))
end
14 changes: 7 additions & 7 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)

embedding = [Flux.Embedding]
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)
gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2)
gpu_gradtest("Embedding repeated indices", embedding, rand(1:10, 10^3), 10 => 2)
gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2)
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2)
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2)
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2)
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix(rand(1:10, 10^3), 10), 10 => 2)

@testset "function layers" begin
x = rand(Float32, 3,3)
Expand Down
11 changes: 7 additions & 4 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,17 @@ import Flux: activations

@testset "Embedding" begin
vocab_size, embed_size = 10, 4
m = Flux.Embedding(vocab_size, embed_size)
m = Flux.Embedding(vocab_size => embed_size)
@test size(m.weight) == (embed_size, vocab_size)

x = rand(1:vocab_size, 3)
y = m(x)
@test y isa Matrix{Float32}
@test y ≈ m.weight[:,x]
x2 = OneHotMatrix(x, vocab_size)
y2 = m(x2)
@test y2 isa Matrix{Float32}
@test y2 ≈ y
@test m(x2) isa Matrix{Float32}
@test m(x2) ≈ y
@test m(collect(x2)) ≈ y
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))

x = rand(1:vocab_size, 3, 4)
Expand All @@ -297,6 +297,9 @@ import Flux: activations
@test m(2) ≈ m.weight[:,2]
@test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))

x = onehotbatch(rand(1:vocab_size, 4, 3, 4, 5), 1:vocab_size)
@test m(x) ≈ m(onecold(x))
end
end

Expand Down
7 changes: 7 additions & 0 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,10 @@ end
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
end

@testset "embedding" begin
m = Embedding(3=>5)
@test outputsize(m, (2,)) == (5, 2)
@test outputsize(m, (2, 3)) == (5, 2, 3)
@test outputsize(m, (2, 3, 4)) == (5, 2, 3, 4)
end