Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1556 #1557

Merged
merged 18 commits into from
Mar 31, 2021
Merged
48 changes: 22 additions & 26 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ extraChain(::Tuple{}, x) = ()


"""
Dense(in, out, σ=identity; bias=true, init=glorot_uniform)
Dense(in, out, σ = identity; bias = true, init = glorot_uniform)
Dense(W::AbstractMatrix, [bias, σ])

Create a traditional `Dense` layer, whose forward pass is given by:
Expand All @@ -81,7 +81,7 @@ as an `in × N` matrix, or any array with `size(x,1) == in`.
The out `y` will be a vector of length `out`, or a batch with
`size(y) == (out, size(x)[2:end]...)`

Keyword `bias=false` will switch off trainable bias for the layer.
Keyword `bias = false` will switch off trainable bias for the layer.
The initialisation of the weight matrix is `W = init(out, in)`, calling the function
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
Expand Down Expand Up @@ -109,46 +109,42 @@ julia> Flux.params(d1) # no trainable bias
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
```
"""
struct Dense{F, M<:AbstractMatrix, B}
weight::M
bias::B
struct Dense{F,S<:AbstractArray,T}
weight::S
bias::T
σ::F
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
b = create_bias(W, bias, size(W,1))
new{F,M,typeof(b)}(W, b, σ)
end
end

function Dense(in::Integer, out::Integer, σ = identity;
initW = nothing, initb = nothing,
init = glorot_uniform, bias=true)
Dense(W, b) = Dense(W, b, identity)

W = if initW !== nothing
Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense)
initW(out, in)
else
init(out, in)
Dense(W, b::Bool, σ) =
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved
Dense(W, create_bias(W, b, size(W,1)), σ)

function Dense(in::Integer, out::Integer, σ = identity; initW = nothing,
init = glorot_uniform, initb = nothing, bias::Bool = true)
if initW !== nothing
depwarn("initW is deprecated, please use the `init` keyword instead")
init = initW
end

b = if bias === true && initb !== nothing
Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense)
initb(out)
if initb !== nothing
depwarn("initb is deprecated, please use the array based constructors instead")
initb = initb
else
bias
initb = zeros
end

return Dense(W, b, σ)
Dense(init(out, in), bias ? initb(out) : Zeros(), σ)
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved
end

@functor Dense

function (a::Dense)(x::AbstractVecOrMat)
W, b, σ = a.weight, a.bias, a.σ
return σ.(W*x .+ b)
σ.(W * x .+ b)
end

(a::Dense)(x::AbstractArray) =
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
(a::Dense)(x) =
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved

function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1))
Expand Down
7 changes: 1 addition & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,7 @@ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
end
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
if eltype(bias) == eltype(weights)
return bias
else
@warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
return broadcast(eltype(weights), bias)
end
bias
end

"""
Expand Down