Skip to content

Commit

Permalink
oWe can now support non flat vector parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 25, 2023
1 parent 71b0901 commit dda196c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
12 changes: 8 additions & 4 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ function CRC.rrule(::typeof(Setfield.set), obj, l::Setfield.PropertyLens{field},
return res, setfield_pullback
end

# Honestly no clue why this is needed! -- probably a whacky fix which shouldn't be ever
# needed.
ZygoteRules.gradtuple1(::NamedTuple{()}) = (nothing, nothing, nothing, nothing, nothing)
ZygoteRules.gradtuple1(x::NamedTuple) = collect(values(x))
function CRC.rrule(::typeof(_construct_problem), deq::AbstractDEQs, dudt, z, ps, x)
prob = _construct_problem(deq, dudt, z, ps, x)
function ∇_construct_problem(Δ)
return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), Δ.u0,
(; model = Δ.p.ps), Δ.p.x)
end
return prob, ∇_construct_problem
end
13 changes: 6 additions & 7 deletions src/layers/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ end

@inline _postprocess_output(_, z_star) = z_star

@inline function _construct_problem(::AbstractDEQs, dudt, z, ps)
return SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model)
@inline function _construct_problem(::AbstractDEQs, dudt, z, ps, x)
return SteadyStateProblem(ODEFunction{false}(dudt), z,
NamedTuple{(:ps, :x)}((ps.model, x)))
end

@inline _fix_solution_output(_, x) = x
Expand Down Expand Up @@ -42,14 +43,12 @@ function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})

function dudt(u, p, t)
nfe += 1
u_ = model((u, x), p)
return u_ .- u
return model((u, p.x), p.ps) .- u
end

prob = _construct_problem(deq, dudt, z, ps)
prob = _construct_problem(deq, dudt, z, ps, x)
sol = solve(prob, deq.solver; deq.sensealg, deq.kwargs...)

z_star = model((_fix_solution_output(deq, sol.u), x), ps.model)
z_star = sol.u

if _jacobian_regularization(deq)
rng = Lux.replicate(st.rng)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/mdeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ function _get_initial_condition(deq::MultiScaleNeuralODE, x, ps, st)
return _get_zeros_initial_condition_mdeq(deq.scales, x, st)
end

@inline function _construct_problem(::MultiScaleNeuralODE, dudt, z, ps)
@inline function _construct_problem(::MultiScaleNeuralODE, dudt, z, ps, x)
return ODEProblem(ODEFunction{false}(dudt), z, (0.0f0, 1.0f0), ps.model)
end

Expand Down

0 comments on commit dda196c

Please sign in to comment.