From 55d81af15d274d094a4c50f3e6ae82e6321d5d13 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Jun 2024 17:00:33 +0530 Subject: [PATCH] fix: fix incorrect dimensionality of `ODESolution` in `build_function` and `@set` --- ext/SciMLBaseZygoteExt.jl | 10 +--------- src/solutions/ode_solutions.jl | 12 +++++++++++- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 2f7399f11..58e7bb309 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -58,15 +58,7 @@ end du, dprob end T = eltype(eltype(VA.u)) - if dprob.u0 === nothing - N = 2 - elseif dprob isa SciMLBase.BVProblem && !hasmethod(size, Tuple{typeof(dprob.u0)}) - __u0 = hasmethod(dprob.u0, Tuple{typeof(dprob.p), typeof(first(dprob.tspan))}) ? - dprob.u0(dprob.p, first(dprob.tspan)) : dprob.u0(first(dprob.tspan)) - N = length((size(__u0)..., length(du))) - else - N = length((size(dprob.u0)..., length(du))) - end + N = ndims(VA) Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 7c22f9b4a..5636c4993 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -129,6 +129,16 @@ function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution ODESolution{T, N} end +function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple) + u = get(patch, :u, sol.u) + N = u === nothing ? 2 : ndims(eltype(u)) + 1 + T = eltype(eltype(u)) + patch = merge(getproperties(sol), patch) + return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k, + patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats, + patch.alg_choice, patch.retcode, patch.resid, patch.original) +end + Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol) if s === :destats Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats") @@ -276,7 +286,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, prob.u0(prob.p, first(prob.tspan)) : prob.u0(first(prob.tspan)) N = length((size(__u0)..., length(u))) else - N = length((size(prob.u0)..., length(u))) + N = ndims(eltype(u)) + 1 end if prob.f isa Tuple