diff --git a/Project.toml b/Project.toml index 7f1a5c684..f2c7ae722 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.89" +version = "0.25.90" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -15,6 +15,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -26,6 +27,7 @@ FillArrays = "0.9, 0.10, 0.11, 0.12, 0.13, 1" PDMats = "0.10, 0.11" QuadGK = "2" SpecialFunctions = "1.2, 2" +StatsAPI = "1.6" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.9.15, 1" julia = "1.3" diff --git a/src/Distributions.jl b/src/Distributions.jl index 0421ddbb0..08c174f0c 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -17,6 +17,7 @@ using Random import Random: default_rng, rand!, SamplerRangeInt import Statistics: mean, median, quantile, std, var, cov, cor +import StatsAPI import StatsBase: kurtosis, skewness, entropy, mode, modes, fit, kldivergence, loglikelihood, dof, span, params, params! @@ -306,6 +307,9 @@ include("pdfnorm.jl") include("mixtures/mixturemodel.jl") include("mixtures/unigmm.jl") +# Interface for StatsAPI +include("statsapi.jl") + # Extensions: Implementation of DensityInterface and ChainRulesCore API if !isdefined(Base, :get_extension) include("../ext/DistributionsChainRulesCoreExt/DistributionsChainRulesCoreExt.jl") diff --git a/src/statsapi.jl b/src/statsapi.jl new file mode 100644 index 000000000..18112b3dc --- /dev/null +++ b/src/statsapi.jl @@ -0,0 +1,29 @@ +function _check_tail(tail::Symbol) + if tail !== :both && tail !== :left && tail !== :right + throw(ArgumentError("`tail=$(repr(tail))` is invalid")) + end +end + +function StatsAPI.pvalue(dist::DiscreteUnivariateDistribution, x::Number; tail::Symbol=:both) + _check_tail(tail) + if tail === :both + p = 2 * min(ccdf(dist, x-1), cdf(dist, x)) + min(p, oneunit(p)) # if P(X = x) > 0, then possibly p > 1 + elseif tail === :left + cdf(dist, x) + else # tail === :right + ccdf(dist, x-1) + end +end + +function StatsAPI.pvalue(dist::ContinuousUnivariateDistribution, x::Number; tail::Symbol=:both) + _check_tail(tail) + if tail === :both + p = 2 * min(cdf(dist, x), ccdf(dist, x)) + min(p, oneunit(p)) # if P(X = x) > 0, then possibly p > 1 + elseif tail === :left + cdf(dist, x) + else # tail === :right + ccdf(dist, x) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8d38c0abc..7eb902066 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -90,6 +90,7 @@ const tests = [ "multivariate/product", "eachvariate", "univariate/continuous/triangular", + "statsapi", ### missing files compared to /src: # "common", diff --git a/test/statsapi.jl b/test/statsapi.jl new file mode 100644 index 000000000..37745c522 --- /dev/null +++ b/test/statsapi.jl @@ -0,0 +1,26 @@ +using Distributions +using StatsAPI: pvalue + +using Test + +@testset "pvalue" begin + # For two discrete and two continuous distribution + for dist in (Binomial(10, 0.3), Poisson(0.3), Normal(1.4, 2.1), Gamma(1.9, 0.8)) + # Draw sample + x = rand(dist) + + # Draw 10^6 additional samples + ys = rand(dist, 1_000_000) + + # Check that empirical frequencies match pvalues of left/right tail approximately + @test pvalue(dist, x; tail=:left) ≈ mean(≤(x), ys) rtol=5e-3 + @test pvalue(dist, x; tail=:right) ≈ mean(≥(x), ys) rtol=5e-3 + + # Check consistency of pvalues of both tails + @test pvalue(dist, x; tail=:both) == + min(1, 2 * min(pvalue(dist, x; tail=:left), pvalue(dist, x; tail=:right))) + + # Incorrect value for keyword argument + @test_throws ArgumentError("`tail=:l` is invalid") pvalue(dist, x; tail=:l) + end +end