Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Embedding layer #1516

Merged
merged 14 commits into from
Jul 13, 2021
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Flux Release Notes

## v0.12.4
* Implemented an [`Embedding layer`](https://github.com/FluxML/Flux.jl/pull/1516)
based on `NNlib.gather` and `NNlib.scatter`.

## v0.12.1 - v0.12.3

* CUDA.jl 3.0 support
* Bug fixes and optimizations.

## v0.12.0

* Add [identity_init](https://github.com/FluxML/Flux.jl/pull/1524).
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.12.4"
version = "0.12.5"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -37,7 +37,7 @@ Colors = "0.12"
Functors = "0.2.1"
Juno = "0.8"
MacroTools = "0.5"
NNlib = "0.7.14"
NNlib = "0.7.24"
NNlibCUDA = "0.1"
Reexport = "0.2, 1.0"
StatsBase = "0.33"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ If you define a structured model, like a `Dense` layer or `Chain`, you just need
```julia
d = Dense(10, 5, σ)
d = fmap(cu, d)
d.W # CuArray
d.weight # CuArray
d(cu(rand(10))) # CuArray output

m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ by simply deleting it from `ps`:

```julia
ps = params(m)
delete!(ps, m[2].b)
delete!(ps, m[2].bias)
```

## Custom multiple input or output layer
Expand Down
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ SkipConnection
Parallel
Flux.Bilinear
Flux.Diagonal
Flux.Embedding
```

## Normalisation & Regularisation
Expand Down
7 changes: 7 additions & 0 deletions docs/src/models/nnlib.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,10 @@ NNlib.batched_mul!
NNlib.batched_adjoint
NNlib.batched_transpose
```

## Gather and Scatter

```@docs
NNlib.gather
NNlib.scatter
```
38 changes: 19 additions & 19 deletions docs/src/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Here's how you'd use Flux to build and train the most basic of models, step by s

This example will predict the output of the function `4x + 2`. First, import `Flux` and define the function we want to simulate:

```
```julia
julia> using Flux

julia> actual(x) = 4x + 2
Expand All @@ -28,7 +28,7 @@ This example will build a model to approximate the `actual` function.

Use the `actual` function to build sets of data for training and verification:

```
```julia
julia> x_train, x_test = hcat(0:5...), hcat(6:10...)
([0 1 … 4 5], [6 7 … 9 10])

Expand All @@ -42,38 +42,38 @@ Normally, your training and test data come from real world observations, but thi

Now, build a model to make predictions with `1` input and `1` output:

```
```julia
julia> model = Dense(1, 1)
Dense(1, 1)

julia> model.W
1-element Array{Float64,1}:
-0.99009055
julia> model.weight
1×1 Matrix{Float32}:
-1.4925033
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

julia> model.b
1-element Array{Float64,1}:
julia> model.bias
1-element Vector{Float32}:
0.0
```

Under the hood, a dense layer is a struct with fields `W` and `b`. `W` represents a weight and `b` represents a bias. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:
Under the hood, a dense layer is a struct with fields `weight` and `bias`. `weight` represents a weights' matrix and `bias` represents a bias vector. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:

```
```julia
julia> predict = Dense(1, 1)
```

`Dense(1, 1)` also implements the function `σ(Wx+b)` where `W` and `b` are the weights and biases. `σ` is an activation function (more on activations later). Our model has one weight and one bias, but typical models will have many more. Think of weights and biases as knobs and levers Flux can use to tune predictions. Activation functions are transformations that tailor models to your needs.

This model will already make predictions, though not accurate ones yet:

```
```julia
julia> predict(x_train)
1×6 Array{Float32,2}:
-1.98018 -5.94054 -9.90091 -13.8613 -17.8216 -21.782
1×6 Matrix{Float32}:
0.0 -1.4925 -2.98501 -4.47751 -5.97001 -7.46252
```

In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.

```
```julia
julia> loss(x, y) = Flux.Losses.mse(predict(x), y)
loss (generic function with 1 method)

Expand All @@ -87,7 +87,7 @@ More accurate predictions will yield a lower loss. You can write your own loss f

Under the hood, the Flux [`train!`](@ref) function uses *a loss function* and *training data* to improve the *parameters* of your model based on a pluggable [`optimiser`](../training/optimisers.md):

```
```julia
julia> using Flux: train!

julia> opt = Descent()
Expand All @@ -100,12 +100,12 @@ julia> data = [(x_train, y_train)]

Now, we have the optimiser and data we'll pass to `train!`. All that remains are the parameters of the model. Remember, each model is a Julia struct with a function and configurable parameters. Remember, the dense layer has weights and biases that depend on the dimensions of the inputs and outputs:

```
julia> predict.W
```julia
julia> predict.weight
1-element Array{Float64,1}:
-0.99009055

julia> predict.b
julia> predict.bias
1-element Array{Float64,1}:
0.0
```
Expand All @@ -120,7 +120,7 @@ Params([[-0.99009055], [0.0]])
These are the parameters Flux will change, one step at a time, to improve predictions. Each of the parameters comes from the `predict` model:

```
julia> predict.W in parameters, predict.b in parameters
julia> predict.weight in parameters, predict.bias in parameters
(true, true)

```
Expand Down
4 changes: 2 additions & 2 deletions docs/src/models/regularisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ m = Dense(10, 5)
loss(x, y) = logitcrossentropy(m(x), y)
```

We can apply L2 regularisation by taking the squared norm of the parameters , `m.W` and `m.b`.
We can apply L2 regularisation by taking the squared norm of the parameters , `m.weight` and `m.bias`.

```julia
penalty() = sum(abs2, m.W) + sum(abs2, m.b)
penalty() = sum(abs2, m.weight) + sum(abs2, m.bias)
loss(x, y) = logitcrossentropy(m(x), y) + penalty()
```

Expand Down
58 changes: 58 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on a given input.
`m[1:3](x)` will calculate the output of the first three layers.

# Examples

```jldoctest
julia> m = Chain(x -> x^2, x -> x+1);

Expand Down Expand Up @@ -428,3 +429,60 @@ function Base.show(io::IO, m::Parallel)
join(io, m.layers, ", ")
print(io, ")")
end

"""
Embedding(in, out; init=randn)

A lookup table that stores embeddings of dimension `out`
for a vocabulary of size `in`.

This layers is often used to store word embeddings and retrieve them using indices.
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
The input to the layer can be either a vector of indexes
or the corresponding [onehot encoding](@ref Flux.OneHotArray).
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

# Examples

```julia-repl
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
julia> using Flux: Embedding

julia> vocab_size, embed_size = 1000, 4;

julia> model = Embedding(vocab_size, embed_size)
Embedding(1000, 4)

julia> vocab_idxs = [1, 722, 53, 220, 3]

julia> x = OneHotMatrix(vocab_idxs, vocab_size);

julia> model(x)
4×5 Matrix{Float32}:
0.91139 0.670462 0.463217 0.670462 0.110932
0.247225 -0.0823874 0.698694 -0.0823874 0.945958
-0.393626 -0.590136 -0.545422 -0.590136 0.77743
-0.497621 0.87595 -0.870251 0.87595 -0.772696
```

julia> model(vocab_idxs) == model(x)
true
"""
struct Embedding{W}
weight::W
end

@functor Embedding

Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved


(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)

function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end

function Base.show(io::IO, m::Embedding)
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
end
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r
```jldoctest
julia> layer = Dense(10, 20);

julia> Flux.nfan(size(layer.W))
julia> Flux.nfan(size(layer.weight))
(10, 20)

julia> layer = Conv((3, 3), 2=>10);
Expand Down Expand Up @@ -368,6 +368,8 @@ identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identit

ones32(dims...) = Base.ones(Float32, dims...)
zeros32(dims...) = Base.zeros(Float32, dims...)
rand32(dims...) = Base.rand(Float32, dims...)
randn32(dims...) = Base.randn(Float32, dims...)

"""
create_bias(weights, bias, length)
Expand Down
29 changes: 23 additions & 6 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,21 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te
# test
if test_cpu
@test y_gpu ≈ y_cpu rtol=1f-3 atol=1f-3
@test Array(xg_gpu) ≈ xg_cpu rtol=1f-3 atol=1f-3
if isnothing(xg_cpu)
@test isnothing(xg_gpu)
else
@test Array(xg_gpu) ≈ xg_cpu rtol=1f-3 atol=1f-3
end
end
@test gs_gpu isa Flux.Zygote.Grads
for (p_cpu, p_gpu) in zip(ps_cpu, ps_gpu)
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
if test_cpu
@test Array(gs_gpu[p_gpu]) ≈ gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
if isnothing(gs_cpu[p_cpu])
@test isnothing(gs_gpu[p_gpu])
else
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
if test_cpu
@test Array(gs_gpu[p_gpu]) ≈ gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
end
end
end
end
Expand Down Expand Up @@ -114,6 +122,15 @@ pixelshuffle = [PixelShuffle]
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)

embedding = [Flux.Embedding]
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)

@testset "function layers" begin
x = rand(Float32, 3,3)
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
Expand All @@ -135,12 +152,12 @@ end
end

@testset "Dense with Zeros bias" begin
l = Dense(ones(Float32, 4,3), Flux.Zeros()) |> gpu
l = Dense(ones(Float32, 4, 3), Flux.Zeros()) |> gpu
ip = zeros(Float32, 3, 7) |> gpu

@test sum(l(ip)) ≈ 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b ∉ gs.params
@test l.bias ∉ gs.params
end

@testset "Extended BatchNorm" begin
Expand Down
25 changes: 25 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,29 @@ import Flux: activations
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
end
end

@testset "Embedding" begin
vocab_size, embed_size = 10, 4
m = Flux.Embedding(vocab_size, embed_size)
@test size(m.weight) == (embed_size, vocab_size)

x = rand(1:vocab_size, 3)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
y = m(x)
@test y isa Matrix{Float32}
@test y ≈ m.weight[:,x]
x2 = OneHotMatrix(x, vocab_size)
y2 = m(x2)
@test y2 isa Matrix{Float32}
@test y2 ≈ y
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

x = rand(1:vocab_size, 3, 4)
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (embed_size, 3, 4)

@test m(2) ≈ m.weight[:,2]
@test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
end
end
Loading