Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Embedding #2084

Merged
merged 3 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -644,35 +644,45 @@ function Base.show(io::IO, m::PairwiseFusion)
end

"""
Embedding(in => out; init=randn)
Embedding(in => out; init=randn32)

A lookup table that stores embeddings of dimension `out`
for a vocabulary of size `in`.
for a vocabulary of size `in`, as a trainable matrix.

This layer is often used to store word embeddings and retrieve them using indices.
The input to the layer can be either a vector of indexes
or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).
The input to the layer can be a vocabulary index in `1:in`, an array of indices,
or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).

For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.
For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`.

# Examples
```jldoctest
julia> vocab_size, embed_size = 1000, 4;

julia> model = Flux.Embedding(vocab_size => embed_size)
Embedding(1000 => 4) # 4_000 parameters

julia> vocab_idxs = [1, 722, 53, 220, 3];

julia> x = Flux.onehotbatch(vocab_idxs, 1:vocab_size); summary(x)
"1000×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool"

julia> model(x) |> summary
"4×5 Matrix{Float32}"

julia> model(vocab_idxs) == model(x)
julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22))
Embedding(26 => 4) # 104 parameters

julia> emb(2) # one column of e.weight (here not random!)
4-element Vector{Float32}:
0.0
22.0
0.0
0.0

julia> emb([3, 1, 20, 14, 4, 15, 7]) # vocabulary indices, in 1:26
4×7 Matrix{Float32}:
0.0 22.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
22.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 22.0 0.0 0.0

julia> ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n'))
true

julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions
(4, 10, 1, 12)
```
"""
struct Embedding{W}
struct Embedding{W<:AbstractMatrix}
weight::W
end

Expand All @@ -684,10 +694,9 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)

function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end
(m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector
(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix
Comment on lines +697 to +698
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could instead call Flux.onecold. The result will differ on e.g. [true, true, false], not sure we care too much either way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For performance in the one hot case? If it's onecold-compatible, then folks should use OneHotArray for performance. At least with *, we do the mathematically expected operation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For OneHotArray these should be identical, right? Result and performance.

For a one-hot BitArray, the results will agree. I would guess that onecold is faster but haven't checked.

For a generic BitArray, I'm not sure which is mathematically expected really. I think you're saying that * is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, what you wrote is what I meant re: performance. I was adding that in the one-hot bit array case, we can direct people to OneHotArray if their concern is performance.

Yeah whenever I've come across this type of operation in papers, I see it written as *. There's an implicit assumption that x is one-hot, so maybe onecold could be better here if it were made to error for [true, true, false], etc. But I think silently choosing the first "hot" index is wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Mixing two embedding vectors seems less wrong. But probably nobody ever hits this & it's just a way to decouple from OneHotArray types. I don't think we should document that boolean indexing is an option.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think we are happy with the current implementation in the PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think so.

I see we had a very similar discussion in #1656 (comment) BTW, I forgot... but same conclusion.

(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
Expand Down
20 changes: 15 additions & 5 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,17 @@ import Flux: activations

@testset "Embedding" begin
vocab_size, embed_size = 10, 4
m = Flux.Embedding(vocab_size, embed_size)
m = Embedding(vocab_size, embed_size)
@test size(m.weight) == (embed_size, vocab_size)

# one index
@test m(1) isa Vector{Float32}
@test m(2) ≈ m.weight[:,2]
@test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
@test m(4) ≈ m((1:vocab_size) .== 4)

# a batch of indices
x = rand(1:vocab_size, 3)
y = m(x)
@test y isa Matrix{Float32}
Expand All @@ -301,15 +309,17 @@ import Flux: activations
@test y2 isa Matrix{Float32}
@test y2 ≈ y
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
@test y ≈ m(x' .== (1:vocab_size))

# more dimensions via reshape
x = rand(1:vocab_size, 3, 4)
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (embed_size, 3, 4)

@test m(2) ≈ m.weight[:,2]
@test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
x3 = onehotbatch(x, 1:1:vocab_size)
@test size(x3) == (vocab_size, 3, 4)
y3 = m(x3)
@test size(y3) == (embed_size, 3, 4)
end
end

Expand Down