Skip to content

Commit

Permalink
Merge pull request #1558 from mcabbott/issue1556
Browse files Browse the repository at this point in the history
Minimal fix of #1556, remove eltype checks
  • Loading branch information
DhairyaLGandhi authored Mar 31, 2021
2 parents 8100707 + e5a9864 commit 28f34d1
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 11 deletions.
8 changes: 2 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,10 @@ to the constructor's keyword `bias=bias`.
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
bias ? fill!(similar(weights, dims...), 0) : Zeros()
end

function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
if eltype(bias) == eltype(weights)
return bias
else
@warn "converting bias to match element type of weights" typeof(weights) typeof(bias) maxlog=3 _id=hash(dims)
return broadcast(eltype(weights), bias)
end
bias
end

"""
Expand Down
6 changes: 3 additions & 3 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ import Flux: activations
@test Dense(rand(100,10), false, tanh).σ == tanh
@test Dense(rand(100,10), rand(100)).σ == identity
@test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type
@test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match
@test_skip Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match

@test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64}
@test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}
@test_skip Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64}

@test_throws MethodError Dense(10, 10.5)
@test_throws MethodError Dense(10, 10.5, tanh)
Expand Down Expand Up @@ -167,7 +167,7 @@ import Flux: activations
@test size(b3(rand(4), rand(5))) == (3,)

b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros)
@test b4.bias isa Vector{Float32}
@test_skip b4.bias isa Vector{Float32}

@test_throws ArgumentError Flux.Bilinear(rand(3)) # expects a 3-array
@test_throws ArgumentError Flux.Bilinear(rand(3,4), false, tanh)
Expand Down
2 changes: 1 addition & 1 deletion test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ end
@test fun(rand(2,3,4,5), false).bias isa Flux.Zeros
if fun == Conv
@test fun(rand(2,3,4,5,6), rand(6)).bias isa Vector{Float64}
@test fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
@test_skip fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}
elseif fun == DepthwiseConv
@test fun(rand(2,3,4,5,6), rand(30)).bias isa Vector{Float64}
end
Expand Down
2 changes: 1 addition & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ end
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
@test l1.W == l2.W
@test l1.b == l2.b
@test typeof(l1.b) === typeof(l2.b)
@test_skip typeof(l1.b) === typeof(l2.b)
end

@testset "loadparams!" begin
Expand Down

0 comments on commit 28f34d1

Please sign in to comment.