-
-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Normalizing flow example does not produce accurate estimates #786
Comments
@prbzrg have you looked into this example? I know you did some deeper investigations of the FFJORD stuff since it was written. |
I have done some experiments by increasing data and neural network layers, it seems the problem is in training. code: using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux,
OptimizationOptimJL, Distributions
nn = Flux.Chain(
Flux.Dense(1, 10, tanh),
Flux.Dense(10, 10, tanh),
Flux.Dense(10, 1, tanh),
) |> f32
tspan = (0.0f0, 1.0f0)
ffjord_mdl = DiffEqFlux.FFJORD(nn, tspan, Tsit5())
# Training
data_dist = Normal(6.0f0, 0.7f0)
train_data = Float32.(rand(data_dist, 1, 1000))
function loss(θ)
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
-mean(logpx)
end
function cb(p, l)::Bool
vl = loss(p)
@info "Training" loss = vl
false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)
res1 = Optimization.solve(optprob,
ADAM(0.1),
maxiters = 100, callback=cb)
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2,
Optim.LBFGS(),
allow_f_increases=false, callback=cb)
# Evaluation
using Distances
actual_pdf = pdf.(data_dist, train_data)
learned_pdf = exp.(ffjord_mdl(train_data, res2.u)[1])
train_dis = totalvariation(learned_pdf, actual_pdf) / size(train_data, 2)
# Data Generation
ffjord_dist = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res2.u))
new_data = rand(ffjord_dist, 100)
@show train_dis
output: WARNING: Method definition rrule(Type{Array{T, N} where N where T}, CUDA.CuArray{T, N, B} where B where N where T) in module Flux at C:\Users\Hossein Pourbozorg\.julia\packages\Flux\kq9Et\src\functor.jl:124 overwritten in module Lux at C:\Users\Hossein Pourbozorg\.julia\packages\Lux\6vByk\src\autodiff.jl:57.
WARNING: Method definition (::Type{DiffEqFlux.DeterministicCNF{M, P, RE, D, T, A, K} where K where A where T where D where RE where P where M})(M, P, RE, D, T, A, K) where {M, P, RE, D, T, A, K} in module DiffEqFlux at deprecated.jl:70 overwritten at C:\Users\Hossein Pourbozorg\.julia\packages\DiffEqFlux\2IJEZ\src\ffjord.jl:41.
┌ Warning: Wrapping a custom walk function as an `AnonymousWalk`. Future versions will only support custom walks that explicitly subtyle `AbstractWalk`.
│ caller = #fmap#25 at maps.jl:7 [inlined]
└ @ Core C:\Users\Hossein Pourbozorg\.julia\packages\Functors\orBYx\src\maps.jl:7
WARNING: importing deprecated binding Flux.ADAM into OptimizationFlux.
WARNING: importing deprecated binding Flux.ADAM into DiffEqFlux.
WARNING: Flux.ADAM is deprecated, use Adam instead.
likely near C:\Users\Hossein Pourbozorg\Code Projects\Mine\tmp-example test\emp.jl:32
┌ Warning: Wrapping a custom walk function as an `AnonymousWalk`. Future versions will only support custom walks that explicitly subtyle `AbstractWalk`.
│ caller = fmap(f::Function, x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, ys::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}; exclude::Function, walk::Function, cache::IdDict{Any, Any}, prune::Functors.NoKeyword) at maps.jl:7
└ @ Functors C:\Users\Hossein Pourbozorg\.julia\packages\Functors\orBYx\src\maps.jl:7
┌ Info: Training
└ loss = 21.64868f0
┌ Info: Training
└ loss = 13.843481f0
┌ Info: Training
└ loss = 13.740033f0
┌ Info: Training
└ loss = 13.7272625f0
┌ Info: Training
└ loss = 13.722098f0
┌ Info: Training
└ loss = 13.718432f0
┌ Info: Training
└ loss = 13.71533f0
┌ Info: Training
└ loss = 13.712653f0
┌ Info: Training
└ loss = 13.710369f0
┌ Info: Training
└ loss = 13.708447f0
┌ Info: Training
└ loss = 13.706831f0
┌ Info: Training
└ loss = 13.705488f0
┌ Info: Training
└ loss = 13.704375f0
┌ Info: Training
└ loss = 13.703446f0
┌ Info: Training
└ loss = 13.702651f0
┌ Info: Training
└ loss = 13.701985f0
┌ Info: Training
└ loss = 13.701411f0
┌ Info: Training
└ loss = 13.700864f0
┌ Info: Training
└ loss = 13.700376f0
┌ Info: Training
└ loss = 13.699959f0
┌ Info: Training
└ loss = 13.699717f0
┌ Info: Training
└ loss = 13.699603f0
┌ Info: Training
└ loss = 13.699476f0
┌ Info: Training
└ loss = 13.699292f0
┌ Info: Training
└ loss = 13.699083f0
┌ Info: Training
└ loss = 13.698948f0
┌ Info: Training
└ loss = 13.698828f0
┌ Info: Training
└ loss = 13.698754f0
┌ Info: Training
└ loss = 13.69869f0
┌ Info: Training
└ loss = 13.698635f0
┌ Info: Training
└ loss = 13.698595f0
┌ Info: Training
└ loss = 13.69856f0
┌ Info: Training
└ loss = 13.698534f0
┌ Info: Training
└ loss = 13.698505f0
┌ Info: Training
└ loss = 13.698484f0
┌ Info: Training
└ loss = 13.698463f0
┌ Info: Training
└ loss = 13.698443f0
┌ Info: Training
└ loss = 13.698429f0
┌ Info: Training
└ loss = 13.698419f0
┌ Info: Training
└ loss = 13.698409f0
┌ Info: Training
└ loss = 13.698401f0
┌ Info: Training
└ loss = 13.698391f0
┌ Info: Training
└ loss = 13.6983795f0
┌ Info: Training
└ loss = 13.698372f0
┌ Info: Training
└ loss = 13.698367f0
┌ Info: Training
└ loss = 13.698359f0
┌ Info: Training
└ loss = 13.698351f0
┌ Info: Training
└ loss = 13.698345f0
┌ Info: Training
└ loss = 13.698339f0
┌ Info: Training
└ loss = 13.698336f0
┌ Info: Training
└ loss = 13.698332f0
┌ Info: Training
└ loss = 13.698327f0
┌ Info: Training
└ loss = 13.698322f0
┌ Info: Training
└ loss = 13.6983185f0
┌ Info: Training
└ loss = 13.698315f0
┌ Info: Training
└ loss = 13.698314f0
┌ Info: Training
└ loss = 13.698308f0
┌ Info: Training
└ loss = 13.698306f0
┌ Info: Training
└ loss = 13.698301f0
┌ Info: Training
└ loss = 13.6983f0
┌ Info: Training
└ loss = 13.698296f0
┌ Info: Training
└ loss = 13.698295f0
┌ Info: Training
└ loss = 13.698291f0
┌ Info: Training
└ loss = 13.698288f0
┌ Info: Training
└ loss = 13.698287f0
┌ Info: Training
└ loss = 13.698282f0
┌ Info: Training
└ loss = 13.698282f0
┌ Info: Training
└ loss = 13.69828f0
┌ Info: Training
└ loss = 13.698277f0
┌ Info: Training
└ loss = 13.698276f0
┌ Info: Training
└ loss = 13.698273f0
┌ Info: Training
└ loss = 13.698272f0
┌ Info: Training
└ loss = 13.698267f0
┌ Info: Training
└ loss = 13.698267f0
┌ Info: Training
└ loss = 13.698265f0
┌ Info: Training
└ loss = 13.698264f0
┌ Info: Training
└ loss = 13.698263f0
┌ Info: Training
└ loss = 13.698261f0
┌ Info: Training
└ loss = 13.698259f0
┌ Info: Training
└ loss = 13.6982565f0
┌ Info: Training
└ loss = 13.698257f0
┌ Info: Training
└ loss = 13.6982565f0
┌ Info: Training
└ loss = 13.698256f0
┌ Info: Training
└ loss = 13.698255f0
┌ Info: Training
└ loss = 13.698254f0
┌ Info: Training
└ loss = 13.698254f0
┌ Info: Training
└ loss = 13.69825f0
┌ Info: Training
└ loss = 13.698251f0
┌ Info: Training
└ loss = 13.698247f0
┌ Info: Training
└ loss = 13.698247f0
┌ Info: Training
└ loss = 13.698245f0
┌ Info: Training
└ loss = 13.698243f0
┌ Info: Training
└ loss = 13.698243f0
┌ Info: Training
└ loss = 13.698239f0
┌ Info: Training
└ loss = 13.698239f0
┌ Info: Training
└ loss = 13.698238f0
┌ Info: Training
└ loss = 13.698239f0
┌ Info: Training
└ loss = 13.698236f0
┌ Info: Training
└ loss = 13.698236f0
┌ Info: Training
└ loss = 13.6982355f0
┌ Info: Training
└ loss = 13.6982355f0
┌ Info: Training
└ loss = 13.698236f0
┌ Info: Training
└ loss = 13.69809f0
train_dis = 0.20204227f0 environment: (tmp-example test) pkg> st
Status `C:\Users\Hossein Pourbozorg\Code Projects\Mine\tmp-example test\Project.toml`
[aae7a2af] DiffEqFlux v1.53.0
[0c46a032] DifferentialEquations v7.6.0
[b4f34e82] Distances v0.10.7
[31c24e10] Distributions v0.25.79
[587475ba] Flux v0.13.10
[7f7a1694] Optimization v3.10.0
[253f991c] OptimizationFlux v0.1.2
[36348300] OptimizationOptimJL v0.1.5
[91a5bcdd] Plots v1.38.0 |
and |
I will make a PR to fix this issue. |
Sorry to reopen this. Although the plots look better, I'm not certain whether there is a problem. The scatter plot looks strange, but perhaps its expected: The pdf overlay does not look as accurate: Is this expected? I have two other questions. First, can you please recommend a better way to generate the pdf. This was the only solution I could get to work, but it seems suboptimal:
Lastly, I think a common use case for generating flows is approximate the likelihood function c.f. https://elifesciences.org/articles/77220 . Extending the example above, this would allow a person to compute the density for an arbitrary value of x, mu, and sigma. In addition, there should be a way to store the solution so that it could be reused at a later time and perhaps used in Turing.jl. To facilitate adoption of the package, it might be helpful to include an example in the documentation. Presumably this would entail generating training vectors consisting of [mu, sigma, x] instead of only x, but I'm not entirely sure how to modify the example. I would be happy to help where I can, but I would require some guidance. Here is my full code:
|
Hi,
As explained in this discourse thread, the current normalizing flow example does not produce accurate results.
Chris Rackauckas suggested increasing the network size. I tried increasing the training set to 1,000 and the following larger network without success:
The text was updated successfully, but these errors were encountered: