Skip to content

Commit

Permalink
Add failing tests for ODEsolver when using AutoDiff - need to work ou…
Browse files Browse the repository at this point in the history
…t how to fix
  • Loading branch information
JoelTrent committed Oct 28, 2024
1 parent 9355cfb commit e631489
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -550,4 +550,124 @@ using EllipseSampling
end
end
end

# REAL LIKELIHOOD with ODE solver
begin
# function solvedmodel(θ, t)
# return (θ[2] * θ[3]) ./ ((θ[2] - θ[3]) .* (exp.(-θ[1] .* t)) .+ θ[3])
# end

function DE!(dC, C, p, t)
λ,K=p
dC[1]= λ * C[1] * (1.0 - C[1]/K)
end

function odesolver(t, λ, K, C0)
p=(λ,K)
tspan=(0.0, maximum(t))
prob=ODEProblem(DE!, [C0], tspan, p)
sol=solve(prob, saveat=t)
return sol[1,:]
end

function model(θ, t)
y = odesolver(t, θ[1], θ[2], θ[3])
return y
end

function loglhood(θ, data)
y = model(θ, data.t)
e = sum(loglikelihood(data.dist, data.yobs .- y))
return e
end

λ, K, C0 = 0.01, 100.0, 10.0
t = 0:100:1000
σ = 10.0

λmin, λmax = 0.00, 0.05
Kmin, Kmax = 50.0, 150.0
C0min, C0max = 0.001, 50.0

θnames = [, :K, :C0]
θG = [λ, K, C0]
lb = [λmin, Kmin, C0min]
ub = [λmax, Kmax, C0max]
par_magnitudes = [0.005, 10, 10]

ytrue = solvedmodel(θG, t)
Random.seed!(12348)
yobs = ytrue + σ * randn(length(t))
data = (yobs=yobs, σ=σ, t=t, dist=Normal(0, σ))

@testset "BoundaryIsAZeroRealLikelihood_ODEsolver" begin
N = 50

m = initialise_LikelihoodModel(loglhood, data, θnames, θG, lb, ub, par_magnitudes, show_progress=false,
find_zero_atol=0.0)
getMLE_ellipse_approximation!(m)

# UNIVARIATE
univariate_confidenceintervals!(m)

targetll = LikelihoodBasedProfileWiseAnalysis.get_target_loglikelihood(m, 0.95, LogLikelihood(), 1)
targetll_standardised = LikelihoodBasedProfileWiseAnalysis.get_target_loglikelihood(m, 0.95, EllipseApprox(), 1)

for i in 1:3
lls = [loglhood(m.uni_profiles_dict[i].interval_points.points[:, j], m.core.data) for j in 1:2]
@test isapprox(lls .- targetll, zeros(2), atol=1e-14)

@test isapprox(m.uni_profiles_dict[i].interval_points.ll .- targetll_standardised, zeros(2), atol=1e-14)
end

# BIVARIATE
for method in [IterativeBoundaryMethod(4, 2, 2), RadialRandomMethod(3), RadialMLEMethod(0.0), SimultaneousMethod(), Fix1AxisMethod()]
bivariate_confidenceprofiles!(m, N, method=method)
end

targetll = LikelihoodBasedProfileWiseAnalysis.get_target_loglikelihood(m, 0.95, LogLikelihood(), 2)

for i in 1:6
lls = [loglhood(m.biv_profiles_dict[i].confidence_boundary[:, j], m.core.data) for j in 1:N]
@test isapprox(lls .- targetll, zeros(N), atol=1e-12)
end
end

@testset "ValidDimensionalPoints_RealLikelihood_ODEsolver" begin
m = initialise_LikelihoodModel(loglhood, data, θnames, θG, lb, ub, par_magnitudes, show_progress=false)

# UNIVARIATE
dimensional_likelihood_samples!(m, 1, 100, sample_type=UniformGridSamples())
dimensional_likelihood_samples!(m, 1, 100, sample_type=UniformRandomSamples())
dimensional_likelihood_samples!(m, 1, 100, sample_type=LatinHypercubeSamples())

targetll_standardised = LikelihoodBasedProfileWiseAnalysis.get_target_loglikelihood(m, 0.95, EllipseApprox(), 1)

for i in 1:9
@test all(m.dim_samples_dict[i].ll .≥ targetll_standardised)
end

# BIVARIATE
dimensional_likelihood_samples!(m, 2, 10, sample_type=UniformGridSamples())
dimensional_likelihood_samples!(m, 2, 100, sample_type=UniformRandomSamples())
dimensional_likelihood_samples!(m, 2, 100, sample_type=LatinHypercubeSamples())

targetll_standardised = LikelihoodBasedProfileWiseAnalysis.get_target_loglikelihood(m, 0.95, EllipseApprox(), 2)

for i in 10:18
@test all(m.dim_samples_dict[i].ll .≥ targetll_standardised)
end

# FULL
dimensional_likelihood_samples!(m, 3, 10, sample_type=UniformGridSamples())
dimensional_likelihood_samples!(m, 3, 1000, sample_type=UniformRandomSamples())
dimensional_likelihood_samples!(m, 3, 1000, sample_type=LatinHypercubeSamples())

targetll_standardised = LikelihoodBasedProfileWiseAnalysis.get_target_loglikelihood(m, 0.95, EllipseApprox(), 3)

for i in 19:21
@test all(m.dim_samples_dict[i].ll .≥ targetll_standardised)
end
end
end
end

0 comments on commit e631489

Please sign in to comment.