Skip to content

Commit

Permalink
Move test utilities to an extension (#1791)
Browse files Browse the repository at this point in the history
* Move test utilities to an extension

* Fix signature and docstring

* Also qualify AbstractRNG

* Fix Julia < 1.9

* Fix for 1.3?

* Simplify the TestUtils stub
  • Loading branch information
devmotion authored Nov 3, 2023
1 parent e407fa5 commit e666d74
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 94 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Distributions"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
authors = ["JuliaStats"]
version = "0.25.102"
version = "0.25.103"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -22,10 +22,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore"
DistributionsDensityInterfaceExt = "DensityInterface"
DistributionsTestExt = "Test"

[compat]
ChainRulesCore = "1"
Expand Down
101 changes: 101 additions & 0 deletions ext/DistributionsTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
module DistributionsTestExt

using Distributions
using Distributions.LinearAlgebra
using Distributions.Random
using Test

__rand(::Nothing, args...) = rand(args...)
__rand(rng::AbstractRNG, args...) = rand(rng, args...)

__rand!(::Nothing, args...) = rand!(args...)
__rand!(rng::AbstractRNG, args...) = rand!(rng, args...)

"""
test_mvnormal(
g::AbstractMvNormal,
n_tsamples::Int=10^6,
rng::Union{Random.AbstractRNG, Nothing}=nothing,
)
Test that `AbstractMvNormal` implements the expected API.
!!! Note
On Julia >= 1.9, you have to load the `Test` standard library to be able to use
this function.
"""
function Distributions.TestUtils.test_mvnormal(
g::AbstractMvNormal, n_tsamples::Int=10^6, rng::Union{AbstractRNG, Nothing}=nothing
)
d = length(g)
μ = mean(g)
Σ = cov(g)
@test length(μ) == d
@test size(Σ) == (d, d)
@test var(g) diag(Σ)
@test entropy(g) 0.5 * logdet(2π ** Σ)
ldcov = logdetcov(g)
@test ldcov logdet(Σ)
vs = diag(Σ)
@test g == typeof(g)(params(g)...)
@test g == deepcopy(g)
@test minimum(g) == fill(-Inf, d)
@test maximum(g) == fill(Inf, d)
@test extrema(g) == (minimum(g), maximum(g))
@test isless(extrema(g)...)

# test sampling for AbstractMatrix (here, a SubArray):
subX = view(__rand(rng, d, 2d), :, 1:d)
@test isa(__rand!(rng, g, subX), SubArray)

# sampling
@test isa(__rand(rng, g), Vector{Float64})
X = __rand(rng, g, n_tsamples)
emp_mu = vec(mean(X, dims=2))
Z = X .- emp_mu
emp_cov = (Z * Z') * inv(n_tsamples)

mean_atols = 8 .* sqrt.(vs ./ n_tsamples)
cov_atols = 10 .* sqrt.(vs .* vs') ./ sqrt.(n_tsamples)
for i = 1:d
@test isapprox(emp_mu[i], μ[i], atol=mean_atols[i])
end
for i = 1:d, j = 1:d
@test isapprox(emp_cov[i,j], Σ[i,j], atol=cov_atols[i,j])
end

X = rand(MersenneTwister(14), g, n_tsamples)
Y = rand(MersenneTwister(14), g, n_tsamples)
@test X == Y
emp_mu = vec(mean(X, dims=2))
Z = X .- emp_mu
emp_cov = (Z * Z') * inv(n_tsamples)
for i = 1:d
@test isapprox(emp_mu[i] , μ[i] , atol=mean_atols[i])
end
for i = 1:d, j = 1:d
@test isapprox(emp_cov[i,j], Σ[i,j], atol=cov_atols[i,j])
end


# evaluation of sqmahal & logpdf
U = X .- μ
sqm = vec(sum(U .*\ U), dims=1))
for i = 1:min(100, n_tsamples)
@test sqmahal(g, X[:,i]) sqm[i]
end
@test sqmahal(g, X) sqm

lp = -0.5 .* sqm .- 0.5 * (d * log(2.0 * pi) + ldcov)
for i = 1:min(100, n_tsamples)
@test logpdf(g, X[:,i]) lp[i]
end
@test logpdf(g, X) lp

# log likelihood
@test loglikelihood(g, X) sum(i -> Distributions._logpdf(g, X[:,i]), 1:n_tsamples)
@test loglikelihood(g, X[:, 1]) logpdf(g, X[:, 1])
@test loglikelihood(g, [X[:, i] for i in axes(X, 2)]) loglikelihood(g, X)
end

end # module
7 changes: 4 additions & 3 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,16 @@ include("mixtures/unigmm.jl")
# Interface for StatsAPI
include("statsapi.jl")

# Testing utilities for other packages which implement distributions.
include("test_utils.jl")

# Extensions: Implementation of DensityInterface and ChainRulesCore API
if !isdefined(Base, :get_extension)
include("../ext/DistributionsChainRulesCoreExt/DistributionsChainRulesCoreExt.jl")
include("../ext/DistributionsDensityInterfaceExt.jl")
include("../ext/DistributionsTestExt.jl")
end

# Testing utilities for other packages which implement distributions.
include("test_utils.jl")

include("deprecates.jl")

"""
Expand Down
103 changes: 13 additions & 90 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,96 +1,19 @@
module TestUtils

using Distributions
using LinearAlgebra
using Random
using Test


__rand(::Nothing, args...) = rand(args...)
__rand(rng::AbstractRNG, args...) = rand(rng, args...)

__rand!(::Nothing, args...) = rand!(args...)
__rand!(rng::AbstractRNG, args...) = rand!(rng, args...)

"""
test_mvnormal(
g::AbstractMvNormal, n_tsamples::Int=10^6, rng::AbstractRNG=Random.default_rng()
)
Test that `AbstractMvNormal` implements the expected API.
"""
function test_mvnormal(
g::AbstractMvNormal, n_tsamples::Int=10^6, rng::Union{AbstractRNG, Nothing}=nothing
)
d = length(g)
μ = mean(g)
Σ = cov(g)
@test length(μ) == d
@test size(Σ) == (d, d)
@test var(g) diag(Σ)
@test entropy(g) 0.5 * logdet(2π ** Σ)
ldcov = logdetcov(g)
@test ldcov logdet(Σ)
vs = diag(Σ)
@test g == typeof(g)(params(g)...)
@test g == deepcopy(g)
@test minimum(g) == fill(-Inf, d)
@test maximum(g) == fill(Inf, d)
@test extrema(g) == (minimum(g), maximum(g))
@test isless(extrema(g)...)

# test sampling for AbstractMatrix (here, a SubArray):
subX = view(__rand(rng, d, 2d), :, 1:d)
@test isa(__rand!(rng, g, subX), SubArray)

# sampling
@test isa(__rand(rng, g), Vector{Float64})
X = __rand(rng, g, n_tsamples)
emp_mu = vec(mean(X, dims=2))
Z = X .- emp_mu
emp_cov = (Z * Z') * inv(n_tsamples)

mean_atols = 8 .* sqrt.(vs ./ n_tsamples)
cov_atols = 10 .* sqrt.(vs .* vs') ./ sqrt.(n_tsamples)
for i = 1:d
@test isapprox(emp_mu[i], μ[i], atol=mean_atols[i])
end
for i = 1:d, j = 1:d
@test isapprox(emp_cov[i,j], Σ[i,j], atol=cov_atols[i,j])
import ..Distributions

function test_mvnormal end

if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint)
function __init__()
# Better error message if users forget to load Test
Base.Experimental.register_error_hint(MethodError) do io, exc, _, _
if exc.f === test_mvnormal &&
(Base.get_extension(Distributions, :DistributionsTestExt) === nothing)
print(io, "\nDid you forget to load Test?")
end
end
end

X = rand(MersenneTwister(14), g, n_tsamples)
Y = rand(MersenneTwister(14), g, n_tsamples)
@test X == Y
emp_mu = vec(mean(X, dims=2))
Z = X .- emp_mu
emp_cov = (Z * Z') * inv(n_tsamples)
for i = 1:d
@test isapprox(emp_mu[i] , μ[i] , atol=mean_atols[i])
end
for i = 1:d, j = 1:d
@test isapprox(emp_cov[i,j], Σ[i,j], atol=cov_atols[i,j])
end


# evaluation of sqmahal & logpdf
U = X .- μ
sqm = vec(sum(U .*\ U), dims=1))
for i = 1:min(100, n_tsamples)
@test sqmahal(g, X[:,i]) sqm[i]
end
@test sqmahal(g, X) sqm

lp = -0.5 .* sqm .- 0.5 * (d * log(2.0 * pi) + ldcov)
for i = 1:min(100, n_tsamples)
@test logpdf(g, X[:,i]) lp[i]
end
@test logpdf(g, X) lp

# log likelihood
@test loglikelihood(g, X) sum(i -> Distributions._logpdf(g, X[:,i]), 1:n_tsamples)
@test loglikelihood(g, X[:, 1]) logpdf(g, X[:, 1])
@test loglikelihood(g, [X[:, i] for i in axes(X, 2)]) loglikelihood(g, X)
end

end

0 comments on commit e666d74

Please sign in to comment.