Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match layer output to weights #2156

Merged
merged 6 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/src/tutorials/linear_regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ Let's start by initializing our dataset. We will be using the [`BostonHousing`](
julia> dataset = BostonHousing();

julia> x, y = BostonHousing(as_df=false)[:];

julia> x, y = Float32.(x), Float32.(y);
```

We can now split the obtained data into training and testing data -
Expand All @@ -287,7 +289,7 @@ This data contains a diverse number of features, which means that the features h

```jldoctest linear_regression_complex; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> std(x_train)
134.06784844377117
134.06786f0
```

The data is indeed not normalised. We can use the [`Flux.normalise`](@ref) function to normalise the training data.
Expand All @@ -296,7 +298,7 @@ The data is indeed not normalised. We can use the [`Flux.normalise`](@ref) funct
julia> x_train_n = Flux.normalise(x_train);

julia> std(x_train_n)
1.0000843694328236
1.0000844f0
```

The standard deviation is now close to one! Our data is ready!
Expand All @@ -318,7 +320,7 @@ julia> function loss(model, x, y)
end;

julia> loss(model, x_train_n, y_train)
676.165591625047
676.1656f0
```

We can now proceed to the training phase!
Expand Down Expand Up @@ -363,7 +365,7 @@ Let's have a look at the loss -

```jldoctest linear_regression_complex; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> loss(model, x_train_n, y_train)
27.127200028562164
27.1272f0
```

The loss went down significantly! It can be minimized further by choosing an even smaller `δ`.
Expand All @@ -376,7 +378,7 @@ The last step of this tutorial would be to test our model using the testing data
julia> x_test_n = Flux.normalise(x_test);

julia> loss(model, x_test_n, y_test)
66.91014769713368
66.91015f0
```

The loss is not as small as the loss of the training data, but it looks good! This also shows that our model is not overfitting!
Expand Down
17 changes: 9 additions & 8 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ true

julia> m = Chain(Dense(10 => 5, tanh), Dense(5 => 2));

julia> x = rand(10, 32);
julia> x = rand32(10, 32);

julia> m(x) == m[2](m[1](x))
true
Expand Down Expand Up @@ -132,11 +132,11 @@ The weight matrix and/or the bias vector (of length `out`) may also be provided
julia> d = Dense(5 => 2)
Dense(5 => 2) # 12 parameters

julia> d(rand(Float32, 5, 64)) |> size
julia> d(rand32(5, 64)) |> size
(2, 64)

julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimensions
(2, 1, 1, 64)
julia> d(rand32(5, 6, 4, 64)) |> size # treated as three batch dimensions
(2, 6, 4, 64)

julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix
Dense(5 => 2, tanh; bias=false) # 10 parameters
Expand Down Expand Up @@ -169,7 +169,8 @@ end

function (a::Dense)(x::AbstractVecOrMat)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ.(a.weight * x .+ a.bias)
xT = _match_eltype(a, x) # fixes Float64 input, etc.
return σ.(a.weight * xT .+ a.bias)
end

(a::Dense)(x::AbstractArray) =
Expand Down Expand Up @@ -475,7 +476,7 @@ julia> model = Chain(Dense(3 => 5),
Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))),
Dense(8 => 17));

julia> model(rand(3)) |> size
julia> model(rand32(3)) |> size
(17,)

julia> model2 = Parallel(+; α = Dense(10, 2, tanh), β = Dense(5, 2))
Expand All @@ -485,10 +486,10 @@ Parallel(
β = Dense(5 => 2), # 12 parameters
) # Total: 4 arrays, 34 parameters, 392 bytes.

julia> model2(rand(10), rand(5)) |> size
julia> model2(rand32(10), rand32(5)) |> size
(2,)

julia> model2[:α](rand(10)) |> size
julia> model2[:α](rand32(10)) |> size
(2,)

julia> model2[:β] == model2[2]
Expand Down
15 changes: 9 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ See also [`Conv`](@ref), [`MaxPool`](@ref).

# Examples
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images
julia> xs = rand32(100, 100, 3, 50); # a batch of images

julia> layer = Conv((2,2), 3 => 7, pad=SamePad())
Conv((2, 2), 3 => 7, pad=(1, 0, 1, 0)) # 91 parameters
Expand Down Expand Up @@ -96,7 +96,7 @@ See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref).

# Examples
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images
julia> xs = rand32(100, 100, 3, 50); # a batch of 50 RGB images

julia> layer = Conv((5,5), 3 => 7, relu; bias = false)
Conv((5, 5), 3 => 7, relu, bias=false) # 525 parameters
Expand Down Expand Up @@ -197,7 +197,8 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
function (c::Conv)(x::AbstractArray)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_dims(c, x)
σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
xT = _match_eltype(c, x)
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
end

_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
Expand Down Expand Up @@ -237,7 +238,7 @@ See also [`Conv`](@ref) for more detailed description of keywords.

# Examples
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images
julia> xs = rand32(100, 100, 3, 50); # a batch of 50 RGB images

julia> layer = ConvTranspose((5,5), 3 => 7, relu)
ConvTranspose((5, 5), 3 => 7, relu) # 532 parameters
Expand Down Expand Up @@ -330,7 +331,8 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
function (c::ConvTranspose)(x::AbstractArray)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_transpose_dims(c, x)
σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
xT = _match_eltype(c, x)
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
end

function Base.show(io::IO, l::ConvTranspose)
Expand Down Expand Up @@ -468,7 +470,8 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
function (c::CrossCor)(x::AbstractArray)
σ = NNlib.fast_act(c.σ, x)
cdims = crosscor_dims(c, x)
σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
xT = _match_eltype(c, x)
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
end

function Base.show(io::IO, l::CrossCor)
Expand Down
10 changes: 5 additions & 5 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ julia> m(ones(2, 7)) # test mode, no effect
2.0 2.0 2.0 2.0 2.0 2.0 2.0
2.0 2.0 2.0 2.0 2.0 2.0 2.0

julia> Flux.trainmode!(m); # would happen within gradient
julia> Flux.trainmode!(m); # equivalent to use within gradient

julia> m(ones(2, 7))
3×7 Matrix{Float64}:
Expand All @@ -48,11 +48,11 @@ julia> y = m(ones(2, 10_000));

julia> using Statistics

julia> mean(y) # is about 2.0, as for test mode
1.9892222222222182
julia> mean(y) # is about 2.0, same as in test mode
1.9989999999999961

julia> mean(iszero, y) # is about 0.4
0.40323333333333333
0.4003
```
"""
mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
Expand Down Expand Up @@ -96,7 +96,7 @@ Does nothing to the input once [`testmode!`](@ref) is true.
```jldoctest
julia> using Statistics

julia> x = randn(1000,1);
julia> x = randn32(1000,1);

julia> m = Chain(Dense(1000 => 1000, selu), AlphaDropout(0.2));

Expand Down
20 changes: 12 additions & 8 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,11 @@ end
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T}
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any risk of us having to handle Duals or such here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I forgot, the intention here is that an RNNCell{F,I,H,V,<:AbstractMatrix{Float32}} should accept also accept Matrix{Float64} etc, which we are sure that _match_eltype will convert to Float32, and hence not promote the type of the hidden state. It does not accept Matrix{Nil}, nor dual numbers.

For the old code to accept Dual, the weight matrix would have to also be duals with the same tag parameter. It's hard to imagine how that would happen. So I don't think that this will break anything.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but wouldn't the weights be Duals when doing nested diff with Zygote + ForwardDiff? Or would the tags differ there too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the tags would in fact match. Say destructure builds the network (all weights have same Dual type) and then data passes through an earlier layer, then it would arrive here with the same tag.

I'm not sure this would work at all. Chunked mode ForwardDiff is going to call the model several times, which sounds likely to screw up anything RNN.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be easy to just revert this part of the PR if we run into issues? If so, we can at least try it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Could have precisely old and new signatures, calling the same internal function.

Wi, Wh, b = m.Wi, m.Wh, m.b
σ = NNlib.fast_act(m.σ, x)
h = σ.(Wi*x .+ Wh*h .+ b)
xT = _match_eltype(m, T, x)
h = σ.(Wi*xT .+ Wh*h .+ b)
return h, reshape_cell_output(h, x)
end

Expand Down Expand Up @@ -305,9 +306,10 @@ function LSTMCell((in, out)::Pair;
return cell
end

function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
b, o = m.b, size(h, 1)
g = muladd(m.Wi, x, muladd(m.Wh, h, b))
xT = _match_eltype(m, T, x)
g = muladd(m.Wi, xT, muladd(m.Wh, h, b))
input, forget, cell, output = multigate(g, o, Val(4))
c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
h′ = @. sigmoid_fast(output) * tanh_fast(c′)
Expand Down Expand Up @@ -376,9 +378,10 @@ end
GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))

function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
xT = _match_eltype(m, T, x)
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
r, z = _gru_output(gxs, ghs, bs)
h̃ = @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
h′ = @. (1 - z) * h̃ + z * h
Expand Down Expand Up @@ -444,9 +447,10 @@ GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state =
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,HH,T}
function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,HH,T}
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
xT = _match_eltype(m, T, x)
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
r, z = _gru_output(gxs, ghs, bs)
h̃ = tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
h′ = @. (1 - z) * h̃ + z * h
Expand Down
44 changes: 44 additions & 0 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,47 @@ true
σ = std(x, dims=dims, mean=μ, corrected=false)
return @. (x - μ) / (σ + ϵ)
end

"""
_match_eltype(layer, ::Type{T}, x)
_match_eltype(layer, x)

This internal function corrects most layer input to match the type of the weights.
The second method uses `T = eltype(layer.weight)`.

It solves a common performance bug: Before, accidentally supplying `Float64` input,
or an activation function which produces `Float64`, would silently run the
entire forward pass in this precision.
"""
_match_eltype(layer, ::Type{T}, x::AbstractArray{T}) where {T} = x

# A common mistake, print a friendly warning, and fix it:
function _match_eltype(layer, ::Type{Float32}, x::AbstractArray{Float64})
# This warning is the only reason this needs to take the layer.
@warn "Layer with Float32 parameters got Float64 input.
The input will be converted, but any earlier layers may be very slow." layer summary(x) maxlog=1
convert(AbstractArray{Float32}, x)
end

# Allow OneHot to reach specialisation of * etc:
_match_eltype(layer, ::Type, x::OneHotLike) = x

# Other floats, and integers, silently fix.
function _match_eltype(layer, ::Type{T}, x::AbstractArray{<:Union{AbstractFloat, Integer}}) where {T}
convert(AbstractArray{T}, x)
end
Comment on lines +85 to +88
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct for all types, for example TrackedArray and Array{TrackedReal}. This is leading to a good amount of downstream breakage. Is there a reason this wasn't made to be in a breaking update?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing T<:Union{AbstractFloat,Integer} as well would likely be sufficient.

This is leading to a good amount of downstream breakage

Can you provide some links? We made a point of double-checking the reverse CI results for DiffEqFlux, NeuralPDE and OperatorLearning, so if they have insufficient coverage then that's very concerning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SciMLSensitivity is the one that tests the adjoints in detail. It found it in two tests. SciML/SciMLSensitivity.jl#784 is really just a test of master. An MWE:

using Flux, Tracker

x = Float32[0.8; 0.8]
ann = Chain(Dense(2, 10, tanh), Dense(10, 1))
p2, re = Flux.destructure(ann)
u = Tracker.TrackedArray(x)
p = Tracker.TrackedArray(p2)
z = re(p)(u)[1]
@show typeof(z) # Tracker.TrackedReal{Tracker.TrackedReal{Float32}}

The issue seems to be because eltype(TrackedArray) is TrackedReal, so this makes a TrackedArray{TrackedReal}.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason the relevant test groups for SciMLSensitivity weren't added alongside https://github.com/FluxML/Flux.jl/blob/master/.github/workflows/Downstream.yml#L28-L30? I don't think anyone actively searches for new downstream tests unless they're in the same org.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this wasn't made to be in a breaking update?

Yes, the reason is that there is no plan to stop supporting code which used to work.

This is a bug, and (once again) it turns out that people rely on Flux in ways which aren't tested. Besides adding more downstream tests, it would be extremely helpful if someone who needs Tracker could write up some small self-contained set of Tracker tests, to include here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Geez, I didn't know that was the sentiment. Just tell people to go away and not come back next time...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, bug reports are welcome, and a MWE doesn't have to test all cases.

"would be extremely helpful if" is obviously suggesting useful contributions beyond a bug report. Which nobody is obliged to work on. But flippant suggestions that the MWE is somehow equivalent to upgraded tests get flippant responses.

Copy link
Member

@ChrisRackauckas ChrisRackauckas Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which nobody is obliged to work on. But flippant suggestions that the MWE is somehow equivalent to upgraded tests get flippant responses.

Nobody is obligated to work on it, which is why when I'm going to work on it that it's pretty demoralizing to get such a flippant reply. I was just asking if that MWE was a good start to making a test or whether you were seeing the issue as something different. All you had to say was that it would be good if it checked a few more layer types (and remove the restructure? Is that the Optimisers part you were mentioning?) And I will include that in the PR, along with adding SciMLSensitivity core 1-5 as downstream tests, and a package extension to fix the Tracker case (though this really needs a new array primitive, so I was testing the promote eltype stuff in Array interface and patching all odd array types). But at least I understand now why I get so many DMs scared of contributing to Flux...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the reason is that there is no plan to stop supporting code which used to work.

The purpose of the PR is to change the output of every call that was doing such mixed precision. It changes the element type and the result of a lot of things, and explains some other tests which broke (SciMLDocs). I think that would have at least been discussed as a breaking change which is why I was confused that I didn't find such a discussion. Changing something that was 32 but precision to 64 bit is borderline, but changing a function call from 64 to 32 is going to give a different output. I mean it's fine if the answer is that the results are not guaranteed to that level and therefore it's not breaking, or sorry we didn't catch it in any test so let's improve downstream testing, but I just didn't know where to find the conversation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it's a bug, we thought about and tested dual numbers, but apparently failed to think about Tracker. Won't be hard to fix but won't happen today. Tracker lies about element types so a bit of care will be needed to get it right & test to ensure it cannot go wrong again.

Accidental promotion to Float64 was the most common Flux performance bug, and the one the docs spend pages trying to teach you to avoid. I was trying to re-write these to be clearer but really it's seems crazy to spend so much text teaching about a footgun we can simply remove. The other half of this footgun was accidental promotion of gradients, which is harder to see, and similarly solved some time ago.

When we discussed this, nobody could picture a real reason to run a model with low-precision weights and high-precision data. For high-precision results, it is of course no problem to run a model with high-precision weights.

changes the element type and the result of a lot of things

What does "a lot of things" mean?


# Weird types like Nil, Dual, etc, we allow through:
_match_eltype(layer, ::Type, x::AbstractArray) = x

# 2-arg method, for common layers with layer.weight
_match_eltype(layer, x) = _match_eltype(layer, eltype(layer.weight), x)

# Trivial rule:
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T}
_match_eltype(layer, T, x), dx -> (NoTangent(), ZeroTangent(), NoTangent(), dx) # does not un-thunk dx
end
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, x::AbstractArray)
_match_eltype(layer, x), dx -> (ZeroTangent(), NoTangent(), dx) # does not un-thunk dx
end

11 changes: 10 additions & 1 deletion src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,16 @@ for (fn, Dims) in ((:conv, DenseConvDims),)
end
end

# Recurrent layers: just convert to the type they like & convert back.

for Cell in [:RNNCell, :LSTMCell, :GRUCell, :GRUv3Cell]
@eval function (m::Recur{<:$Cell})(x::AbstractArray{Nil})
xT = fill!(similar(m.cell.Wi, size(x)), 0)
_, y = m.cell(m.state, xT) # discard the new state
return similar(x, size(y))
end
end


"""
@autosize (size...,) Chain(Layer(_ => 2), Layer(_), ...)
Expand Down Expand Up @@ -229,7 +239,6 @@ Limitations:
* While `@autosize (5, 32) Flux.Bilinear(_ => 7)` is OK, something like `Bilinear((_, _) => 7)` will fail.
* While `Scale(_)` and `LayerNorm(_)` are fine (and use the first dimension), `Scale(_,_)` and `LayerNorm(_,_)`
will fail if `size(x,1) != size(x,2)`.
* RNNs won't work: `@autosize (7, 11) LSTM(_ => 5)` fails, because `outputsize(RNN(3=>7), (3,))` also fails, a known issue.
"""
macro autosize(size, model)
Meta.isexpr(size, :tuple) || error("@autosize's first argument must be a tuple, the size of the input")
Expand Down
10 changes: 5 additions & 5 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ It differs from `Optimisers.setup` in that it:

# Example
```jldoctest
julia> model = Dense(2=>1, leakyrelu; init=ones32);
julia> model = Dense(2=>1, leakyrelu; init=ones);

julia> opt_state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ())
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [0.0]), σ = ())

julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:

Expand All @@ -39,11 +39,11 @@ julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y
end

julia> model.bias # was zero, mutated by Flux.train!
1-element Vector{Float32}:
10.190001
1-element Vector{Float64}:
10.19

julia> opt_state # mutated by Flux.train!
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ())
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [-10.09]), σ = ())
```
"""
function setup(rule::Optimisers.AbstractRule, model)
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,14 @@ to the constructor's keyword `bias=bias`.
* `bias == true` creates a trainable array of the given size, of the same type as `weights`, initialised to zero.
* `bias == false` returns `false`, which is understood by AD to be non-differentiable.
* `bias::AbstractArray` uses the array provided, provided it has the correct size.
It does not at present correct the `eltype` to match that of `weights`.
It will also correct the `eltype` to match that of `weights`.
"""
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
bias ? fill!(similar(weights, dims...), 0) : false
end
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
bias
convert(AbstractArray{eltype(weights)}, bias)
end


Expand Down
Loading