From cb43150cde7c9e8ab2a3dc1db1c3133508f1f8f9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 20 Dec 2022 13:39:00 -0500 Subject: [PATCH 1/5] some fast paths forward --- src/layers/basic.jl | 6 ++++++ src/layers/conv.jl | 12 ++++++++++++ src/layers/normalise.jl | 18 +++++++++--------- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9524c0c284..64362f68f7 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -172,6 +172,9 @@ function (a::Dense)(x::AbstractVecOrMat) return σ.(a.weight * x .+ a.bias) end +(a::Dense{typeof(identity), <:AbstractMatrix, Bool})(x::AbstractVecOrMat) = + a.weight * x # fast path, no broadcast + (a::Dense)(x::AbstractArray) = reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...) @@ -246,6 +249,9 @@ function (a::Scale)(x::AbstractArray) σ.(a.scale .* x .+ a.bias) end +(a::Scale{typeof(identity), <:AbstractArray, Bool})(x::AbstractArray) = + a.scale .* x + function Base.show(io::IO, l::Scale) print(io, "Scale(", join(size(l.scale), ", ")) l.σ == identity || print(io, ", ", l.σ) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 003395c15d..b5d0871fd4 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -199,6 +199,10 @@ function (c::Conv)(x::AbstractArray) cdims = conv_dims(c, x) σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c)) end +function (c::Conv{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray) + cdims = conv_dims(c, x) + conv(x, c.weight, cdims) # fast path, no broadcast +end _channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups _channels_out(l::Conv) = size(l.weight, ndims(l.weight)) @@ -332,6 +336,10 @@ function (c::ConvTranspose)(x::AbstractArray) cdims = conv_transpose_dims(c, x) σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c)) end +function (c::ConvTranspose{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray) + cdims = conv_transpose_dims(c, x) + ∇conv_data(x, c.weight, cdims) # fast path, no broadcast +end function Base.show(io::IO, l::ConvTranspose) print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2]) @@ -470,6 +478,10 @@ function (c::CrossCor)(x::AbstractArray) cdims = crosscor_dims(c, x) σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c)) end +function (c::CrossCor{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray) + cdims = crosscor_dims(c, x) + crosscor(x, c.weight, cdims) # fast path, no broadcast +end function Base.show(io::IO, l::CrossCor) print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2]) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 89eee976ee..7a47a9c7a2 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -210,7 +210,7 @@ true ``` """ struct LayerNorm{F,D,T,N} - λ::F + λ::F # this field is not used diag::D ϵ::T size::NTuple{N,Int} @@ -254,16 +254,16 @@ function _norm_layer_forward( end end - o = _norm_layer_forward(x, μ, σ², l.ϵ) - hasaffine(l) || return l.λ.(o) - - γ = reshape(l.γ, affine_shape) - β = reshape(l.β, affine_shape) - return l.λ.(γ .* o .+ β) + s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, smaller... ideally mean_var(x, ε)? + if hasaffine(l) + γ = reshape(l.γ, affine_shape) # ideally reshape on construction, store Scale? + β = reshape(l.β, affine_shape) + return l.λ.(γ .* s .* (x .- μ) .+ β) + else + return l.λ.(s .* (x .- μ)) + end end -@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ) - function _track_stats!( bn, x::AbstractArray{T, N}, μ, σ², reduce_dims, ) where {T, N} From 7bf759473aeb0f78a79ab9ba6b153c9d05607c6d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 20 Dec 2022 21:05:06 -0500 Subject: [PATCH 2/5] fixup --- src/layers/normalise.jl | 11 +++++------ test/layers/normalisation.jl | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 7a47a9c7a2..d49eb2c233 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -254,9 +254,9 @@ function _norm_layer_forward( end end - s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, smaller... ideally mean_var(x, ε)? + s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, fewer inv∘sqrt calls if hasaffine(l) - γ = reshape(l.γ, affine_shape) # ideally reshape on construction, store Scale? + γ = reshape(l.γ, affine_shape) # ideally reshape on construction? β = reshape(l.β, affine_shape) return l.λ.(γ .* s .* (x .- μ) .+ β) else @@ -356,10 +356,9 @@ end @functor BatchNorm trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) -function (BN::BatchNorm)(x) - @assert size(x, ndims(x)-1) == BN.chs - N = ndims(x) - reduce_dims = [1:N-2; N] +function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N} + size(x, N-1) == BN.chs || error("BatchNorm expected an input with $(BN.chs) channels, got size(x) == $(size(x))") + reduce_dims = ntuple(d -> d + (d==N-1), N-1) # i.e. 1:N with N-1 removed affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) return _norm_layer_forward(BN, x; reduce_dims, affine_shape) end diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 859d703368..a0dc79a762 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -166,11 +166,11 @@ end end # with activation function - let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0; - 2.0 4.0 6.0] + let m = BatchNorm(2, sigmoid) + x = Float32[1.0 3.0 5.0; 2.0 4.0 6.0] y = m(x) @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) - @inferred m(x) + @inferred m(x) # fails when x::Matrix{Float64}, do we care? end let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) From 8e9a5cceebf49e8d44b762e44ccdcd8d96ad9978 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 23 Dec 2022 09:44:57 -0500 Subject: [PATCH 3/5] also add conversion to Float32, with a warning --- src/layers/basic.jl | 16 ++++++++++++++++ src/layers/conv.jl | 9 +++++++++ test/layers/basic.jl | 10 ++++++++++ 3 files changed, 35 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 64362f68f7..0e1684e004 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -178,6 +178,22 @@ end (a::Dense)(x::AbstractArray) = reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...) +function (a::Dense{<:Any,<:AbstractMatrix{Float32}})(x::AbstractVecOrMat{<:Union{Float64,Integer}}) + _warn_32_64(a, x) + a(convert(AbstractArray{Float32}, x)) +end +function (a::Dense{typeof(identity),<:AbstractMatrix{Float32},Bool})(x::AbstractVecOrMat{<:Union{Float64,Integer}}) # solve ambiguity + _warn_32_64(a, x) + a(convert(AbstractArray{Float32}, x)) +end + +function _warn_32_64(layer, x::AbstractArray{Float64}) + @warn "Layer with Float32 parameters got Float64 input. + The input will be converted, but any earlier layers may be very slow" layer summary(x) maxlog=1 +end +_warn_32_64(layer, x::AbstractArray) = nothing # silently fix integer input? +ChainRulesCore.@non_differentiable _warn_32_64(::Any, ::Any) + function Base.show(io::IO, l::Dense) print(io, "Dense(", size(l.weight, 2), " => ", size(l.weight, 1)) l.σ == identity || print(io, ", ", l.σ) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index b5d0871fd4..3eee2ae133 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -204,6 +204,15 @@ function (c::Conv{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::Abstrac conv(x, c.weight, cdims) # fast path, no broadcast end +function (c::Conv{<:Any,<:Any,<:Any,<:AbstractArray{Float32}})(x::AbstractArray{<:Union{Float64,Integer}}) + _warn_32_64(c, x) # warning about a slow path + c(convert(AbstractArray{Float32}, x)) +end +function (c::Conv{<:Any,<:Any,typeof(identity),<:AbstractArray{Float32},Bool})(x::AbstractArray{<:Union{Float64,Integer}}) + _warn_32_64(c, x) + c(convert(AbstractArray{Float32}, x)) +end + _channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups _channels_out(l::Conv) = size(l.weight, ndims(l.weight)) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 1f9d30dec5..ad91075fe5 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -89,6 +89,16 @@ import Flux: activations @test Dense(10, 2, identity, init = ones)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] @test Dense(10, 2, identity, init = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] end + @testset "fast paths, type fixes, ambiguities" begin + d1 = Dense(2 => 3) + d2 = Dense(d1.weight, false) + x1 = randn(Float32, 2, 4) + @test d1(x1) ≈ d2(x1) ≈ d1.weight * x1 + x2 = Float64.(x1) + @test d1(x2) ≈ d2(x2) ≈ d1.weight * x2 + x3 = rand(-5:5, 2, 4) + @test d1(x3) ≈ d2(x3) ≈ d1.weight * x3 + end end @testset "Scale" begin From fcbc7b4c000e2500276041699c643bae2e44c0fe Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 23 Dec 2022 12:12:38 -0500 Subject: [PATCH 4/5] better implementation of input type fixup, apply to RNN too --- src/layers/basic.jl | 40 +++++++++++++++++++++------------------- src/layers/conv.jl | 24 ++++++++++-------------- src/layers/recurrent.jl | 20 ++++++++++++-------- test/layers/basic.jl | 6 ++++++ test/layers/recurrent.jl | 5 +++-- 5 files changed, 52 insertions(+), 43 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0e1684e004..3eb6c00136 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -169,31 +169,16 @@ end function (a::Dense)(x::AbstractVecOrMat) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc - return σ.(a.weight * x .+ a.bias) + xT = _match_eltype(a, eltype(a.weight), x) # fixes Float64 input, etc. + return σ.(a.weight * xT .+ a.bias) end (a::Dense{typeof(identity), <:AbstractMatrix, Bool})(x::AbstractVecOrMat) = - a.weight * x # fast path, no broadcast + a.weight * _match_eltype(a, eltype(a.weight), x) # fast path, no broadcast (a::Dense)(x::AbstractArray) = reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...) -function (a::Dense{<:Any,<:AbstractMatrix{Float32}})(x::AbstractVecOrMat{<:Union{Float64,Integer}}) - _warn_32_64(a, x) - a(convert(AbstractArray{Float32}, x)) -end -function (a::Dense{typeof(identity),<:AbstractMatrix{Float32},Bool})(x::AbstractVecOrMat{<:Union{Float64,Integer}}) # solve ambiguity - _warn_32_64(a, x) - a(convert(AbstractArray{Float32}, x)) -end - -function _warn_32_64(layer, x::AbstractArray{Float64}) - @warn "Layer with Float32 parameters got Float64 input. - The input will be converted, but any earlier layers may be very slow" layer summary(x) maxlog=1 -end -_warn_32_64(layer, x::AbstractArray) = nothing # silently fix integer input? -ChainRulesCore.@non_differentiable _warn_32_64(::Any, ::Any) - function Base.show(io::IO, l::Dense) print(io, "Dense(", size(l.weight, 2), " => ", size(l.weight, 1)) l.σ == identity || print(io, ", ", l.σ) @@ -204,6 +189,22 @@ end Dense(W::LinearAlgebra.Diagonal, bias = true, σ = identity) = Scale(W.diag, bias, σ) +_match_eltype(layer, ::Type{T}, x::AbstractArray{T}) where {T} = x # best case +function _match_eltype(layer, ::Type{Float32}, x::AbstractArray{Float64}) # common mistake + @warn "Layer with Float32 parameters got Float64 input. + The input will be converted, but any earlier layers may be very slow" layer summary(x) maxlog=1 + convert(AbstractArray{Float32}, x) +end +function _match_eltype(layer, ::Type{T}, x::AbstractArray{<:Union{AbstractFloat, Integer}}) where {T} + convert(AbstractArray{T}, x) +end +_match_eltype(layer, ::Type, x::OneHotLike) = x +_match_eltype(layer, ::Type, x::AbstractArray) = x # weird types + +function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T} + _match_eltype(layer, T, x), dx -> (NoTangent(), ZeroTangent(), NoTangent(), dx) # does not un-thunk dx +end + """ Scale(size::Integer..., σ=identity; bias=true, init=ones32) Scale(scale::AbstractArray, [bias, σ]) @@ -443,6 +444,7 @@ Bilinear((in12, out)::Pair{<:Integer, <:Integer}, σ = identity; kw...) = Biline function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) W, b, σ = a.weight, a.bias, a.σ + xT = _match_eltype(a, eltype(a.weight), x) d_z, d_x, d_y = size(W) d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W")) @@ -452,7 +454,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :)) # @einsum Z[o,s] := Wy[o,i,s] * x[i,s] - Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :))) + Wyx = batched_mul(Wy, reshape(xT, (d_x, 1, :))) Z = reshape(Wyx, (d_z, :)) # @einsum out[o,s] := σ(Z[o,i] + b[o]) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 3eee2ae133..d8f8c3fa07 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -197,22 +197,14 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any) function (c::Conv)(x::AbstractArray) σ = NNlib.fast_act(c.σ, x) cdims = conv_dims(c, x) - σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c)) + xT = _match_eltype(c, eltype(c.weight), x) + σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c)) end function (c::Conv{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray) cdims = conv_dims(c, x) conv(x, c.weight, cdims) # fast path, no broadcast end -function (c::Conv{<:Any,<:Any,<:Any,<:AbstractArray{Float32}})(x::AbstractArray{<:Union{Float64,Integer}}) - _warn_32_64(c, x) # warning about a slow path - c(convert(AbstractArray{Float32}, x)) -end -function (c::Conv{<:Any,<:Any,typeof(identity),<:AbstractArray{Float32},Bool})(x::AbstractArray{<:Union{Float64,Integer}}) - _warn_32_64(c, x) - c(convert(AbstractArray{Float32}, x)) -end - _channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups _channels_out(l::Conv) = size(l.weight, ndims(l.weight)) @@ -343,11 +335,13 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any) function (c::ConvTranspose)(x::AbstractArray) σ = NNlib.fast_act(c.σ, x) cdims = conv_transpose_dims(c, x) - σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c)) + xT = _match_eltype(c, eltype(c.weight), x) + σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c)) end function (c::ConvTranspose{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray) cdims = conv_transpose_dims(c, x) - ∇conv_data(x, c.weight, cdims) # fast path, no broadcast + xT = _match_eltype(c, eltype(c.weight), x) + ∇conv_data(xT, c.weight, cdims) # fast path, no broadcast end function Base.show(io::IO, l::ConvTranspose) @@ -485,11 +479,13 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any) function (c::CrossCor)(x::AbstractArray) σ = NNlib.fast_act(c.σ, x) cdims = crosscor_dims(c, x) - σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c)) + xT = _match_eltype(c, eltype(c.weight), x) + σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c)) end function (c::CrossCor{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray) cdims = crosscor_dims(c, x) - crosscor(x, c.weight, cdims) # fast path, no broadcast + xT = _match_eltype(c, eltype(c.weight), x) + crosscor(xT, c.weight, cdims) # fast path, no broadcast end function Base.show(io::IO, l::CrossCor) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 7cabc9d5b6..fc758fb592 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -200,10 +200,11 @@ end RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) = RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1)) -function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T} +function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T} Wi, Wh, b = m.Wi, m.Wh, m.b σ = NNlib.fast_act(m.σ, x) - h = σ.(Wi*x .+ Wh*h .+ b) + xT = _match_eltype(m, T, x)::AbstractArray{T} # any AbstractFloat is so converted + h = σ.(Wi*xT .+ Wh*h .+ b) return h, reshape_cell_output(h, x) end @@ -305,9 +306,10 @@ function LSTMCell((in, out)::Pair; return cell end -function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T} +function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T} b, o = m.b, size(h, 1) - g = muladd(m.Wi, x, muladd(m.Wh, h, b)) + xT = _match_eltype(m, T, x)::AbstractArray{T} + g = muladd(m.Wi, xT, muladd(m.Wh, h, b)) input, forget, cell, output = multigate(g, o, Val(4)) c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) h′ = @. sigmoid_fast(output) * tanh_fast(c′) @@ -376,9 +378,10 @@ end GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) = GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1)) -function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T} +function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T} Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1) - gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3)) + xT = _match_eltype(m, T, x)::AbstractArray{T} + gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3)) r, z = _gru_output(gxs, ghs, bs) h̃ = @. tanh_fast(gxs[3] + r * ghs[3] + bs[3]) h′ = @. (1 - z) * h̃ + z * h @@ -444,9 +447,10 @@ GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), init(out, out), init_state(out,1)) -function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,HH,T} +function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,HH,T} Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1) - gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3)) + xT = _match_eltype(m, T, x)::AbstractArray{T} + gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3)) r, z = _gru_output(gxs, ghs, bs) h̃ = tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3]) h′ = @. (1 - z) * h̃ + z * h diff --git a/test/layers/basic.jl b/test/layers/basic.jl index ad91075fe5..b3da86bcfd 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -96,8 +96,14 @@ import Flux: activations @test d1(x1) ≈ d2(x1) ≈ d1.weight * x1 x2 = Float64.(x1) @test d1(x2) ≈ d2(x2) ≈ d1.weight * x2 + @test d1(x2) isa Array{Float32} + @test d2(x2) isa Array{Float32} x3 = rand(-5:5, 2, 4) @test d1(x3) ≈ d2(x3) ≈ d1.weight * x3 + x4 = rand(Bool, 2, 4) + @test d1(x4) ≈ d2(x4) ≈ d1.weight * x4 + x5 = Flux.onehotbatch(rand(Bool, 5), (true, false)) + @test d1(x5) ≈ d2(x5) ≈ d1.weight * x5 end end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index facab8466b..5691a49575 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -93,9 +93,10 @@ end @testset "RNN-input-state-eltypes" begin @testset for R in [RNN, GRU, LSTM, GRUv3] m = R(3 => 5) - x = rand(Float64, 3, 1) + x = rand(Float64, 3, 1) # Float64 input is now converted Flux.reset!(m) - @test_throws MethodError m(x) + @test m(x) isa Array{Float32} + @test m.state isa Array{Float32} end end From ce1cf881a711447f6c80851c5397510e51d713c1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 26 Dec 2022 17:58:26 -0500 Subject: [PATCH 5/5] error on construction for a mistake I just made --- src/layers/basic.jl | 1 + test/layers/basic.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 3eb6c00136..41f8e9cca6 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -162,6 +162,7 @@ end function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity; init = glorot_uniform, bias = true) + σ isa Number && throw(ArgumentError("can't use $σ as an activation function!")) Dense(init(out, in), bias, σ) end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index b3da86bcfd..8cab83e00d 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -65,6 +65,7 @@ import Flux: activations @test_throws MethodError Dense(10, 10.5) @test_throws MethodError Dense(10, 10.5, tanh) + @test_throws ArgumentError Dense(3 => 4, false) @test_throws DimensionMismatch Dense(3,4; bias=rand(5)) @test_throws DimensionMismatch Dense(rand(4,3), rand(5)) @test_throws MethodError Dense(rand(5))