Skip to content

Commit

Permalink
Relax dropout types
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 22, 2022
1 parent 757fe76 commit 733f74c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.4.30"
version = "0.4.31"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
4 changes: 2 additions & 2 deletions src/layers/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function Dropout(p; dims=:)
return Dropout(p, 1 / (1 - p), dims)
end

function (d::Dropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T}
function (d::Dropout)(x, ps, st::NamedTuple)
y, _, rng = LuxLib.dropout(st.rng, x, d.p, st.training; invp=d.q, d.dims)
return y, merge(st, (; rng))
end
Expand Down Expand Up @@ -113,7 +113,7 @@ function VariationalHiddenDropout(p; dims=:)
return VariationalHiddenDropout(p, 1 / (1 - p), dims)
end

function (d::VariationalHiddenDropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T}
function (d::VariationalHiddenDropout)(x, ps, st::NamedTuple)
_mask = st.mask === nothing ? x : st.mask
y, mask, rng = LuxLib.dropout(st.rng, x, _mask, d.p, st.training, st.update_mask;
invp=d.q, d.dims)
Expand Down
12 changes: 6 additions & 6 deletions test/layers/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ include("../test_utils.jl")
rng = Random.default_rng()
Random.seed!(rng, 0)

@testset "Dropout" begin
layer = Dropout(0.5f0)
@testset "Dropout" begin for p in (0.5f0, 0.5)
layer = Dropout(p)
display(layer)
ps, st = Lux.setup(rng, layer)
x = randn(Float32, 5, 2)
Expand All @@ -27,10 +27,10 @@ Random.seed!(rng, 0)
st = Lux.testmode(st)

@test first(layer(x, ps, st)) == x
end
end end

@testset "VariationalHiddenDropout" begin
layer = VariationalHiddenDropout(0.5f0)
@testset "VariationalHiddenDropout" begin for p in (0.5f0, 0.5)
layer = VariationalHiddenDropout(p)
display(layer)
ps, st = Lux.setup(rng, layer)
x = randn(Float32, 5, 2)
Expand Down Expand Up @@ -61,4 +61,4 @@ end
run_JET_tests(layer, x, ps, st__)
test_gradient_correctness_fdm(x -> sum(layer(x, ps, st__)[1]), x; atol=1.0f-3,
rtol=1.0f-3)
end
end end

0 comments on commit 733f74c

Please sign in to comment.