Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1556 #1557

Merged
merged 18 commits into from
Mar 31, 2021
Merged
20 changes: 10 additions & 10 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
@deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...)
@deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...)

function Base.getproperty(a::Dense, s::Symbol)
if s === :W
Base.depwarn("field name dense.W is deprecated in favour of dense.weight", :Dense)
return getfield(a, :weight)
elseif s === :b
Base.depwarn("field name dense.b is deprecated in favour of dense.bias", :Dense)
return getfield(a, :bias)
end
return getfield(a, s)
end
# function Base.getproperty(a::Dense, s::Symbol)
# if s === :W
# Base.depwarn("field name dense.W is deprecated in favour of dense.weight", :Dense)
# return getfield(a, :weight)
# elseif s === :b
# Base.depwarn("field name dense.b is deprecated in favour of dense.bias", :Dense)
# return getfield(a, :bias)
# end
# return getfield(a, s)
# end
46 changes: 13 additions & 33 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,51 +109,31 @@ julia> Flux.params(d1) # no trainable bias
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
```
"""
struct Dense{F, M<:AbstractMatrix, B}
weight::M
bias::B
struct Dense{F,S<:AbstractArray,T}
W::S
b::T
σ::F
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
b = create_bias(W, bias, size(W,1))
new{F,M,typeof(b)}(W, b, σ)
end
end

function Dense(in::Integer, out::Integer, σ = identity;
initW = nothing, initb = nothing,
init = glorot_uniform, bias=true)

W = if initW !== nothing
Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense)
initW(out, in)
else
init(out, in)
end
Dense(W, b) = Dense(W, b, identity)

b = if bias === true && initb !== nothing
Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense)
initb(out)
else
bias
end

return Dense(W, b, σ)
function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros, bias::Bool = true)
Dense(initW(out, in), bias ? initb(out) : Zeros(), σ)
end

@functor Dense

function (a::Dense)(x::AbstractVecOrMat)
W, b, σ = a.weight, a.bias, a.σ
return σ.(W*x .+ b)
function (a::Dense)(x)
W, b, σ = a.W, a.b, a.σ
x_reshaped = reshape(x, size(x, 1), :)
x_out = σ.(W * x_reshaped .+ b)
reshape(x_out, :, size(x)[2:end]...)
end

(a::Dense)(x::AbstractArray) =
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1))
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
l.σ == identity || print(io, ", ", l.σ)
l.bias == Zeros() && print(io, "; bias=false")
print(io, ")")
end

Expand Down