Skip to content

Commit

Permalink
chore: update to newer versions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 1, 2024
1 parent de2c730 commit 421e3bb
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 82 deletions.
4 changes: 1 addition & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ steps:
queue: "juliagpu"
cuda: "*"
env:
GROUP: "CUDA"
BACKEND_GROUP: "CUDA"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
Expand Down Expand Up @@ -54,7 +54,5 @@ steps:
timeout_in_minutes: 240

env:
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
SECRET_CODECOV_TOKEN: "fbSN+ZbScLIWr1FOpAu1Z8PYWFobqbLGFayOgZE1ebhE8LIH/PILGXUMcdm9gkXVSwgdETDD0s33k14lBkJ90O4dV9w6k79F/pEgzVHV8baMoXZG03BPMxztlcoRXrKtRtAp+MwoATc3Ldb9H5vqgAnVNn5rhn4Rp0Z6LOVRC43hbhKBBKYh/N4gqpIQlcW4dBXmELhlnMFnUILjwGRVgEt/zh8H+vmf0qiIulNIQ/rfGISROHqFML0QDL4icloiqX08J76ZP/gZCeg6rJ0gl3ok3IspNPz51rlbvijqsPNyIHWi29OrAtWX3qKHfrAOoGIrE1d5Oy4wx4XaN/YBhg==;U2FsdGVkX188gcRjkUNMEC2Z5fEFfhsYY4WJbhhINOuCUgqq9XNHVDbJhzFUFVQ+UiuPHFg7CW/gn+3IkSVyOA=="
SECRET_DOCUMENTER_KEY: "jzyAET5IdazYwPAEZAmYmnBALb2dC1GPizCDCdt8xpjIi4ce6QbGGJMKo00ZNzJ/A7ii4bhqysVPXniifFwIGl7x+GSCeavwcSr15pfxJSqPuQYLKxESzIo+SM+l2uJWUz8KYMJ1tSt/Z3Up3qQfLeQFtR+f43b9QrLfhgZGAAdxpwu5VHdI3Xm/gZo5d8xEJ1xs4gqVP0e2A5EFr/j/exaWJL9+AvgO+Gko8NaJGG5B89zP1W2NBlpjttbwzj2naBhDx8A43Qe4eXm+BZd9CIZImiEJnnqoGxLkAyLDksbA68getUHW5z3nGyhWTrg5yfRqq0uyZZGTIOFz6dJrRg==;U2FsdGVkX19QOxLLkdNoQf7Rid3mcSR/renIHQ+/X3o0WxTmU8KDDxzfKuWPeK1fxMon8y45HCJv3HlMuzyfvPWrOmUXccfHK272D8vHu1kk/qZZw8nPd7iYBU9+VAIxwfmI3Av2gC+8tUlOcuUTEVMtMbi/MiLHp+phLYcELKzzrxL8VdrLzna81M+8xVLu7zzNuyK0cUPWLxRHcZc/fewK5Nh7EQ2x8u1b6e5zR0/AcqjCzMayD1RiE7QhRVGdF5GJYnAxc1eoyCwIjXTRfFo0a0Q2h6DEz9FEat/ZCekIuWyVrUkGbpsRqXUTrSH0An7FRRqRlZ9lStRaQY4Z3XBkoIh94vQlXwwLUH20jC7yRTV73CeYmhfigQckHL0JsjjIENz04Ac346fCV6WNQtEak0m3pN/BucoiwRA8l+WU4AK1r84cwGSphKk4SnWRAqeZVuFHck7NkcmHDEkO4C7WTP400oui/5NDMtVZbtnZfLxVzQqijxXj7IflWqF1vKqGmW5aPFMVNeAqwNGu3xM4oIIeHRu0u+k2S5dp1wqRVlMxYXdPtcoFzE0CNsMQdWgsvPd2eet38YRc8ftXNjKzoUSRRCbjGbVr0iJXeNmPg3jfZoVdILHjCN/hcz4nY+61P11OlJAdfE/6HzEr4VoOS4CN+s/brjWycmAKZo2+1e4fSV1xBH7t1spOlESLvsBhZNtj9/zUKgWgMct5hnF4anQcPAeRpz/MBrkwX1gW3WOvCxaqVlRfgGSy6boPgRd3p/ZXN4Xnfeg9RFqKZn21d2gcrc3/1+PTUEkOIv+C9BGszo9IaUziW/Tz2mVP386kX86SF4fF4y3PofcUT2FLTm8Q9ZJBnslOsRP8bq3rIjDiQR3Iz3uGctkGZPs+GOtCR5OrhnnS6BXxkGwt/n9PJsnbXt0Z4tuXihC1B8KfP7mzDvZr3q9X/DGKyZ+oMHdDI+f2+lRwx42nJnsu+nZW9lyhdIwWla9F1rIoVz59HbUrmUhsVmFQYfjy7Nl18g8Wh5r9CkFL/vr6Zpy5lj1J/vhe1501X2FIkKOnLAM73GwtAa4GkbHyu5rNcij6YoozPrJWT4KRNFWGVAqNZ1atG8WwmziwIl2KfBn8jiuP/8o6rXQkmrAzBr6jVnto5FTWnIexEmnbELs20XDck8pO5WQxU1IR9YhKMbrDGbn0jWzVoRmCWpaJgV1AkWu09a++DxIec4+Zt+3SZLj/H57XsBchWHmkFz4NVTBeSans26VmdDd3LxprT8qeH6cioceakmu6yegsKQnJGLmSNyUkHqBqmsCcvyTUyaQUBTFkjLmDeZB3Ifu2kD7AFdx5n58wdJTMZxYviybOCgCV4qe95v5XfIqthp5mF/0F1Wt9ZcEreFSM2Paj5GrQ+M25cZ+kqOSlMet51Q+QBCfQyDF8jdu3j1hVniwpgMI1gqyb2alRfyNx52elTqRn9hPqpFptGH4uJXi8H72YPe4fYkFS7wwELeRIv+nKkNYNLPQAyQFvZ/qB/PRI1YoFBbpi0Vi6iE9xLRq7QVvhJde2EgNbvQk8uakwV630Tht2OuwVdJu/PIbXsQ5i+EuknIlPRdQdhbEIkpuBHFAzxBqA2K92gJ4bbcOjGtDHc0pt0RtvIVoyyJMkYVzr1yBeMWEmsL8qYJ5yzuAFGqpCTmJzXE0ETZLDDJtxwSKj5M2vG59wPNfo9DF+LgJLF+94VydYGNOHy9KuY2Oo3ejV7iFXUtsEV3Id9EkNGq8+t5KIAGk6lnDcM1TTOAc5W8fGGNhYzlqgWK1n3nwPJLykqY7VFHPZjF8Il/8E1IubnPCIyOCTJwKqQlBB5td/bt7YIDEFmkpl7OvUwyc2uYkFmrxGv81OtopsYZOJ+WnwSkqqZ3p2MyqNj3xp92p8itz5tM3tzjrkdfPXsx1QJGY+rkZhCsSf6DSG18AFqI4+Q8uWUwqO5/TJb2z/F2LT88+wJfGPtwGeR+98XgvwjsMWIA/TZfwTrTQsZX2YOIf0bg3yjlEbFM16xAFAA2oItBuvbC6d6NIit4Dukn2WamnOceoTyO6mdHYRh5SBOryr3AWnBJZsUPL3HsC+Xiibgixuwjjalj+HOrAzDlQc8L0Z77dZJhpST0x/gwCleSA3lOKs7MA8ASolCaPVL2pPJXkb97mBxZx8k1n6abhK1w3QVJuYvp7CyGhavsYEqcR+vYx/T0tN4MVOjfRhimqhNihz0VDfY97YS5XavZV07jycqoAlufmH5VSwNbiy8/NY6Q2djc46ISbqvKr6Pf0TZBuJti2gIpm02Btm4rMCawVPpEPieU3GI77nhQp6orq0Zjl5f4XfOKnfcxiqdgip4SVFTglHBTWTiRFnGTF0I3VX4V+RRmqJwwKPN8cxDsNd6wSpylhDUAMfxEvvb+0vAt1yGNUC52OB4bSOXOyZIAU8+08xl7mYGIVUnoWHaR6Y0aHdnywJUuzQ2q3dotfnI1j72MzlHsTK6Lro3YiolDNJpTqLtxmSzkWctw/PfijnoEXtmDnZKptZ7t0v7oTAkdE3kk0RrnFTnMAkyCOREcFcyxglROCoDHsZx3Q+MkWLG/tPMVpuMRhy9gJ1WZTpeExNgs5KgwtrS1HJg7KunWXguFH/zDgODTdKclgfvsVe/SCtlpbO6z5fZji0j1y9LRBVLyTN/LzeR40OBX3r0abk4SGyslAdZMgg2WJdSLVAJ+MtxxbnlKDXDNmu5YehpWdTvm/wIYwTKw+1A48plKburw8fBEofVy9Ubmc8E4z6hQRX2cwcNN9N/60aCwlpM7wVbYfBo4Hw9H/6EawjbRRN9UwmgsfmYUuqCTSi8fNi2dR36bqaoHHURgyqW7DiR7BYgVnOZ+B/2GM8uO2rYgSOhVJf+OK+2HsNly0MW5v3/ft6W7PEsab8IweYWmPLVvJNfHW5CDP6KotdDgm/DcD5owgQ12D95BGWawR5gQxpyjX9uIlxORPq6h0Z79j8gFFsYIfddIdxsJZS9r59FtZe2JL7nK3Dum07tXDGlBCUD4mwv+LNxOJLa8DM6YoEd4Nh8qosfQNJu505Vh/r+PgegnFvG9LRkwQnk8fgPTNKThB067s82YuVg0mv4O9q4Hlm13wTWuvlMr4k1ShBrNyy08YaFCu2hmZm7RizU1rU5MpieiwpQ6cGx+sHBTszB+c89045n2TM4VUedi+vjEq2KuhmIl9ID0EHvWDy5iwOQV5nJ7Rk/Alky2GRZ1CpnJPN29q5lrs6fhvfPquolJTyBTNgVjQ7f0z1zuUQFdhWQX5BFyq/bT46qc+X6dSOvlFenioPDe5MYpA3SZCi2lmVQqHnTOcRZp2HtYpoRUzNB6cT7P1wkRTvAR5PQyuAknkKB+T6HvRb2H8EBLfk+imlyu7mb2iAJNORpZ0Rz+no/5A3wd6qHfTq27h/CDZ91YqGamylrLBdFqyefFYfSbFa1BKikiavpZnYh19hHNl9v0Q4Tkb7ogy7biw8icOvMPWCrxd50zoY1EUe2maNKtkyytJeEOV+Yj6VaUa88M+7WSKaK8QbEB+fBKmwvGkknRKs8lYRoABChwqDZ7M+98pL057QbquxseKX6alzV+IOHlO7I8csRHnF6OpVQG6wdzJZhEwg/0n1K2qTExF9Z3WzUoiQ+NVC3gRQ9Au+x3fpkuLu85lOVjelX3JtdVt1T3623sqxmcEr/TCZT/+X1QyflAkJyw1EMq4sat7wjYK3ugyPPPjo/v2h4TuaoWc0X/+qJPV/o2Vu489loIQ+N59ABZOLldpbkPM7VJIOnnfg+/GMvaEor2YCYElDGXx4BdRmSfOyzFF2Wqz5iTxMbdCo3iZbPQqbFTdMOX7Hy3nT8vUOhCLo+Dkgb7B01nPnm7crmC+TOgi4iDLp4nCqx5OSiG5gd/m54gZHe6Cymwj+DbW303KcvpGBrk0xr2sGUkQiu4vxNz+uW09EyMNCp5cg2AUWG4w6ykTHXUeDDQG232C5K7/tTt8Z09Kp9v71PkwH7hmZUrjAERGvF29zucdTVRmzr++JMH82Sk0chAi5UFs/lbVcN/birI7OVl6okyO3+bKWMCuhje1huOgeZzSk5xKFrgJ1v69TyD1mOa5wYx6IskbWSrFW/sqrhURqpSlfdWVCZiaOHLb/UIgQ0s1xlHyZ0/YOBQFz1VCgKH217ALijV3FOr+q00761SKNFc/IhZLNtVhHhE7lutAjVqyme7RHKd4fjFOD6oREyDYXHULmDGPRTmlFSxwE4+n3N9AInajQLH82CGWO1nV3u7qSY5vSbuzQIxCr8OKQfW8AzTdNjUoEtU+ojprLZ4V4r1dr01eLqXSVJ12Yq9Sm/Ivu1SZkHJl9oIxNjWSbRRMoYIVH3yVv1HyGGajcmKuzIfORuBZm"
11 changes: 4 additions & 7 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ on:
branches:
- main
paths-ignore:
- 'docs/**'
- "docs/**"
push:
branches:
- main
paths-ignore:
- 'docs/**'
- "docs/**"
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
Expand All @@ -22,7 +22,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1'
- "1"
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -41,10 +41,7 @@ jobs:
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: "CPU"
JULIA_NUM_THREADS: 12
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
BACKEND_GROUP: "CPU"
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
Expand Down
2 changes: 0 additions & 2 deletions LocalPreferences.toml

This file was deleted.

14 changes: 8 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -40,16 +39,16 @@ ExplicitImports = "1.6.0"
FastClosures = "0.3"
ForwardDiff = "0.10.36"
Functors = "0.4.10"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LinearSolve = "2.21.2"
Lux = "0.5.56"
LuxCUDA = "0.3.2"
LuxCore = "0.1.14"
LuxTestUtils = "0.1.15"
LuxTestUtils = "1"
NLsolve = "4.5.1"
NNlib = "0.9.17"
NonlinearSolve = "3.10.0"
OrdinaryDiffEq = "6.74.1"
PrecompileTools = "1"
Random = "1.10"
ReTestItems = "1.23.1"
SciMLBase = "2"
Expand All @@ -67,8 +66,11 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand All @@ -79,4 +81,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Documenter", "ExplicitImports", "ForwardDiff", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
test = ["Aqua", "Documenter", "ExplicitImports", "ForwardDiff", "Functors", "GPUArraysCore", "Hwloc", "InteractiveUtils", "LuxTestUtils", "MLDataDevices", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
2 changes: 1 addition & 1 deletion ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ using DeepEquilibriumNetworks: DEQs
linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)
return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP())
end
@inline DEQs.__default_sensealg(prob::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())
@inline DEQs.__default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())

end
3 changes: 1 addition & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ CRC.@non_differentiable __gaussian_like(::Any...)
@inline __tupleify(x) = @closure(u->(u, x))

# Jacobian Stabilization
## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33
function __estimate_jacobian_trace(ad::AutoFiniteDiff, model::StatefulLuxLayer, z, x, rng)
function __estimate_jacobian_trace(::AutoFiniteDiff, model::StatefulLuxLayer, z, x, rng)
__f = @closure u -> model((u, x))
res = zero(eltype(x))
ϵ = cbrt(eps(typeof(res)))
Expand Down
50 changes: 25 additions & 25 deletions test/layers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ export loss_function, SOLVERS

end

@testitem "DEQ" setup=[SharedTestSetup, LayersTestSetup] timeout=10000 begin
@testitem "DEQ" setup=[SharedTestSetup, LayersTestSetup] begin
using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote

rng = __get_prng(0)
rng = StableRNG(0)

base_models = [Parallel(+, __get_dense_layer(2 => 2), __get_dense_layer(2 => 2)),
Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1))]
init_models = [__get_dense_layer(2 => 2), __get_conv_layer((1, 1), 1 => 1)]
base_models = [Parallel(+, dense_layer(2 => 2), dense_layer(2 => 2)),
Parallel(+, conv_layer((1, 1), 1 => 1), conv_layer((1, 1), 1 => 1))]
init_models = [dense_layer(2 => 2), conv_layer((1, 1), 1 => 1)]
x_sizes = [(2, 14), (3, 3, 1, 3)]

model_type = (:deq, :skipdeq, :skipregdeq)
Expand All @@ -34,7 +34,7 @@ end
jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] :
_jacobian_regularizations

@testset "Solver: $(__nameof(solver)) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS,
@testset "Solver: $(nameof(typeof(solver))) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS,
mtype in model_type,
jacobian_regularization in jacobian_regularizations

Expand Down Expand Up @@ -65,8 +65,8 @@ end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
@test is_finite_gradient(gs_x)
@test is_finite_gradient(gs_ps)

ps, st = Lux.setup(rng, model) |> dev
st = Lux.update_state(st, :fixed_depth, Val(10))
Expand All @@ -82,28 +82,28 @@ end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
@test is_finite_gradient(gs_x)
@test is_finite_gradient(gs_ps)
end
end
end
end

@testitem "Multiscale DEQ" setup=[SharedTestSetup, LayersTestSetup] timeout=10000 begin
@testitem "Multiscale DEQ" setup=[SharedTestSetup, LayersTestSetup] begin
using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote

rng = __get_prng(0)
rng = StableRNG(0)

main_layers = [(Parallel(+, __get_dense_layer(4 => 4), __get_dense_layer(4 => 4)),
__get_dense_layer(3 => 3), __get_dense_layer(2 => 2), __get_dense_layer(1 => 1))]
main_layers = [(Parallel(+, dense_layer(4 => 4), dense_layer(4 => 4)),
dense_layer(3 => 3), dense_layer(2 => 2), dense_layer(1 => 1))]

mapping_layers = [[NoOpLayer() __get_dense_layer(4 => 3) __get_dense_layer(4 => 2) __get_dense_layer(4 => 1);
__get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1);
__get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1);
__get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()]]
mapping_layers = [[NoOpLayer() dense_layer(4 => 3) dense_layer(4 => 2) dense_layer(4 => 1);
dense_layer(3 => 4) NoOpLayer() dense_layer(3 => 2) dense_layer(3 => 1);
dense_layer(2 => 4) dense_layer(2 => 3) NoOpLayer() dense_layer(2 => 1);
dense_layer(1 => 4) dense_layer(1 => 3) dense_layer(1 => 2) NoOpLayer()]]

init_layers = [(__get_dense_layer(4 => 4), __get_dense_layer(4 => 3),
__get_dense_layer(4 => 2), __get_dense_layer(4 => 1))]
init_layers = [(dense_layer(4 => 4), dense_layer(4 => 3),
dense_layer(4 => 2), dense_layer(4 => 1))]

x_sizes = [(4, 3)]
scales = [((4,), (3,), (2,), (1,))]
Expand All @@ -112,7 +112,7 @@ end
jacobian_regularizations = (nothing,)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
@testset "Solver: $(__nameof(solver))" for solver in SOLVERS,
@testset "Solver: $(nameof(typeof(solver)))" for solver in SOLVERS,
mtype in model_type,
jacobian_regularization in jacobian_regularizations

Expand Down Expand Up @@ -153,8 +153,8 @@ end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
@test is_finite_gradient(gs_x)
@test is_finite_gradient(gs_ps)

ps, st = Lux.setup(rng, model) |> dev
st = Lux.update_state(st, :fixed_depth, Val(10))
Expand All @@ -172,8 +172,8 @@ end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
@test is_finite_gradient(gs_x)
@test is_finite_gradient(gs_ps)
end
end
end
Expand Down
29 changes: 27 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
using ReTestItems
using ReTestItems, Pkg, InteractiveUtils, Hwloc

ReTestItems.runtests(@__DIR__)
@info sprint(versioninfo)

const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all"))
const EXTRA_PKGS = String[]

(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU")

if !isempty(EXTRA_PKGS)
@info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS
Pkg.add(EXTRA_PKGS)
Pkg.update()
Base.retry_load_extensions()
Pkg.instantiate()
end

using DeepEquilibriumNetworks

const RETESTITEMS_NWORKERS = parse(
Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16))))
const RETESTITEMS_NWORKER_THREADS = parse(Int,
get(ENV, "RETESTITEMS_NWORKER_THREADS",
string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1))))

ReTestItems.runtests(DeepEquilibriumNetworks; nworkers=RETESTITEMS_NWORKERS,
nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=12000)
69 changes: 35 additions & 34 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
@@ -1,55 +1,56 @@
@testsetup module SharedTestSetup

using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote, ForwardDiff
import LuxTestUtils: @jet
using LuxCUDA
using LuxTestUtils
using MLDataDevices, GPUArraysCore

CUDA.allowscalar(false)
LuxTestUtils.jet_target_modules!(["Boltz", "Lux", "LuxLib"])

__nameof(::X) where {X} = nameof(X)
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all"))

__get_prng(seed::Int) = StableRNG(seed)
if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda"
using LuxCUDA
end

if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu"
using AMDGPU
end

__is_finite_gradient(x::AbstractArray) = all(isfinite, x)
GPUArraysCore.allowscalar(false)

cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu"
function cuda_testing()
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") &&
MLDataDevices.functional(CUDADevice)
end
function amdgpu_testing()
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") &&
MLDataDevices.functional(AMDGPUDevice)
end

function __is_finite_gradient(gs::NamedTuple)
gradient_is_finite = Ref(true)
function __is_gradient_finite(x)
!isnothing(x) && !all(isfinite, x) && (gradient_is_finite[] = false)
return x
end
fmap(__is_gradient_finite, gs)
return gradient_is_finite[]
const MODES = begin
modes = []
cpu_testing() && push!(modes, ("cpu", Array, CPUDevice(), false))
cuda_testing() && push!(modes, ("cuda", CuArray, CUDADevice(), true))
amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, AMDGPUDevice(), true))
modes
end

function __get_dense_layer(args...; kwargs...)
is_finite_gradient(x::AbstractArray) = all(isfinite, x)
is_finite_gradient(::Nothing) = true
is_finite_gradient(gs) = all(is_finite_gradient, fleaves(gs))

function dense_layer(args...; kwargs...)
init_weight(rng::AbstractRNG, dims...) = randn(rng, Float32, dims) .* 0.001f0
return Dense(args...; init_weight, use_bias=false, kwargs...)
end

function __get_conv_layer(args...; kwargs...)
function conv_layer(args...; kwargs...)
init_weight(rng::AbstractRNG, dims...) = randn(rng, Float32, dims) .* 0.001f0
return Conv(args...; init_weight, use_bias=false, kwargs...)
end

const GROUP = get(ENV, "GROUP", "All")

cpu_testing() = GROUP == "All" || GROUP == "CPU"
cuda_testing() = LuxCUDA.functional() && (GROUP == "All" || GROUP == "CUDA")

const MODES = begin
cpu_mode = ("CPU", Array, LuxCPUDevice(), false)
cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true)

modes = []
cpu_testing() && push!(modes, cpu_mode)
cuda_testing() && push!(modes, cuda_mode)

modes
end

export Lux, LuxCore, LuxLib
export MODES, __get_dense_layer, __get_conv_layer, __is_finite_gradient, __get_prng,
__nameof, @jet
export MODES, dense_layer, conv_layer, is_finite_gradient, StableRNG, @jet, test_gradients

end

0 comments on commit 421e3bb

Please sign in to comment.