Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert to immutable for CUDA tests #151

Closed
wants to merge 12 commits into from
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.10.0"
version = "1.10.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
71 changes: 41 additions & 30 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
module SimpleNonlinearSolve

using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations

@recompile_invalidations begin
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode,
NONLINEARSOLVE_DEFAULT_NORM
using DifferentiationInterface: DifferentiationInterface
using DiffResults: DiffResults
using FastClosures: @closure
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu,
mul!, norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
end
using PrecompileTools: @compile_workload, @setup_workload

using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode,
NONLINEARSOLVE_DEFAULT_NORM
using DifferentiationInterface: DifferentiationInterface
using DiffResults: DiffResults
using FastClosures: @closure
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu, mul!,
norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
AbstractNonlinearFunction, StandardNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val, warn_paramtype
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size

const DI = DifferentiationInterface

Expand All @@ -37,7 +36,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end

@inline __is_extension_loaded(::Val) = false

include("immutable_nonlinear_problem.jl")
include("utils.jl")
include("linesearch.jl")

Expand Down Expand Up @@ -69,9 +68,21 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...;
return solve(prob, ITP(), args...; prob.kwargs..., kwargs...)
end

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
# Bypass the highlevel checks for NonlinearProblem for Simple Algorithms
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
prob = convert(ImmutableNonlinearProblem, prob)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
new_u0 = u0 !== nothing ? u0 : prob.u0
new_p = p !== nothing ? p : prob.p
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
Expand All @@ -81,7 +92,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol
p === nothing, alg, args...; prob.kwargs..., kwargs...)
end

function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed,
function __internal_solve_up(_prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
return SciMLBase.__solve(prob, alg, args...; kwargs...)
Expand Down
63 changes: 63 additions & 0 deletions src/immutable_nonlinear_problem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <:
AbstractNonlinearProblem{uType, isinplace}
f::F
u0::uType
p::P
problem_type::PT
kwargs::K
@add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0,
p = NullParameters(),
problem_type = StandardNonlinearProblem();
kwargs...) where {iip}
if haskey(kwargs, :p)
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.")
end
warn_paramtype(p)
new{typeof(u0), iip, typeof(p), typeof(f),
typeof(kwargs), typeof(problem_type)}(f,
u0,
p,
problem_type,
kwargs)
end

"""

Define a steady state problem using the given function.
`isinplace` optionally sets whether the function is inplace or not.
This is determined automatically, but not inferred.
"""
function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
end
end

"""

Define a nonlinear problem using an instance of
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
"""
function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
end

function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
end

"""

Define a ImmutableNonlinearProblem problem from SteadyStateProblem
"""
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
end


function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
ImmutableNonlinearProblem{isinplace(prob)}(prob.f,
prob.u0,
prob.p,
prob.problem_type;
prob.kwargs...)
end
2 changes: 1 addition & 1 deletion src/nlsolve/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ end

__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane{M}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
termination_condition = nothing, kwargs...) where {M}
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method.
autodiff = nothing
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems.
"""
struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
6 changes: 3 additions & 3 deletions src/nlsolve/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function SimpleLimitedMemoryBroyden(;
return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
args...; termination_condition = nothing, kwargs...)
if prob.u0 isa SArray
if termination_condition === nothing ||
Expand All @@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyd
return __generic_solve(prob, alg, args...; termination_condition, kwargs...)
end

@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
@views function __generic_solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down Expand Up @@ -114,7 +114,7 @@ end
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
# finicky, so we'll implement it separately from the generic version
# Ignore termination_condition. Don't pass things into internal functions
function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
function __static_solve(prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
args...; abstol = nothing, maxiters = 1000, kwargs...)
x = prob.u0
fx = _get_fx(prob, x)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end

const SimpleGaussNewton = SimpleNewtonRaphson

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function SciMLBase.__solve(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing,
maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ scalar and static array problems.
nlsolve_update_rule = Val(false)
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...;
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegion, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = __maybe_unaliased(prob.u0, alias_u0)
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end
error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`")
return _get_fx(prob.f, x, prob.p)
end
@inline _get_fx(prob::NonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
@inline _get_fx(prob::ImmutableNonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
@inline function _get_fx(f::NonlinearFunction, x, p)
if isinplace(f)
if f.resid_prototype !== nothing
Expand All @@ -145,7 +145,7 @@ end
# different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve
# is meant for low overhead solvers, users can opt into the other termination modes but the
# default is to use the least overhead version.
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
function init_termination_cache(prob::ImmutableNonlinearProblem, abstol, reltol, du, u, ::Nothing)
return init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs)))
end
Expand All @@ -155,7 +155,7 @@ function init_termination_cache(
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2)))
end

function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
function init_termination_cache(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
T = promote_type(eltype(du), eltype(u))
abstol = __get_tolerance(u, abstol, T)
Expand Down