Skip to content

Commit

Permalink
Merge pull request #34 from AlCap23/master
Browse files Browse the repository at this point in the history
Restructured Symbolic Recovery
  • Loading branch information
ChrisRackauckas authored Feb 22, 2021
2 parents 63854e4 + b85d921 commit 266e1e4
Show file tree
Hide file tree
Showing 45 changed files with 535 additions and 225 deletions.
Binary file removed LotkaVolterra/Hudson_Bay_recovery.jld2
Binary file not shown.
192 changes: 105 additions & 87 deletions LotkaVolterra/Manifest.toml

Large diffs are not rendered by default.

Binary file removed LotkaVolterra/Scenario_1_recovery_0.01.jld2
Binary file not shown.
Binary file removed LotkaVolterra/Scenario_2_full_plot.pdf
Binary file not shown.
Binary file removed LotkaVolterra/Scenario_3_recovery_0.01.jld2
Binary file not shown.
108 changes: 71 additions & 37 deletions LotkaVolterra/hudson_bay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ using DelimitedFiles
using Random
Random.seed!(5443)

svname = "HudsonBay"
## Data Preprocessing
# The data has been taken from https://jmahaffy.sdsu.edu/courses/f00/math122/labs/labj/q3v1.htm
# Originally published in
# Originally published in E. P. Odum (1953), Fundamentals of Ecology, Philadelphia, W. B. Saunders
hudson_bay_data = readdlm("hudson_bay_data.dat", '\t', Float32, '\n')
# Measurements of prey and predator
Xₙ = Matrix(transpose(hudson_bay_data[:, 2:3]))
plot(t, transpose(Xₙ))
t = hudson_bay_data[:, 1] .- hudson_bay_data[1, 1]
# Normalize the data; since the data domain is strictly positive
# we just need to divide by the maximum
xscale = maximum(Xₙ, dims =2)
Xₙ .= 1f0 ./ xscale .* Xₙ
# Time from 0 -> n
t = hudson_bay_data[:, 1] .- hudson_bay_data[1, 1]
tspan = (t[1], t[end])

# Plot the data
scatter(t, transpose(Xₙ), xlabel = "t [a]", ylabel = "x(t), y(t)")
scatter(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)")
plot!(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)")

## Direct Identification via SINDy + Collocation
Expand All @@ -50,17 +50,17 @@ plot(t, dx̂')
b = [polynomial_basis(u, 5); sin.(u)]
basis = Basis(b, u)
# Create an optimizer for the SINDy problem
opt = SR3(Float32(1e-2), Float32(1e-2))
opt = STRRidge()#SR3(Float32(1e-2), Float32(1e-2))
# Create the thresholds which should be used in the search process
λ = Float32.(exp10.(-7:0.1:3))
# Target function to choose the results from; x = L0 of coefficients and L2-Error of the model
g(x) = x[1] < 1 ? Inf : norm(x, 2)
# Test on derivative data
Ψ = SINDy(x̂, dx̂, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true) # Succeed
Ψ = SINDy(x̂, dx̂, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true)
println(Ψ)
print_equations(Ψ) # Fails
b2 = Basis((u,p,t)->Ψ(u,ones(length(parameters(Ψ))),t),u, linear_independent = true)
Ψ = SINDy(x̂, dx̂, b2, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true) # Succeed
Ψ = SINDy(x̂, dx̂, b2, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true)
println(Ψ)
print_equations(Ψ) # Fails
parameters(Ψ)
Expand All @@ -69,7 +69,6 @@ parameters(Ψ)
# We assume we have only 5 measurements in y, evenly distributed
ty = collect(t[1]:Float32(t[end]/5):t[end])
# Create datasets for the different measurements
t
XS = zeros(Float32, length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # All x data
TS = zeros(Float32, length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # Time data
YS = zeros(Float32, length(ty)-1, 2) # Just two measurements in y
Expand Down Expand Up @@ -160,6 +159,12 @@ println("Training loss after $(length(losses)) iterations: $(losses[end])")
res3 = DiffEqFlux.sciml_train(loss, res2.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")


pl_losses = plot(1:101, losses[1:101], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM (Shooting)", color = :blue)
plot!(102:302, losses[102:302], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS (Shooting)", color = :red)
plot!(302:length(losses), losses[302:end], color = :black, label = "BFGS (L2)")
savefig(pl_losses, joinpath(pwd(), "plots", "$(svname)_losses.pdf"))

# Rename the best candidate
p_trained = res3.minimizer

Expand All @@ -168,14 +173,18 @@ p_trained = res3.minimizer
tsample = t[1]:0.5:t[end]
= predict(p_trained, Xₙ[:,1], tsample)
# Trained on noisy data vs real solution
plot(t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])
plot!(tsample, transpose(X̂), color = :red, label = ["Interpolation" nothing])
pl_trajectory = scatter(t, transpose(Xₙ), color = :black, label = ["Measurements" nothing], xlabel = "t", ylabel = "x(t), y(t)")
plot!(tsample, transpose(X̂), color = :red, label = ["UDE Approximation" nothing])
savefig(pl_trajectory, joinpath(pwd(), "plots", "$(svname)_trajectory_reconstruction.pdf"))

# Neural network guess
= U(X̂,p_trained[3:end])

scatter(tsample, transpose(Ŷ), xlabel = "t", ylabel ="I1(t), I2(t)", color = :red, label = ["UDE Approximation" nothing])

pl_reconstruction = scatter(tsample, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing])
plot!(tsample, transpose(Ŷ), color = :red, lw = 2, style = :dash, label = [nothing nothing])
savefig(pl_reconstruction, joinpath(pwd(), "plots", "$(svname)_missingterm_reconstruction.pdf"))
pl_missing = plot(pl_trajectory, pl_reconstruction, layout = (2,1))
savefig(pl_missing, joinpath(pwd(), "plots", "$(svname)_reconstruction.pdf"))
## Symbolic regression via sparse regression ( SINDy based )

# Create a Basis
Expand All @@ -187,7 +196,7 @@ b = [polynomial_basis(u, 5); sin.(u)]
basis = Basis(b, u)

# Create an optimizer for the SINDy problem
opt = SR3(Float32(1e-2), Float32(1e-2))
opt = STRRidge()
# Create the thresholds which should be used in the search process
λ = Float32.(exp10.(-7:0.1:3))
# Target function to choose the results from; x = L0 of coefficients and L2-Error of the model
Expand All @@ -196,51 +205,76 @@ g(x) = x[1] < 1 ? Inf : norm(x, 2)
# Test on uode derivative data
println("SINDy on learned, partial, available data")
Ψ = SINDy(X̂, Ŷ, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true)
println(Ψ)
print_equations(Ψ)

@info "Found equations:"
print_equations(Ψ)
# Extract the parameter
= parameters(Ψ)
println("First parameter guess : $(p̂)")

# Just the equations -> we reiterate on sindy here
# searching all linear independent components again
b = Basis((u, p, t)->Ψ(u, ones(length(p̂)), t), u, linear_independent = true)
println(b)
# Retune for better parameters -> we could also use DiffEqFlux or other parameter estimation tools here.
opt = SR3(Float32(1e-2), Float32(1e-2))
Ψf = SINDy(X̂, Ŷ, b, opt, maxiter = 10000, normalize = true, convergence_error = eps()) # Succeed
println(Ψf)
print_equations(Ψf)
= parameters(Ψf)
println("Second parameter guess : $(p̂)")

# Define the recovered, hyrid model with the rescaled dynamics
function recovered_dynamics!(du,u, p, t)
= Ψf(u, p[3:4]) # Network prediction
= Ψ(u, p[3:end]) # Network prediction
du[1] = p[1]*u[1] + û[1]
du[2] = -p[2]*u[2] + û[2]
end

p_model = [p_trained[1:2];p̂]
estimation_prob = ODEProblem(recovered_dynamics!, Xₙ[:, 1], tspan, p_model)
estimate = solve(estimation_prob, Tsit5(), saveat = 0.1)
# Convert for reuse
sys = modelingtoolkitize(estimation_prob);
dudt = ODEFunction(sys);
estimation_prob = ODEProblem(dudt,Xₙ[:, 1], tspan, p_model)
estimate = solve(estimation_prob, Tsit5(), saveat = t)

## Fit the found model
function loss_fit(θ)
= Array(solve(estimation_prob, Tsit5(), p = θ, saveat = t))
sum(abs2, X̂ .- Xₙ)
end

# Post-fit the model
res_fit = DiffEqFlux.sciml_train(loss_fit, p_model, BFGS(initial_stepnorm = 0.1f0), maxiters = 1000)
p_fitted = res_fit.minimizer

# Estimate
estimate_rough = solve(estimation_prob, Tsit5(), saveat = 0.1*mean(diff(t)), p = p_model)
estimate = solve(estimation_prob, Tsit5(), saveat = 0.1*mean(diff(t)), p = p_fitted)

# Plot
plot(t, transpose(Xₙ))
plot!(estimate)
pl_fitted = plot(t, transpose(Xₙ), style = :dash, lw = 2,color = :black, label = ["Measurements" nothing], xlabel = "t", ylabel = "x(t), y(t)")
plot!(estimate_rough, color = :red, label = ["Recovered" nothing])
plot!(estimate, color = :blue, label = ["Recovered + Fitted" nothing])
savefig(pl_fitted,joinpath(pwd(),"plots","$(svname)recovery_fitting.pdf"))

## Simulation

# Look at long term prediction
t_long = (0.0f0, 50.0f0)
estimation_prob = ODEProblem(recovered_dynamics!, Xₙ[:, 1], t_long, p_model)
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.25)
plot(estimate_long)
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.25f0, tspan = t_long,p = p_fitted)
plot(estimate_long.t, transpose(xscale .* estimate_long[:,:]), xlabel = "t", ylabel = "x(t),y(t)")


## Save the results
save("Hudson_Bay_recovery.jld2",
save(joinpath(pwd(),"results","Hudson_Bay_recovery.jld2"),
"X", Xₙ, "t" , t, "neural_network" , U, "initial_parameters", p, "trained_parameters" , p_trained, # Training
"losses", losses, "result", Ψf, "recovered_parameters", p̂, # Recovery
"model", recovered_dynamics!, "model_parameter", p_model,
"losses", losses, "result", Ψ, "recovered_parameters", p̂, # Recovery
"model", recovered_dynamics!, "model_parameter", p_model, "fitted_parameter", p_fitted,
"long_estimate", estimate_long) # Estimation

## Post Processing and Plots

c1 = 3 # RGBA(174/255,192/255,201/255,1) # Maroon
c2 = :orange # RGBA(132/255,159/255,173/255,1) # Red
c3 = :blue # RGBA(255/255,90/255,0,1) # Orange
c4 = :purple # RGBA(153/255,50/255,204/255,1) # Purple

p3 = scatter(t, transpose(Xₙ), color = [c1 c2], label = ["x data" "y data"],
title = "Recovered Model from Hudson Bay Data",
titlefont = "Helvetica", legendfont = "Helvetica",
markersize = 5)

plot!(p3,estimate_long, color = [c3 c4], lw=1, label = ["Estimated x(t)" "Estimated y(t)"])
plot!(p3,[19.99,20.01],[0.0,maximum(Xₙ)*1.25],lw=1,color=:black, label = nothing)
annotate!([(10.0,maximum(Xₙ)*1.25,text("Training \nData",12 , :center, :top, :black, "Helvetica"))])
savefig(p3,joinpath(pwd(),"plots","$(svname)full_plot.pdf"))
Loading

0 comments on commit 266e1e4

Please sign in to comment.