Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalizing flow example does not produce accurate estimates #786

Closed
itsdfish opened this issue Dec 29, 2022 · 6 comments · Fixed by #787
Closed

Normalizing flow example does not produce accurate estimates #786

itsdfish opened this issue Dec 29, 2022 · 6 comments · Fixed by #787

Comments

@itsdfish
Copy link

Hi,

As explained in this discourse thread, the current normalizing flow example does not produce accurate results.

scatter

Chris Rackauckas suggested increasing the network size. I tried increasing the training set to 1,000 and the following larger network without success:

nn = Flux.Chain(
    Flux.Dense(1, 100, tanh),
    Flux.Dense(100, 100, tanh),
    Flux.Dense(100, 1, tanh),
) 

scatter1

@ChrisRackauckas
Copy link
Member

@prbzrg have you looked into this example? I know you did some deeper investigations of the FFJORD stuff since it was written.

@prbzrg
Copy link
Member

prbzrg commented Dec 30, 2022

I have done some experiments by increasing data and neural network layers, it seems the problem is in training.

code:

using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux,
      OptimizationOptimJL, Distributions

nn = Flux.Chain(
    Flux.Dense(1, 10, tanh),
    Flux.Dense(10, 10, tanh),
    Flux.Dense(10, 1, tanh),
) |> f32
tspan = (0.0f0, 1.0f0)

ffjord_mdl = DiffEqFlux.FFJORD(nn, tspan, Tsit5())

# Training
data_dist = Normal(6.0f0, 0.7f0)
train_data = Float32.(rand(data_dist, 1, 1000))

function loss(θ)
    logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
    -mean(logpx)
end

function cb(p, l)::Bool
    vl = loss(p)
    @info "Training" loss = vl
    false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)

res1 = Optimization.solve(optprob,
                          ADAM(0.1),
                          maxiters = 100, callback=cb)

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2,
                          Optim.LBFGS(),
                          allow_f_increases=false, callback=cb)

# Evaluation
using Distances

actual_pdf = pdf.(data_dist, train_data)
learned_pdf = exp.(ffjord_mdl(train_data, res2.u)[1])
train_dis = totalvariation(learned_pdf, actual_pdf) / size(train_data, 2)

# Data Generation
ffjord_dist = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res2.u))
new_data = rand(ffjord_dist, 100)

@show train_dis

output:

WARNING: Method definition rrule(Type{Array{T, N} where N where T}, CUDA.CuArray{T, N, B} where B where N where T) in module Flux at C:\Users\Hossein Pourbozorg\.julia\packages\Flux\kq9Et\src\functor.jl:124 overwritten in module Lux at C:\Users\Hossein Pourbozorg\.julia\packages\Lux\6vByk\src\autodiff.jl:57.
WARNING: Method definition (::Type{DiffEqFlux.DeterministicCNF{M, P, RE, D, T, A, K} where K where A where T where D where RE where P where M})(M, P, RE, D, T, A, K) where {M, P, RE, D, T, A, K} in module DiffEqFlux at deprecated.jl:70 overwritten at C:\Users\Hossein Pourbozorg\.julia\packages\DiffEqFlux\2IJEZ\src\ffjord.jl:41.
┌ Warning: Wrapping a custom walk function as an `AnonymousWalk`. Future versions will only support custom walks that explicitly subtyle `AbstractWalk`.
│   caller = #fmap#25 at maps.jl:7 [inlined]
└ @ Core C:\Users\Hossein Pourbozorg\.julia\packages\Functors\orBYx\src\maps.jl:7
WARNING: importing deprecated binding Flux.ADAM into OptimizationFlux.
WARNING: importing deprecated binding Flux.ADAM into DiffEqFlux.
WARNING: Flux.ADAM is deprecated, use Adam instead.
  likely near C:\Users\Hossein Pourbozorg\Code Projects\Mine\tmp-example test\emp.jl:32
┌ Warning: Wrapping a custom walk function as an `AnonymousWalk`. Future versions will only support custom walks that explicitly subtyle `AbstractWalk`.
│   caller = fmap(f::Function, x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, ys::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, ), Tuple{Int64, Int64, Tuple{}}}}}}; exclude::Function, walk::Function, cache::IdDict{Any, Any}, prune::Functors.NoKeyword) at maps.jl:7
└ @ Functors C:\Users\Hossein Pourbozorg\.julia\packages\Functors\orBYx\src\maps.jl:7
┌ Info: Training
└   loss = 21.64868f0
┌ Info: Training
└   loss = 13.843481f0
┌ Info: Training
└   loss = 13.740033f0
┌ Info: Training
└   loss = 13.7272625f0
┌ Info: Training
└   loss = 13.722098f0
┌ Info: Training
└   loss = 13.718432f0
┌ Info: Training
└   loss = 13.71533f0
┌ Info: Training
└   loss = 13.712653f0
┌ Info: Training
└   loss = 13.710369f0
┌ Info: Training
└   loss = 13.708447f0
┌ Info: Training
└   loss = 13.706831f0
┌ Info: Training
└   loss = 13.705488f0
┌ Info: Training
└   loss = 13.704375f0
┌ Info: Training
└   loss = 13.703446f0
┌ Info: Training
└   loss = 13.702651f0
┌ Info: Training
└   loss = 13.701985f0
┌ Info: Training
└   loss = 13.701411f0
┌ Info: Training
└   loss = 13.700864f0
┌ Info: Training
└   loss = 13.700376f0
┌ Info: Training
└   loss = 13.699959f0
┌ Info: Training
└   loss = 13.699717f0
┌ Info: Training
└   loss = 13.699603f0
┌ Info: Training
└   loss = 13.699476f0
┌ Info: Training
└   loss = 13.699292f0
┌ Info: Training
└   loss = 13.699083f0
┌ Info: Training
└   loss = 13.698948f0
┌ Info: Training
└   loss = 13.698828f0
┌ Info: Training
└   loss = 13.698754f0
┌ Info: Training
└   loss = 13.69869f0
┌ Info: Training
└   loss = 13.698635f0
┌ Info: Training
└   loss = 13.698595f0
┌ Info: Training
└   loss = 13.69856f0
┌ Info: Training
└   loss = 13.698534f0
┌ Info: Training
└   loss = 13.698505f0
┌ Info: Training
└   loss = 13.698484f0
┌ Info: Training
└   loss = 13.698463f0
┌ Info: Training
└   loss = 13.698443f0
┌ Info: Training
└   loss = 13.698429f0
┌ Info: Training
└   loss = 13.698419f0
┌ Info: Training
└   loss = 13.698409f0
┌ Info: Training
└   loss = 13.698401f0
┌ Info: Training
└   loss = 13.698391f0
┌ Info: Training
└   loss = 13.6983795f0
┌ Info: Training
└   loss = 13.698372f0
┌ Info: Training
└   loss = 13.698367f0
┌ Info: Training
└   loss = 13.698359f0
┌ Info: Training
└   loss = 13.698351f0
┌ Info: Training
└   loss = 13.698345f0
┌ Info: Training
└   loss = 13.698339f0
┌ Info: Training
└   loss = 13.698336f0
┌ Info: Training
└   loss = 13.698332f0
┌ Info: Training
└   loss = 13.698327f0
┌ Info: Training
└   loss = 13.698322f0
┌ Info: Training
└   loss = 13.6983185f0
┌ Info: Training
└   loss = 13.698315f0
┌ Info: Training
└   loss = 13.698314f0
┌ Info: Training
└   loss = 13.698308f0
┌ Info: Training
└   loss = 13.698306f0
┌ Info: Training
└   loss = 13.698301f0
┌ Info: Training
└   loss = 13.6983f0
┌ Info: Training
└   loss = 13.698296f0
┌ Info: Training
└   loss = 13.698295f0
┌ Info: Training
└   loss = 13.698291f0
┌ Info: Training
└   loss = 13.698288f0
┌ Info: Training
└   loss = 13.698287f0
┌ Info: Training
└   loss = 13.698282f0
┌ Info: Training
└   loss = 13.698282f0
┌ Info: Training
└   loss = 13.69828f0
┌ Info: Training
└   loss = 13.698277f0
┌ Info: Training
└   loss = 13.698276f0
┌ Info: Training
└   loss = 13.698273f0
┌ Info: Training
└   loss = 13.698272f0
┌ Info: Training
└   loss = 13.698267f0
┌ Info: Training
└   loss = 13.698267f0
┌ Info: Training
└   loss = 13.698265f0
┌ Info: Training
└   loss = 13.698264f0
┌ Info: Training
└   loss = 13.698263f0
┌ Info: Training
└   loss = 13.698261f0
┌ Info: Training
└   loss = 13.698259f0
┌ Info: Training
└   loss = 13.6982565f0
┌ Info: Training
└   loss = 13.698257f0
┌ Info: Training
└   loss = 13.6982565f0
┌ Info: Training
└   loss = 13.698256f0
┌ Info: Training
└   loss = 13.698255f0
┌ Info: Training
└   loss = 13.698254f0
┌ Info: Training
└   loss = 13.698254f0
┌ Info: Training
└   loss = 13.69825f0
┌ Info: Training
└   loss = 13.698251f0
┌ Info: Training
└   loss = 13.698247f0
┌ Info: Training
└   loss = 13.698247f0
┌ Info: Training
└   loss = 13.698245f0
┌ Info: Training
└   loss = 13.698243f0
┌ Info: Training
└   loss = 13.698243f0
┌ Info: Training
└   loss = 13.698239f0
┌ Info: Training
└   loss = 13.698239f0
┌ Info: Training
└   loss = 13.698238f0
┌ Info: Training
└   loss = 13.698239f0
┌ Info: Training
└   loss = 13.698236f0
┌ Info: Training
└   loss = 13.698236f0
┌ Info: Training
└   loss = 13.6982355f0
┌ Info: Training
└   loss = 13.6982355f0
┌ Info: Training
└   loss = 13.698236f0
┌ Info: Training
└   loss = 13.69809f0
train_dis = 0.20204227f0

environment:

(tmp-example test) pkg> st
Status `C:\Users\Hossein Pourbozorg\Code Projects\Mine\tmp-example test\Project.toml`
  [aae7a2af] DiffEqFlux v1.53.0
  [0c46a032] DifferentialEquations v7.6.0
  [b4f34e82] Distances v0.10.7
  [31c24e10] Distributions v0.25.79
  [587475ba] Flux v0.13.10
  [7f7a1694] Optimization v3.10.0
  [253f991c] OptimizationFlux v0.1.2
  [36348300] OptimizationOptimJL v0.1.5
  [91a5bcdd] Plots v1.38.0

@prbzrg
Copy link
Member

prbzrg commented Dec 30, 2022

The problem was tspan, changing it to tspan = (0.0f0, 10.0f0) fixed it.

fig

train_dis = 0.0057737245f0

@prbzrg
Copy link
Member

prbzrg commented Dec 30, 2022

and learned_pdf must be learned_pdf = exp.(ffjord_mdl(train_data, res2.u, monte_carlo=false)[1])

@prbzrg
Copy link
Member

prbzrg commented Dec 30, 2022

I will make a PR to fix this issue.

@itsdfish
Copy link
Author

Sorry to reopen this. Although the plots look better, I'm not certain whether there is a problem.

The scatter plot looks strange, but perhaps its expected:

scatter1

The pdf overlay does not look as accurate:

pdf_overlay

Is this expected?

I have two other questions. First, can you please recommend a better way to generate the pdf. This was the only solution I could get to work, but it seems suboptimal:

xs = [4:.05:9;]
est_density = map(x -> exp(ffjord_mdl([x], res2.u, monte_carlo=false)[1]), xs)
est_density = vcat(est_density...)

Lastly, I think a common use case for generating flows is approximate the likelihood function c.f. https://elifesciences.org/articles/77220 . Extending the example above, this would allow a person to compute the density for an arbitrary value of x, mu, and sigma. In addition, there should be a way to store the solution so that it could be reused at a later time and perhaps used in Turing.jl. To facilitate adoption of the package, it might be helpful to include an example in the documentation. Presumably this would entail generating training vectors consisting of [mu, sigma, x] instead of only x, but I'm not entirely sure how to modify the example. I would be happy to help where I can, but I would require some guidance.

Here is my full code:

###########################################################################################################
#                                           load packages
###########################################################################################################
cd(@__DIR__)
using Pkg 
Pkg.activate("")
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux
using OptimizationOptimJL, Distributions
using Random
Random.seed!(3411)
###########################################################################################################
#                                           setup network
###########################################################################################################
nn = Flux.Chain(
    Flux.Dense(1, 10, tanh),
    Flux.Dense(10, 10, tanh),
    Flux.Dense(10, 1, tanh),
) |> f32
tspan = (0.0f0, 10.0f0)

ffjord_mdl = DiffEqFlux.FFJORD(nn, tspan, Tsit5())

# Training
data_dist = Normal(6.0f0, 0.7f0)
train_data = Float32.(rand(data_dist, 1, 1000))

function loss(θ)
    logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
    -mean(logpx)
end

function cb(p, l)::Bool
    vl = loss(p)
    @info "Training" loss = vl
    false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)

res1 = Optimization.solve(optprob,
                          ADAM(0.1),
                          maxiters = 100, callback=cb)

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2,
                          Optim.LBFGS(),
                          allow_f_increases=false, callback=cb)

###########################################################################################################
#                                           evaluate and plot
###########################################################################################################
using Distances

actual_pdf = pdf.(data_dist, train_data)
learned_pdf = exp.(ffjord_mdl(train_data, res2.u, monte_carlo=false)[1])
train_dis = totalvariation(learned_pdf, actual_pdf) / size(train_data, 2)

@show train_dis

# Data Generation
ffjord_dist = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res2.u))
new_data = rand(ffjord_dist, 100)

using Plots
scatter(actual_pdf', learned_pdf', xlabel="true density", ylabel="estimated density",
    leg=false, grid=false)
savefig("scatter1.png")

# density
xs = [4:.05:9;]
true_density = pdf.(data_dist, xs)
# how do I plot the density across xs?
est_density = map(x -> exp(ffjord_mdl([x], res2.u, monte_carlo=false)[1]), xs)
est_density = vcat(est_density...)
plot(xs, true_density)
plot!(xs, est_density)
savefig("pdf_overlay.png")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants