Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more tests for single distributions #621

Merged
merged 15 commits into from
Jan 27, 2019
3 changes: 2 additions & 1 deletion src/inference/adapt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ function adapt!(tp::ThreePhaseAdapter, stats::Real, θ; adapt_ϵ=false, adapt_M=
tp.state.n += 1
if tp.state.n == tp.n_adapts
if adapt_ϵ
tp.ssa.state.ϵ = exp(tp.ssa.state.x_bar)
ϵ = exp(tp.ssa.state.x_bar)
tp.ssa.state.ϵ = min(one(ϵ), ϵ)
end
@info " Adapted ϵ = $(getss(tp)), std = $(string(tp.pc)); $(tp.state.n) iterations is used for adaption."
else
Expand Down
3 changes: 2 additions & 1 deletion src/inference/adapt/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ function adapt_stepsize!(da::DualAveraging, stats::Real)
if isnan(ϵ) || isinf(ϵ)
@warn "Incorrect ϵ = $ϵ; ϵ_previous = $(da.state.ϵ) is used instead."
else
da.state.ϵ = ϵ
ϵ < one(ϵ) && @warn "$ϵ exceeds 1.0; capped to 1.0 for numerical stability"
da.state.ϵ = min(one(ϵ), ϵ)
end
da.state.x_bar = x_bar
da.state.H_bar = H_bar
Expand Down
105 changes: 105 additions & 0 deletions test/models.jl/single_dist_correctness.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using Test, Turing
turnprogress(false)

n_samples = 20_000
mean_atol = 0.25
var_atol = 1.0
multi_dim = 10

# 1. UnivariateDistribution
# NOTE: Noncentral distributions are commented out because of
# AD imcompatibility of their logpdf functions
dist_uni = [Arcsine(1, 3),
Beta(2, 1),
# NoncentralBeta(2, 1, 1),
BetaPrime(1, 1),
Biweight(0, 1),
Chi(7),
Chisq(7),
# NoncentralChisq(7, 1),
Cosine(0, 1),
Epanechnikov(0, 1),
Erlang(2, 3),
Exponential(0.1),
FDist(7, 7),
# NoncentralF(7, 7, 1),
Frechet(2, 0.5),
Normal(0, 1),
GeneralizedExtremeValue(0, 1, 0.5),
GeneralizedPareto(0, 1, 0.5),
Gumbel(0, 0.5),
InverseGaussian(1, 1),
Kolmogorov(),
# KSDist(2), # no pdf function defined
# KSOneSided(2), # no pdf function defined
Laplace(0, 0.5),
Levy(0, 1),
Logistic(0, 1),
LogNormal(0, 1),
Gamma(2, 3),
InverseGamma(3, 1),
NormalCanon(0, 1),
NormalInverseGaussian(0, 2, 1, 1),
Pareto(1, 1),
Rayleigh(1),
SymTriangularDist(0, 1),
TDist(2.5),
# NoncentralT(2.5, 1),
TriangularDist(1, 3, 2),
Triweight(0, 1),
Uniform(0, 1),
# VonMises(0, 1), WARNING: this is commented are because the test is broken
Weibull(2, 1),
# Cauchy(0, 1), # mean and variance are undefined for Cauchy
]

# 2. MultivariateDistribution
dist_multi = [MvNormal(zeros(multi_dim), ones(multi_dim)),
MvNormal(zeros(2), [2 1; 1 4]),
Dirichlet(multi_dim, 2.0),
]

# 3. MatrixDistribution
dist_matrix = [Wishart(7, [1 0.5; 0.5 1]),
InverseWishart(7, [1 0.5; 0.5 1]),
]

@testset "Correctness test for single distributions" begin
for (dist_set, dist_list) ∈ [("UnivariateDistribution", dist_uni),
("MultivariateDistribution", dist_multi),
("MatrixDistribution", dist_matrix)
]
@testset "$(string(dist_set))" begin
for dist in dist_list
@testset "$(string(typeof(dist)))" begin
@info "Distribution(params)" dist

@model m() = begin
x ~ dist
end
chn = sample(m(), NUTS(n_samples, 0.8))

chn_xs = chn[:x][1:2:end] # thining by halving

# Mean
dist_mean = mean(dist)
if !all(isnan.(dist_mean)) && !all(isinf.(dist_mean))
chn_mean = mean(chn_xs)
@test chn_mean ≈ dist_mean atol=(mean_atol * length(chn_mean))
end

# var() for Distributions.MatrixDistribution is not defined
if !(dist isa Distributions.MatrixDistribution)
# Variance
dist_var = var(dist)
if !all(isnan.(dist_var)) && !all(isinf.(dist_var))
chn_var = var(chn_xs)
@test chn_var ≈ dist_var atol=(var_atol * length(chn_var))
end
end
end
end
end
end
end
# Wishart(7, [1 0.5; 0.5 1])
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ cd(path); include("utility.jl")
if get(ENV, "TRAVIS", "false") == "true"
# If Travis is testing, separate the tests.
numerical_tests = [joinpath("hmc.jl", "matrix_support.jl"),
joinpath("mh.jl", "mh_cons.jl")]
joinpath("mh.jl", "mh_cons.jl"),
joinpath("models.jl", "single_dist_correctness.jl")]

if ENV["STAGE"] == "test"
runtests(exclude = numerical_tests)
Expand Down