diff --git a/Project.toml b/Project.toml index f5a8fb6a5..ccf2fe03a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.4.30" +version = "0.4.31" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/layers/dropout.jl b/src/layers/dropout.jl index 660147323..deb15a5ad 100644 --- a/src/layers/dropout.jl +++ b/src/layers/dropout.jl @@ -48,7 +48,7 @@ function Dropout(p; dims=:) return Dropout(p, 1 / (1 - p), dims) end -function (d::Dropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T} +function (d::Dropout)(x, ps, st::NamedTuple) y, _, rng = LuxLib.dropout(st.rng, x, d.p, st.training; invp=d.q, d.dims) return y, merge(st, (; rng)) end @@ -113,7 +113,7 @@ function VariationalHiddenDropout(p; dims=:) return VariationalHiddenDropout(p, 1 / (1 - p), dims) end -function (d::VariationalHiddenDropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T} +function (d::VariationalHiddenDropout)(x, ps, st::NamedTuple) _mask = st.mask === nothing ? x : st.mask y, mask, rng = LuxLib.dropout(st.rng, x, _mask, d.p, st.training, st.update_mask; invp=d.q, d.dims) diff --git a/test/layers/dropout.jl b/test/layers/dropout.jl index dc3d9f504..6a45d0f21 100644 --- a/test/layers/dropout.jl +++ b/test/layers/dropout.jl @@ -5,8 +5,8 @@ include("../test_utils.jl") rng = Random.default_rng() Random.seed!(rng, 0) -@testset "Dropout" begin - layer = Dropout(0.5f0) +@testset "Dropout" begin for p in (0.5f0, 0.5) + layer = Dropout(p) display(layer) ps, st = Lux.setup(rng, layer) x = randn(Float32, 5, 2) @@ -27,10 +27,10 @@ Random.seed!(rng, 0) st = Lux.testmode(st) @test first(layer(x, ps, st)) == x -end +end end -@testset "VariationalHiddenDropout" begin - layer = VariationalHiddenDropout(0.5f0) +@testset "VariationalHiddenDropout" begin for p in (0.5f0, 0.5) + layer = VariationalHiddenDropout(p) display(layer) ps, st = Lux.setup(rng, layer) x = randn(Float32, 5, 2) @@ -61,4 +61,4 @@ end run_JET_tests(layer, x, ps, st__) test_gradient_correctness_fdm(x -> sum(layer(x, ps, st__)[1]), x; atol=1.0f-3, rtol=1.0f-3) -end +end end