Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Steady State Adjoint #93

Merged
merged 11 commits into from
Oct 20, 2023
Merged

Better Steady State Adjoint #93

merged 11 commits into from
Oct 20, 2023

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Aug 21, 2023

Sister PR to SciML/SciMLSensitivity.jl#877

- [ ] Compare the improved SSAdjoint Methods

Fast DEQ Solves using JNFK

using DeepEquilibriumNetworks, LuxAMDGPU, Lux, Random, OrdinaryDiffEq, Zygote,
    SciMLSensitivity, LinearAlgebra, NonlinearSolve, LegibleLambdas, Statistics, Optimisers,
    LinearSolve

import Lux.Experimental: @compact

function NonlinearSolve.SparseDiffTools.get_tag(::AbstractArray{
    NonlinearSolve.ForwardDiff.Dual{T, V, N}}) where {T, V, N}
    return T
end

dudt = @compact(; c1=CrossCor((3, 3), 1 => 4, swish; pad=SamePad()),
    c2=CrossCor((3, 3), 1 => 4, swish; pad=SamePad()),
    c3=Chain(;
        l1=CrossCor((3, 3), 8 => 32, swish; pad=SamePad()),
        l2=CrossCor((3, 3), 32 => 128, swish; pad=SamePad()),
        l3=CrossCor((3, 3), 128 => 32, swish; pad=SamePad()),
        l4=CrossCor((3, 3), 32 => 8, swish; pad=SamePad()),
        l5=CrossCor((3, 3), 8 => 1; pad=SamePad()))) do (u, x)
    return c3(cat(c1(u), c2(x); dims=Val(3)))
end

dev = gpu_device()

model = DeepEquilibriumNetwork(dudt,
    DiscreteDEQSolver(NewtonRaphson(; linsolve=SimpleGMRES()));
    abstol=1.0f-6, reltol=1.0f-6, maxiters=20)

x = randn(Float32, 32, 32, 1, 16) |> dev;

ps, st = Lux.setup(Xoshiro(0), model) |> dev;

model(x, ps, st)

ai-maintainer[bot]

This comment was marked as outdated.

@avik-pal avik-pal changed the title Sister PR to https://github.com/SciML/SciMLSensitivity.jl/pull/877 [WIP] Better Steady State Adjoint Aug 21, 2023
@codecov
Copy link

codecov bot commented Aug 21, 2023

Codecov Report

Merging #93 (34932dc) into main (30cdb93) will decrease coverage by 2.28%.
Report is 4 commits behind head on main.
The diff coverage is 95.00%.

@@            Coverage Diff             @@
##             main      #93      +/-   ##
==========================================
- Coverage   97.92%   95.65%   -2.28%     
==========================================
  Files           9        9              
  Lines         241      207      -34     
==========================================
- Hits          236      198      -38     
- Misses          5        9       +4     
Files Coverage Δ
src/chainrules.jl 100.00% <100.00%> (ø)
src/layers/core.jl 100.00% <100.00%> (ø)
src/layers/deq.jl 100.00% <100.00%> (ø)
src/layers/evaluate.jl 97.22% <100.00%> (-2.78%) ⬇️
src/layers/jacobian_stabilization.jl 100.00% <100.00%> (+4.54%) ⬆️
src/layers/mdeq.jl 95.23% <100.00%> (-0.37%) ⬇️
src/solve.jl 100.00% <100.00%> (ø)
src/DeepEquilibriumNetworks.jl 33.33% <0.00%> (-66.67%) ⬇️
src/utils.jl 83.33% <50.00%> (-16.67%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@avik-pal
Copy link
Member Author

With SciML/SciMLSensitivity.jl#918 this should be done

@avik-pal avik-pal changed the title [WIP] Better Steady State Adjoint Better Steady State Adjoint Oct 19, 2023
@ChrisRackauckas
Copy link
Member

Why is a special adjoint still needed?

@avik-pal avik-pal force-pushed the ap/ssadjoint_better branch from 09f11bf to 4d6c1ed Compare October 19, 2023 21:49
@avik-pal
Copy link
Member Author

There is no special adjoint. I just migrated to using the non-component arrays version

@avik-pal
Copy link
Member Author

The name is from the SciMLSensitivity change 😅

@avik-pal avik-pal closed this Oct 20, 2023
@avik-pal avik-pal reopened this Oct 20, 2023
@avik-pal avik-pal force-pushed the ap/ssadjoint_better branch from 80a2aad to 66ecf1b Compare October 20, 2023 16:13
@avik-pal
Copy link
Member Author

Needs to wait till LuxDL/Lux.jl#425 is merged

@avik-pal avik-pal force-pushed the ap/ssadjoint_better branch from 66ecf1b to 34932dc Compare October 20, 2023 16:49
@avik-pal avik-pal closed this Oct 20, 2023
@avik-pal avik-pal reopened this Oct 20, 2023
@avik-pal avik-pal merged commit ddc5efb into main Oct 20, 2023
7 of 14 checks passed
@avik-pal avik-pal deleted the ap/ssadjoint_better branch October 20, 2023 22:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants