From 6689316a82c37963edb13d6e3bc81d91e6e5b3bb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 7 Jan 2023 22:23:31 -0500 Subject: [PATCH 1/6] _match_eltype --- src/layers/basic.jl | 3 ++- src/layers/conv.jl | 9 +++++--- src/layers/recurrent.jl | 20 ++++++++++-------- src/layers/stateless.jl | 44 ++++++++++++++++++++++++++++++++++++++++ src/outputsize.jl | 11 +++++++++- test/layers/basic.jl | 17 ++++++++++++++++ test/layers/conv.jl | 14 +++++++++++++ test/layers/recurrent.jl | 24 ++++++++++++++++++++++ test/outputsize.jl | 7 +++++++ 9 files changed, 136 insertions(+), 13 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9524c0c284..7253fcf8c2 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -169,7 +169,8 @@ end function (a::Dense)(x::AbstractVecOrMat) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc - return σ.(a.weight * x .+ a.bias) + xT = _match_eltype(a, x) # fixes Float64 input, etc. + return σ.(a.weight * xT .+ a.bias) end (a::Dense)(x::AbstractArray) = diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 003395c15d..72d7f000d9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -197,7 +197,8 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any) function (c::Conv)(x::AbstractArray) σ = NNlib.fast_act(c.σ, x) cdims = conv_dims(c, x) - σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c)) + xT = _match_eltype(c, x) + σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c)) end _channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups @@ -330,7 +331,8 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any) function (c::ConvTranspose)(x::AbstractArray) σ = NNlib.fast_act(c.σ, x) cdims = conv_transpose_dims(c, x) - σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c)) + xT = _match_eltype(c, x) + σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c)) end function Base.show(io::IO, l::ConvTranspose) @@ -468,7 +470,8 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any) function (c::CrossCor)(x::AbstractArray) σ = NNlib.fast_act(c.σ, x) cdims = crosscor_dims(c, x) - σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c)) + xT = _match_eltype(c, x) + σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c)) end function Base.show(io::IO, l::CrossCor) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 7cabc9d5b6..193f53c3dc 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -200,10 +200,11 @@ end RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) = RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1)) -function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T} +function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T} Wi, Wh, b = m.Wi, m.Wh, m.b σ = NNlib.fast_act(m.σ, x) - h = σ.(Wi*x .+ Wh*h .+ b) + xT = _match_eltype(m, T, x) + h = σ.(Wi*xT .+ Wh*h .+ b) return h, reshape_cell_output(h, x) end @@ -305,9 +306,10 @@ function LSTMCell((in, out)::Pair; return cell end -function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T} +function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T} b, o = m.b, size(h, 1) - g = muladd(m.Wi, x, muladd(m.Wh, h, b)) + xT = _match_eltype(m, T, x) + g = muladd(m.Wi, xT, muladd(m.Wh, h, b)) input, forget, cell, output = multigate(g, o, Val(4)) c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) h′ = @. sigmoid_fast(output) * tanh_fast(c′) @@ -376,9 +378,10 @@ end GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) = GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1)) -function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T} +function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T} Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1) - gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3)) + xT = _match_eltype(m, T, x) + gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3)) r, z = _gru_output(gxs, ghs, bs) h̃ = @. tanh_fast(gxs[3] + r * ghs[3] + bs[3]) h′ = @. (1 - z) * h̃ + z * h @@ -444,9 +447,10 @@ GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), init(out, out), init_state(out,1)) -function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,HH,T} +function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,HH,T} Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1) - gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3)) + xT = _match_eltype(m, T, x) + gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3)) r, z = _gru_output(gxs, ghs, bs) h̃ = tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3]) h′ = @. (1 - z) * h̃ + z * h diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 06c8b6a4a9..9c34e6ca6c 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -57,3 +57,47 @@ true σ = std(x, dims=dims, mean=μ, corrected=false) return @. (x - μ) / (σ + ϵ) end + +""" + _match_eltype(layer, ::Type{T}, x) + _match_eltype(layer, x) + +This internal function corrects most layer input to match the type of the weights. +The second method uses `T = eltype(layer.weight)`. + +It solves a common performance bug: Before, accidentally supplying `Float64` input, +or an activation function which produces `Float64`, would silently run the +entire forward pass in this precision. +""" +_match_eltype(layer, ::Type{T}, x::AbstractArray{T}) where {T} = x + +# A common mistake, print a friendly warning, and fix it: +function _match_eltype(layer, ::Type{Float32}, x::AbstractArray{Float64}) + # This warning is the only reason this needs to take the layer. + @warn "Layer with Float32 parameters got Float64 input. + The input will be converted, but any earlier layers may be very slow." layer summary(x) maxlog=1 + convert(AbstractArray{Float32}, x) +end + +# Allow OneHot to reach specialisation of * etc: +_match_eltype(layer, ::Type, x::OneHotLike) = x + +# Other floats, and integers, silently fix. +function _match_eltype(layer, ::Type{T}, x::AbstractArray{<:Union{AbstractFloat, Integer}}) where {T} + convert(AbstractArray{T}, x) +end + +# Weird types like Nil, Dual, etc, we allow through: +_match_eltype(layer, ::Type, x::AbstractArray) = x + +# 2-arg method, for common layers with layer.weight +_match_eltype(layer, x) = _match_eltype(layer, eltype(layer.weight), x) + +# Trivial rule: +function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T} + _match_eltype(layer, T, x), dx -> (NoTangent(), ZeroTangent(), NoTangent(), dx) # does not un-thunk dx +end +function ChainRulesCore.rrule(::typeof(_match_eltype), layer, x::AbstractArray) + _match_eltype(layer, x), dx -> (ZeroTangent(), NoTangent(), dx) # does not un-thunk dx +end + diff --git a/src/outputsize.jl b/src/outputsize.jl index 9fd9545b5f..64faa6bdab 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -173,6 +173,16 @@ for (fn, Dims) in ((:conv, DenseConvDims),) end end +# Recurrent layers: just convert to the type they like & convert back. + +for Cell in [:RNNCell, :LSTMCell, :GRUCell, :GRUv3Cell] + @eval function (m::Recur{<:$Cell})(x::AbstractArray{Nil}) + xT = fill!(similar(m.cell.Wi, size(x)), 0) + _, y = m.cell(m.state, xT) # discard the new state + return similar(x, size(y)) + end +end + """ @autosize (size...,) Chain(Layer(_ => 2), Layer(_), ...) @@ -229,7 +239,6 @@ Limitations: * While `@autosize (5, 32) Flux.Bilinear(_ => 7)` is OK, something like `Bilinear((_, _) => 7)` will fail. * While `Scale(_)` and `LayerNorm(_)` are fine (and use the first dimension), `Scale(_,_)` and `LayerNorm(_,_)` will fail if `size(x,1) != size(x,2)`. -* RNNs won't work: `@autosize (7, 11) LSTM(_ => 5)` fails, because `outputsize(RNN(3=>7), (3,))` also fails, a known issue. """ macro autosize(size, model) Meta.isexpr(size, :tuple) || error("@autosize's first argument must be a tuple, the size of the input") diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 1f9d30dec5..896bf15e0d 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -89,6 +89,23 @@ import Flux: activations @test Dense(10, 2, identity, init = ones)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] @test Dense(10, 2, identity, init = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] end + @testset "type matching" begin + d1 = Dense(2 => 3) + d2 = Dense(d1.weight, false) + x1 = randn(Float32, 2, 4) + @test d1(x1) ≈ d2(x1) ≈ d1.weight * x1 + x2 = Float64.(x1) + @test d1(x2) ≈ d2(x2) ≈ d1.weight * x2 + @test d1(x2) isa Array{Float32} # tests _match_eltype, will print a warning + @test d2(x2) isa Array{Float32} + + x3 = rand(-5:5, 2, 4) + @test d1(x3) ≈ d2(x3) ≈ d1.weight * x3 + x4 = rand(Bool, 2, 4) + @test d1(x4) ≈ d2(x4) ≈ d1.weight * x4 + x5 = Flux.onehotbatch(rand(Bool, 5), (true, false)) + @test d1(x5) ≈ d2(x5) ≈ d1.weight * x5 + end end @testset "Scale" begin diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 019f3fd603..f0eb281e48 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -286,3 +286,17 @@ end end @test_throws DimensionMismatch fun(rand(2,3,4), rand(6)) end + +@testset "type matching" begin + x = rand(Float64, 10,2,5) + xi = rand(-3:3, 10,2,5) + c1 = Conv((3,), 2=>4, relu) + @test @inferred(c1(x)) isa Array{Float32, 3} + @test c1(xi) isa Array{Float32, 3} + + c2 = CrossCor((3,), 2=>1, relu) + @test @inferred(c2(x)) isa Array{Float32, 3} + + c3 = ConvTranspose((3,), 2=>4, relu) + @test @inferred(c3(x)) isa Array{Float32, 3} +end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index facab8466b..915136339b 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -169,3 +169,27 @@ end @test size(m(x3)) == (5, 1, 2) end end + +@testset "type matching" begin + x = rand(Float64, 2, 4) + m1 = RNN(2=>3) + @test m1(x) isa Matrix{Float32} # uses _match_eltype, may print a warning + @test m1.state isa Matrix{Float32} + @test (@inferred m1(x); true) + @test Flux.outputsize(m1, size(x)) == size(m1(x)) + + m2 = LSTM(2=>3) + @test m2(x) isa Matrix{Float32} + @test (@inferred m2(x); true) + @test Flux.outputsize(m2, size(x)) == size(m2(x)) + + m3 = GRU(2=>3) + @test m3(x) isa Matrix{Float32} + @test (@inferred m3(x); true) + @test Flux.outputsize(m3, size(x)) == size(m3(x)) + + m4 = GRUv3(2=>3) + @test m4(x) isa Matrix{Float32} + @test (@inferred m4(x); true) + @test Flux.outputsize(m4, size(x)) == size(m4(x)) +end diff --git a/test/outputsize.jl b/test/outputsize.jl index 0e5b807a60..c1b77f5998 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -257,3 +257,10 @@ end # Can't let |> gpu act before the arrays are materialized... so it's an error: @test_throws ErrorException @eval @autosize (1,2,3) Dense(_=>2) |> f64 end + +@testset "type matching" begin + # Check that _match_eltype doesn't replace this with an array of Float32: + @test Flux._match_eltype(Dense(2=>3), fill(Flux.Nil(),2,4)) isa Matrix{Flux.Nil} + # For RNN etc there's a special path: + @test RNN(2=>3)(fill(Flux.Nil(),2,4)) isa Matrix{Flux.Nil} +end From 9aed157d020aa958e4273cd49e12c86359428681 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 7 Jan 2023 23:20:27 -0500 Subject: [PATCH 2/6] fix tests --- test/layers/recurrent.jl | 9 --------- test/utils.jl | 4 ++-- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 915136339b..f402f873c9 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -90,15 +90,6 @@ end end end -@testset "RNN-input-state-eltypes" begin - @testset for R in [RNN, GRU, LSTM, GRUv3] - m = R(3 => 5) - x = rand(Float64, 3, 1) - Flux.reset!(m) - @test_throws MethodError m(x) - end -end - @testset "multigate" begin x = rand(6, 5) res, (dx,) = Flux.withgradient(x) do x diff --git a/test/utils.jl b/test/utils.jl index fbb7f7d9d1..45b5d1f3d7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -290,8 +290,8 @@ end x32 = rand(Float32, 10) @test eltype(m[1].weight) == Float32 @test eltype(m(x32)) == Float32 - @test eltype(m(x64)) == Float64 - @test eltype(f64(m)(x32)) == Float64 + @test eltype(m(x64)) == Float32 # fixed by _match_eltype + @test eltype(f64(m)(x32)) == Float64 # _match_eltype promotes, Julia would too @test eltype(f64(m)(x64)) == Float64 @test eltype(f64(m)[1].weight) == Float64 @test eltype(f32(f64(m))[1].weight) == Float32 From 1bb4392e6852412edb336a6d7e97c5c668923c26 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 8 Jan 2023 00:20:33 -0500 Subject: [PATCH 3/6] skip a test on 1.6 --- test/layers/conv.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index f0eb281e48..8a8f28eefd 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -298,5 +298,8 @@ end @test @inferred(c2(x)) isa Array{Float32, 3} c3 = ConvTranspose((3,), 2=>4, relu) - @test @inferred(c3(x)) isa Array{Float32, 3} + @test c3(x) isa Array{Float32, 3} + if VERSION >= "v1.8" + @test (@inferred c3(x); true) # fails on 1.6 + end end From 3f8084ed6c3b72f0a7e6897d5eb01db48a4fd689 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 8 Jan 2023 00:21:14 -0500 Subject: [PATCH 4/6] also, don't allow bias to be wider type --- src/utils.jl | 4 ++-- test/layers/basic.jl | 4 ++-- test/layers/conv.jl | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 884fcd7465..3e634cc17a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -514,14 +514,14 @@ to the constructor's keyword `bias=bias`. * `bias == true` creates a trainable array of the given size, of the same type as `weights`, initialised to zero. * `bias == false` returns `false`, which is understood by AD to be non-differentiable. * `bias::AbstractArray` uses the array provided, provided it has the correct size. - It does not at present correct the `eltype` to match that of `weights`. + It will also correct the `eltype` to match that of `weights`. """ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) bias ? fill!(similar(weights, dims...), 0) : false end function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) - bias + convert(AbstractArray{eltype(weights)}, bias) end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 896bf15e0d..45e4750a6f 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -58,10 +58,10 @@ import Flux: activations @test Dense(rand(100,10), false, tanh).σ == tanh @test Dense(rand(100,10), rand(100)).σ == identity @test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type - @test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match + @test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match @test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64} - @test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64} + @test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64} @test_throws MethodError Dense(10, 10.5) @test_throws MethodError Dense(10, 10.5, tanh) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 8a8f28eefd..c83b2c18d3 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -299,7 +299,7 @@ end c3 = ConvTranspose((3,), 2=>4, relu) @test c3(x) isa Array{Float32, 3} - if VERSION >= "v1.8" + if VERSION >= v"1.8" @test (@inferred c3(x); true) # fails on 1.6 end end From 8708a0e30028db6d1879a25e52ab61d0de03e66a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 29 Jan 2023 15:25:27 -0500 Subject: [PATCH 5/6] use rand32 etc --- docs/src/tutorials/linear_regression.md | 12 +++++++----- src/layers/basic.jl | 14 +++++++------- src/layers/conv.jl | 6 +++--- src/layers/normalise.jl | 2 +- src/train.jl | 10 +++++----- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/docs/src/tutorials/linear_regression.md b/docs/src/tutorials/linear_regression.md index 993c66e494..f32ec30e78 100644 --- a/docs/src/tutorials/linear_regression.md +++ b/docs/src/tutorials/linear_regression.md @@ -272,6 +272,8 @@ Let's start by initializing our dataset. We will be using the [`BostonHousing`]( julia> dataset = BostonHousing(); julia> x, y = BostonHousing(as_df=false)[:]; + +julia> x, y = Float32.(x), Float32.(y) ``` We can now split the obtained data into training and testing data - @@ -287,7 +289,7 @@ This data contains a diverse number of features, which means that the features h ```jldoctest linear_regression_complex; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" julia> std(x_train) -134.06784844377117 +134.06786f0 ``` The data is indeed not normalised. We can use the [`Flux.normalise`](@ref) function to normalise the training data. @@ -296,7 +298,7 @@ The data is indeed not normalised. We can use the [`Flux.normalise`](@ref) funct julia> x_train_n = Flux.normalise(x_train); julia> std(x_train_n) -1.0000843694328236 +1.0000844f0 ``` The standard deviation is now close to one! Our data is ready! @@ -318,7 +320,7 @@ julia> function loss(model, x, y) end; julia> loss(model, x_train_n, y_train) -676.165591625047 +676.1656f0 ``` We can now proceed to the training phase! @@ -363,7 +365,7 @@ Let's have a look at the loss - ```jldoctest linear_regression_complex; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" julia> loss(model, x_train_n, y_train) -27.127200028562164 +27.1272f0 ``` The loss went down significantly! It can be minimized further by choosing an even smaller `δ`. @@ -376,7 +378,7 @@ The last step of this tutorial would be to test our model using the testing data julia> x_test_n = Flux.normalise(x_test); julia> loss(model, x_test_n, y_test) -66.91014769713368 +66.91015f0 ``` The loss is not as small as the loss of the training data, but it looks good! This also shows that our model is not overfitting! diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 7253fcf8c2..1e55076766 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -16,7 +16,7 @@ true julia> m = Chain(Dense(10 => 5, tanh), Dense(5 => 2)); -julia> x = rand(10, 32); +julia> x = rand32(10, 32); julia> m(x) == m[2](m[1](x)) true @@ -132,11 +132,11 @@ The weight matrix and/or the bias vector (of length `out`) may also be provided julia> d = Dense(5 => 2) Dense(5 => 2) # 12 parameters -julia> d(rand(Float32, 5, 64)) |> size +julia> d(rand32(5, 64)) |> size (2, 64) -julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimensions -(2, 1, 1, 64) +julia> d(rand32(5, 6, 4, 64)) |> size # treated as three batch dimensions +(2, 6, 4, 64) julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix Dense(5 => 2, tanh; bias=false) # 10 parameters @@ -476,7 +476,7 @@ julia> model = Chain(Dense(3 => 5), Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))), Dense(8 => 17)); -julia> model(rand(3)) |> size +julia> model(rand32(3)) |> size (17,) julia> model2 = Parallel(+; α = Dense(10, 2, tanh), β = Dense(5, 2)) @@ -486,10 +486,10 @@ Parallel( β = Dense(5 => 2), # 12 parameters ) # Total: 4 arrays, 34 parameters, 392 bytes. -julia> model2(rand(10), rand(5)) |> size +julia> model2(rand32(10), rand32(5)) |> size (2,) -julia> model2[:α](rand(10)) |> size +julia> model2[:α](rand32(10)) |> size (2,) julia> model2[:β] == model2[2] diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 72d7f000d9..5851620f9e 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -22,7 +22,7 @@ See also [`Conv`](@ref), [`MaxPool`](@ref). # Examples ```jldoctest -julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images +julia> xs = rand32(100, 100, 3, 50); # a batch of images julia> layer = Conv((2,2), 3 => 7, pad=SamePad()) Conv((2, 2), 3 => 7, pad=(1, 0, 1, 0)) # 91 parameters @@ -96,7 +96,7 @@ See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref). # Examples ```jldoctest -julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images +julia> xs = rand32(100, 100, 3, 50); # a batch of 50 RGB images julia> layer = Conv((5,5), 3 => 7, relu; bias = false) Conv((5, 5), 3 => 7, relu, bias=false) # 525 parameters @@ -238,7 +238,7 @@ See also [`Conv`](@ref) for more detailed description of keywords. # Examples ```jldoctest -julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images +julia> xs = rand32(100, 100, 3, 50); # a batch of 50 RGB images julia> layer = ConvTranspose((5,5), 3 => 7, relu) ConvTranspose((5, 5), 3 => 7, relu) # 532 parameters diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index e9362313ab..92238bfa3b 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -96,7 +96,7 @@ Does nothing to the input once [`testmode!`](@ref) is true. ```jldoctest julia> using Statistics -julia> x = randn(1000,1); +julia> x = randn32(1000,1); julia> m = Chain(Dense(1000 => 1000, selu), AlphaDropout(0.2)); diff --git a/src/train.jl b/src/train.jl index 90bec2534b..1eb860b9d3 100644 --- a/src/train.jl +++ b/src/train.jl @@ -27,10 +27,10 @@ It differs from `Optimisers.setup` in that it: # Example ```jldoctest -julia> model = Dense(2=>1, leakyrelu; init=ones32); +julia> model = Dense(2=>1, leakyrelu; init=ones); julia> opt_state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state -(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ()) +(weight = Leaf(Momentum{Float64}(0.1, 0.9), [0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [0.0]), σ = ()) julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps: @@ -39,11 +39,11 @@ julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y end julia> model.bias # was zero, mutated by Flux.train! -1-element Vector{Float32}: - 10.190001 +1-element Vector{Float64}: + 10.19 julia> opt_state # mutated by Flux.train! -(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ()) +(weight = Leaf(Momentum{Float64}(0.1, 0.9), [-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [-10.09]), σ = ()) ``` """ function setup(rule::Optimisers.AbstractRule, model) From dc5821f9200e25207f5f187f0a3a63e4a6e61edf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 29 Jan 2023 17:20:31 -0500 Subject: [PATCH 6/6] fixup --- docs/src/tutorials/linear_regression.md | 2 +- src/layers/normalise.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/src/tutorials/linear_regression.md b/docs/src/tutorials/linear_regression.md index f32ec30e78..a8155b1fa2 100644 --- a/docs/src/tutorials/linear_regression.md +++ b/docs/src/tutorials/linear_regression.md @@ -273,7 +273,7 @@ julia> dataset = BostonHousing(); julia> x, y = BostonHousing(as_df=false)[:]; -julia> x, y = Float32.(x), Float32.(y) +julia> x, y = Float32.(x), Float32.(y); ``` We can now split the obtained data into training and testing data - diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 92238bfa3b..d2e891cf87 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -36,7 +36,7 @@ julia> m(ones(2, 7)) # test mode, no effect 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 -julia> Flux.trainmode!(m); # would happen within gradient +julia> Flux.trainmode!(m); # equivalent to use within gradient julia> m(ones(2, 7)) 3×7 Matrix{Float64}: @@ -48,11 +48,11 @@ julia> y = m(ones(2, 10_000)); julia> using Statistics -julia> mean(y) # is about 2.0, as for test mode -1.9892222222222182 +julia> mean(y) # is about 2.0, same as in test mode +1.9989999999999961 julia> mean(iszero, y) # is about 0.4 -0.40323333333333333 +0.4003 ``` """ mutable struct Dropout{F<:Real,D,R<:AbstractRNG}