Skip to content

Commit

Permalink
test: more tests fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 31, 2024
1 parent 9826a2d commit 4dc08d1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ComponentArrays = "0.15.18"
ConcreteStructs = "0.2.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.16"
EnzymeCore = "0.8.6"
EnzymeCore = "0.8.8"
FastClosures = "0.3.2"
Flux = "0.15, 0.16"
ForwardDiff = "0.10.36"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Compat = "4.16"
CpuId = "0.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.16"
EnzymeCore = "0.8.6"
EnzymeCore = "0.8.8"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ BenchmarkTools = "1.5"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.18"
Enzyme = "0.13.16"
EnzymeCore = "0.8.6"
EnzymeCore = "0.8.8"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
Expand Down
26 changes: 8 additions & 18 deletions test/layers/normalize_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ end
@jet __f(z)
end

broken_backends = VERSION v"1.11-" ? Any[AutoEnzyme()] : []

@testset "Conv" begin
c = Conv((3, 3), 3 => 3; init_bias=Lux.ones32)

Expand All @@ -165,35 +163,31 @@ end
x = randn(rng, Float32, 3, 3, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(c, (:weight,))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 3, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(c, (:weight, :bias), (2, 2))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 3, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(c, (:weight,), (2,))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 3, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
end

@testset "Dense" begin
Expand All @@ -205,35 +199,31 @@ end
x = randn(rng, Float32, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(d, (:weight,))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(d, (:weight, :bias), (2, 2))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(d, (:weight,), (2,))
display(wn)
ps, st = Lux.setup(rng, wn) |> dev
x = randn(rng, Float32, 3, 1) |> aType

@jet wn(x, ps, st)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)
@test_gradients(sumabs2first, wn, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
end

# See https://github.com/LuxDL/Lux.jl/issues/95
Expand Down

0 comments on commit 4dc08d1

Please sign in to comment.