Skip to content

Commit

Permalink
Add more tests for single distributions (#621)
Browse files Browse the repository at this point in the history
* add uni tests

* force string conversion

* add more dists

* fix typo in runtest

* exclude var for MatrixDist

* improve script

* set a upper limit for the stepsize in adaptation

* cap end-of-window step size update

* improve type stability

* resolve Mohamed's comment

* add inv wishart test (suggested by Will)

* add more continous univariate dists

* fix incorrect normal inverse gaussian parameters

* add warning msg for capping ss to 1.0

* comment out VonMises distribution
  • Loading branch information
xukai92 authored Jan 27, 2019
1 parent ec47816 commit c8e40c5
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/inference/adapt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ function adapt!(tp::ThreePhaseAdapter, stats::Real, θ; adapt_ϵ=false, adapt_M=
tp.state.n += 1
if tp.state.n == tp.n_adapts
if adapt_ϵ
tp.ssa.state.ϵ = exp(tp.ssa.state.x_bar)
ϵ = exp(tp.ssa.state.x_bar)
tp.ssa.state.ϵ = min(one(ϵ), ϵ)
end
@info " Adapted ϵ = $(getss(tp)), std = $(string(tp.pc)); $(tp.state.n) iterations is used for adaption."
else
Expand Down
3 changes: 2 additions & 1 deletion src/inference/adapt/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ function adapt_stepsize!(da::DualAveraging, stats::Real)
if isnan(ϵ) || isinf(ϵ)
@warn "Incorrect ϵ = ; ϵ_previous = $(da.state.ϵ) is used instead."
else
da.state.ϵ = ϵ
ϵ < one(ϵ) && @warn " exceeds 1.0; capped to 1.0 for numerical stability"
da.state.ϵ = min(one(ϵ), ϵ)
end
da.state.x_bar = x_bar
da.state.H_bar = H_bar
Expand Down
105 changes: 105 additions & 0 deletions test/models.jl/single_dist_correctness.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using Test, Turing
turnprogress(false)

n_samples = 20_000
mean_atol = 0.25
var_atol = 1.0
multi_dim = 10

# 1. UnivariateDistribution
# NOTE: Noncentral distributions are commented out because of
# AD imcompatibility of their logpdf functions
dist_uni = [Arcsine(1, 3),
Beta(2, 1),
# NoncentralBeta(2, 1, 1),
BetaPrime(1, 1),
Biweight(0, 1),
Chi(7),
Chisq(7),
# NoncentralChisq(7, 1),
Cosine(0, 1),
Epanechnikov(0, 1),
Erlang(2, 3),
Exponential(0.1),
FDist(7, 7),
# NoncentralF(7, 7, 1),
Frechet(2, 0.5),
Normal(0, 1),
GeneralizedExtremeValue(0, 1, 0.5),
GeneralizedPareto(0, 1, 0.5),
Gumbel(0, 0.5),
InverseGaussian(1, 1),
Kolmogorov(),
# KSDist(2), # no pdf function defined
# KSOneSided(2), # no pdf function defined
Laplace(0, 0.5),
Levy(0, 1),
Logistic(0, 1),
LogNormal(0, 1),
Gamma(2, 3),
InverseGamma(3, 1),
NormalCanon(0, 1),
NormalInverseGaussian(0, 2, 1, 1),
Pareto(1, 1),
Rayleigh(1),
SymTriangularDist(0, 1),
TDist(2.5),
# NoncentralT(2.5, 1),
TriangularDist(1, 3, 2),
Triweight(0, 1),
Uniform(0, 1),
# VonMises(0, 1), WARNING: this is commented are because the test is broken
Weibull(2, 1),
# Cauchy(0, 1), # mean and variance are undefined for Cauchy
]

# 2. MultivariateDistribution
dist_multi = [MvNormal(zeros(multi_dim), ones(multi_dim)),
MvNormal(zeros(2), [2 1; 1 4]),
Dirichlet(multi_dim, 2.0),
]

# 3. MatrixDistribution
dist_matrix = [Wishart(7, [1 0.5; 0.5 1]),
InverseWishart(7, [1 0.5; 0.5 1]),
]

@testset "Correctness test for single distributions" begin
for (dist_set, dist_list) [("UnivariateDistribution", dist_uni),
("MultivariateDistribution", dist_multi),
("MatrixDistribution", dist_matrix)
]
@testset "$(string(dist_set))" begin
for dist in dist_list
@testset "$(string(typeof(dist)))" begin
@info "Distribution(params)" dist

@model m() = begin
x ~ dist
end
chn = sample(m(), NUTS(n_samples, 0.8))

chn_xs = chn[:x][1:2:end] # thining by halving

# Mean
dist_mean = mean(dist)
if !all(isnan.(dist_mean)) && !all(isinf.(dist_mean))
chn_mean = mean(chn_xs)
@test chn_mean dist_mean atol=(mean_atol * length(chn_mean))
end

# var() for Distributions.MatrixDistribution is not defined
if !(dist isa Distributions.MatrixDistribution)
# Variance
dist_var = var(dist)
if !all(isnan.(dist_var)) && !all(isinf.(dist_var))
chn_var = var(chn_xs)
@test chn_var dist_var atol=(var_atol * length(chn_var))
end
end
end
end
end
end
end
# Wishart(7, [1 0.5; 0.5 1])
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ cd(path); include("utility.jl")
if get(ENV, "TRAVIS", "false") == "true"
# If Travis is testing, separate the tests.
numerical_tests = [joinpath("hmc.jl", "matrix_support.jl"),
joinpath("mh.jl", "mh_cons.jl")]
joinpath("mh.jl", "mh_cons.jl"),
joinpath("models.jl", "single_dist_correctness.jl")]

if ENV["STAGE"] == "test"
runtests(exclude = numerical_tests)
Expand Down

0 comments on commit c8e40c5

Please sign in to comment.