Skip to content

Commit

Permalink
also, don't allow bias to be wider type
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jan 8, 2023
1 parent 438db81 commit db4bb28
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,14 @@ to the constructor's keyword `bias=bias`.
* `bias == true` creates a trainable array of the given size, of the same type as `weights`, initialised to zero.
* `bias == false` returns `false`, which is understood by AD to be non-differentiable.
* `bias::AbstractArray` uses the array provided, provided it has the correct size.
It does not at present correct the `eltype` to match that of `weights`.
It will also correct the `eltype` to match that of `weights`.
"""
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
bias ? fill!(similar(weights, dims...), 0) : false
end
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
bias
convert(AbstractArray{eltype(weights)}, bias)
end


Expand Down
4 changes: 2 additions & 2 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ import Flux: activations
@test Dense(rand(100,10), false, tanh).σ == tanh
@test Dense(rand(100,10), rand(100)).σ == identity
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
@test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match

@test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
@test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}

@test_throws MethodError Dense(10, 10.5)
@test_throws MethodError Dense(10, 10.5, tanh)
Expand Down

0 comments on commit db4bb28

Please sign in to comment.