Skip to content

Commit

Permalink
Merge pull request #891 from AayushSabharwal/as/fix-mtk-tests
Browse files Browse the repository at this point in the history
feat: add new `remake(::AbstractSciMLFunction)`, fix some `remake` bugs.
  • Loading branch information
ChrisRackauckas authored Dec 14, 2024
2 parents 97a79f7 + d455a5a commit b0dc015
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 115 deletions.
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: ADTypes, AbstractADType
import Accessors: @set, @reset, @delete
import Accessors: @set, @reset, @delete, @insert
using Expronicon.ADT: @match

using Reexport
Expand Down
6 changes: 4 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,13 @@ function evaluate_f(
return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t)
end

function evaluate_f(integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t)
function evaluate_f(
integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

function evaluate_f(integrator::AbstractSDDEIntegrator, prob::AbstractSDDEProblem, f, isinplace, u, p, t)
function evaluate_f(integrator::AbstractSDDEIntegrator,
prob::AbstractSDDEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

Expand Down
224 changes: 112 additions & 112 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,105 @@ function remake(
_remake_internal(prob; kwargs..., p)
end

"""
$(TYPEDSIGNATURES)
A utility function which merges two `NamedTuple`s `a` and `b`, assuming that the
keys of `a` are a subset of those of `b`. Values in `b` take priority over those
in `a`, except if they are `nothing`. Keys not present in `a` are assumed to have
a value of `nothing`.
"""
function _similar_namedtuple_merge_ignore_nothing(a::NamedTuple, b::NamedTuple)
ks = fieldnames(typeof(b))
return NamedTuple{ks}(ntuple(Val(length(ks))) do i
something(get(b, ks[i], nothing), get(a, ks[i], nothing), Some(nothing))
end)
end

"""
remake(func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...)
`remake` the given `func`. Return an `AbstractSciMLFunction` of the same kind, `isinplace` and
`specialization` as `func`. Retain the properties of `func`, except those that are overridden
by keyword arguments. For stochastic functions (e.g. `SDEFunction`) the `g` keyword argument
is used to override `func.g`. For split functions (e.g. `SplitFunction`) the `f2` keyword
argument is used to override `func.f2`, and `f` is used for `func.f1`. If
`f isa AbstractSciMLFunction` and `func` is not a split function, properties of `f` will
override those of `func` (but not ones provided via keyword arguments). Properties of `f` that
are `nothing` will fall back to those in `func` (unless provided via keyword arguments). If
`f` is a different type of `AbstractSciMLFunction` from `func`, the returned function will be
of the kind of `f` unless `func` is a split function. If `func` is a split function, `f` and
`f2` will be wrapped in the appropriate `AbstractSciMLFunction` type with the same `isinplace`
and `specialization` as `func`.
"""
function remake(
func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...)
# retain iip and spec of original function
iip = isinplace(func)
spec = specialization(func)
# retain properties of original function
props = getproperties(func)

if f === missing || is_split_function(func)
# if no `f` is provided, create the same type of SciMLFunction
T = parameterless_type(func)
f = isdefined(func, :f) ? func.f : func.f1
elseif f isa AbstractSciMLFunction
# if `f` is a SciMLFunction, create that type
T = parameterless_type(f)
# properties of `f` take priority over those in the existing `func`
# ignore properties of `f` which are `nothing` but present in `func`
props = _similar_namedtuple_merge_ignore_nothing(props, getproperties(f))
f = isdefined(f, :f) ? f.f : f.f1
else
# if `f` is provided but not a SciMLFunction, create the same type
T = parameterless_type(func)
end

# minor hack to avoid breaking MTK, since prior to ~9.57 in `remake_initialization_data`
# it creates a `NonlinearFunction` inside a `NonlinearFunction`. Just recursively unwrap
# in this case and forget about properties.
while !is_split_function(T) && f isa AbstractSciMLFunction
f = isdefined(f, :f) ? f.f : f.f1
end

props = @delete props.f
props = @delete props.f1

args = (f,)
if is_split_function(T)
# for DynamicalSDEFunction and SplitFunction
if isdefined(props, :cache)
props = @insert props._func_cache = props.cache
props = @delete props.cache
end

# `f1` and `f2` are wrapped in another SciMLFunction, unless they're
# already wrapped in the appropriate type or are an `AbstractSciMLOperator`
if !(f isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)})
f = split_function_f_wrapper(T){iip, spec}(f)
end
# For SplitFunction
# we don't do the same thing as `g`, because for SDEs `g` is
# stored in the problem as well, whereas for Split ODEs etc
# f2 is a part of the function. Thus, if the user provides
# a SciMLFunction for `f` which contains `f2` we use that.
f2 = coalesce(f2, get(props, :f2, missing), func.f2)
if !(f2 isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)})
f2 = split_function_f_wrapper(T){iip, spec}(f2)
end
props = @delete props.f2
args = (args..., f2)
end
if isdefined(func, :g)
# For SDEs/SDDEs where `g` is not a keyword
g = coalesce(g, func.g)
props = @delete props.g
args = (args..., g)
end
T{iip, spec}(args...; props..., kwargs...)
end

"""
remake(prob::ODEProblem; f = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, _kwargs...)
Expand Down Expand Up @@ -135,53 +234,26 @@ function remake(prob::ODEProblem; f = missing,
initialization_data = nothing
end

if f === missing
if specialization(prob.f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
if iip
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
wrapfun_iip(
unwrapped_f(prob.f.f),
(newu0, newu0, newp,
ptspan[1])); initialization_data)
else
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
wrapfun_oop(
unwrapped_f(prob.f.f),
(newu0, newp,
ptspan[1])); initialization_data)
end
else
_f = prob.f
if __has_initialization_data(_f)
props = getproperties(_f)
@reset props.initialization_data = initialization_data
props = values(props)
_f = parameterless_type(_f){iip, specialization(_f), map(typeof, props)...}(props...)
end
end
elseif f isa AbstractODEFunction
_f = f
elseif specialization(prob.f) === FunctionWrapperSpecialize
f = coalesce(f, prob.f)
f = remake(prob.f; f, initialization_data)

if specialization(f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
if iip
_f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f,
(newu0, newu0, newp,
ptspan[1])))
f = remake(
f; f = wrapfun_iip(unwrapped_f(f.f), (newu0, newu0, newp, ptspan[1])))
else
_f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f,
(newu0, newp, ptspan[1])))
f = remake(
f; f = wrapfun_oop(unwrapped_f(f.f), (newu0, newu0, newp, ptspan[1])))
end
else
_f = ODEFunction{isinplace(prob), specialization(prob.f)}(f)
end

prob = if kwargs === missing
ODEProblem{isinplace(prob)}(
_f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
ODEProblem{iip}(
f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
_kwargs...)
else
ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...)
ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...)
end

if lazy_initialization === nothing
Expand Down Expand Up @@ -395,42 +467,6 @@ function remake(prob::SDEProblem;
return prob
end

"""
remake(func::SDEFunction; f = missing, g = missing,
mass_matrix = missing, analytic = missing, kwargs...)
Remake the given `SDEFunction`.
"""
function remake(func::Union{SDEFunction, SDDEFunction};
f = missing,
g = missing,
mass_matrix = missing,
analytic = missing,
sys = missing,
kwargs...)
props = getproperties(func)
props = @delete props.f
props = @delete props.g
@reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix)
@reset props.analytic = coalesce(analytic, func.analytic)
@reset props.sys = coalesce(sys, func.sys)

if f === missing
f = func.f
end

if g === missing
g = func.g
end

if f isa AbstractSciMLFunction
f = f.f
end

T = func isa SDEFunction ? SDEFunction : SDDEFunction
return T{isinplace(func)}(f, g; props..., kwargs...)
end

function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,
tspan = missing, p = missing, constant_lags = missing,
dependent_lags = missing, order_discontinuity_t0 = missing,
Expand Down Expand Up @@ -497,28 +533,6 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,
return prob
end

function remake(func::DDEFunction;
f = missing,
mass_matrix = missing,
analytic = missing,
sys = missing,
kwargs...)
props = getproperties(func)
props = @delete props.f
@reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix)
@reset props.analytic = coalesce(analytic, func.analytic)
@reset props.sys = coalesce(sys, func.sys)

if f === missing
f = func.f
end
if f isa AbstractSciMLFunction
f = f.f
end

return DDEFunction{isinplace(func)}(f; props..., kwargs...)
end

function remake(prob::SDDEProblem;
f = missing,
g = missing,
Expand Down Expand Up @@ -706,6 +720,7 @@ function remake(prob::NonlinearProblem;
initialization_data = nothing
end

f = coalesce(f, prob.f)
f = remake(prob.f; f, initialization_data)

if problem_type === missing
Expand Down Expand Up @@ -737,22 +752,6 @@ function remake(prob::NonlinearProblem;
return prob
end

function remake(func::NonlinearFunction;
f = missing,
kwargs...)
props = getproperties(func)
props = @delete props.f

if f === missing
f = func.f
end
if f isa AbstractSciMLFunction
f = f.f
end

return NonlinearFunction{isinplace(func)}(f; props..., kwargs...)
end

"""
remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing,
kwargs = missing, _kwargs...)
Expand All @@ -775,6 +774,7 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
initialization_data = nothing
end

f = coalesce(f, prob.f)
f = remake(prob.f; f, initialization_data)

prob = if kwargs === missing
Expand Down
14 changes: 14 additions & 0 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4601,6 +4601,20 @@ has_Wfact_t(f::JacobianWrapper) = has_Wfact_t(f.f)
has_paramjac(f::JacobianWrapper) = has_paramjac(f.f)
has_colorvec(f::JacobianWrapper) = has_colorvec(f.f)

is_split_function(x) = is_split_function(typeof(x))
is_split_function(::Type) = false
function is_split_function(::Type{T}) where {T <: Union{
SplitFunction, SplitSDEFunction, DynamicalODEFunction,
DynamicalDDEFunction, DynamicalSDEFunction}}
true
end

split_function_f_wrapper(::Type{<:SplitFunction}) = ODEFunction
split_function_f_wrapper(::Type{<:SplitSDEFunction}) = SDEFunction
split_function_f_wrapper(::Type{<:DynamicalODEFunction}) = ODEFunction
split_function_f_wrapper(::Type{<:DynamicalDDEFunction}) = DDEFunction
split_function_f_wrapper(::Type{<:DynamicalSDEFunction}) = DDEFunction

######### Additional traits

islinear(::AbstractDiffEqFunction) = false
Expand Down
11 changes: 11 additions & 0 deletions test/remake_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,14 @@ end
prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5])
@test_nowarn remake(prob; u0 = [:x => nothing], p = [:a => nothing])
end

@testset "retain properties of `SciMLFunction` passed to `remake`" begin
u0 = [1.0; 2.0; 3.0]
p = [10.0, 20.0, 30.0]
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
fn = NonlinearFunction(nllorenz!; sys, resid_prototype = zeros(Float64, 3))
prob = NonlinearProblem(fn, u0, p)
fn2 = NonlinearFunction(nllorenz!; resid_prototype = zeros(Float32, 3))
prob2 = remake(prob; f = fn2)
@test prob2.f.resid_prototype isa Vector{Float32}
end

0 comments on commit b0dc015

Please sign in to comment.