diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 32f275695b..d3c0bc4d6f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -133,7 +133,10 @@ function Dense(in::Integer, out::Integer, σ = identity; initW = nothing, else initb = zeros end - Dense(init(out, in), bias ? initb(out) : Zeros(), σ) + + W = init(out, in) + b = create_bias(W, bias, size(W, 1)) + Dense(W, b, σ) end @functor Dense diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 72b9dc1d6d..b38a172976 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -40,7 +40,7 @@ import Flux: activations @test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type @test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match - @test_skip Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64} + @test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64} @test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}