From 18f1743aa324027a8e8b93f1d689f4112a846eb5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 8 May 2023 16:03:05 +0200 Subject: [PATCH 1/4] Implement `StatsAPI.pvalue` --- Project.toml | 4 +++- src/Distributions.jl | 4 ++++ src/statsapi.jl | 29 +++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 src/statsapi.jl diff --git a/Project.toml b/Project.toml index 7f1a5c6845..f2c7ae7228 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.89" +version = "0.25.90" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -15,6 +15,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -26,6 +27,7 @@ FillArrays = "0.9, 0.10, 0.11, 0.12, 0.13, 1" PDMats = "0.10, 0.11" QuadGK = "2" SpecialFunctions = "1.2, 2" +StatsAPI = "1.6" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.9.15, 1" julia = "1.3" diff --git a/src/Distributions.jl b/src/Distributions.jl index 0421ddbb0f..08c174f0ce 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -17,6 +17,7 @@ using Random import Random: default_rng, rand!, SamplerRangeInt import Statistics: mean, median, quantile, std, var, cov, cor +import StatsAPI import StatsBase: kurtosis, skewness, entropy, mode, modes, fit, kldivergence, loglikelihood, dof, span, params, params! @@ -306,6 +307,9 @@ include("pdfnorm.jl") include("mixtures/mixturemodel.jl") include("mixtures/unigmm.jl") +# Interface for StatsAPI +include("statsapi.jl") + # Extensions: Implementation of DensityInterface and ChainRulesCore API if !isdefined(Base, :get_extension) include("../ext/DistributionsChainRulesCoreExt/DistributionsChainRulesCoreExt.jl") diff --git a/src/statsapi.jl b/src/statsapi.jl new file mode 100644 index 0000000000..8a0e8bc0cb --- /dev/null +++ b/src/statsapi.jl @@ -0,0 +1,29 @@ +function _check_tail(tail::Symbol) + if tail !== :both && tail !== :left && tail !== :right + throw(ArgumentError("tail=$(tail) is invalid")) + end +end + +function StatsAPI.pvalue(dist::DiscreteUnivariateDistribution, x::Number; tail::Symbol=:both) + _check_tail(tail) + if tail === :both + p = 2 * min(ccdf(dist, x-1), cdf(dist, x)) + min(p, oneunit(p)) # if P(X = x) > 0, then possibly p > 1 + elseif tail === :left + cdf(dist, x) + else # tail === :right + ccdf(dist, x-1) + end +end + +function StatsAPI.pvalue(dist::ContinuousUnivariateDistribution, x::Number; tail::Symbol=:both) + _check_tail(tail) + if tail === :both + p = 2 * min(cdf(dist, x), ccdf(dist, x)) + min(p, oneunit(p)) # if P(X = x) > 0, then possibly p > 1 + elseif tail === :left + cdf(dist, x) + else # tail === :right + ccdf(dist, x) + end +end \ No newline at end of file From 3a9dc604aa54d1033c4b11dc3a9b54dcfe015daf Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 9 May 2023 13:11:40 +0200 Subject: [PATCH 2/4] Add tests --- src/statsapi.jl | 4 ++-- test/runtests.jl | 1 + test/statsapi.jl | 26 ++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 test/statsapi.jl diff --git a/src/statsapi.jl b/src/statsapi.jl index 8a0e8bc0cb..18112b3dc8 100644 --- a/src/statsapi.jl +++ b/src/statsapi.jl @@ -1,6 +1,6 @@ function _check_tail(tail::Symbol) if tail !== :both && tail !== :left && tail !== :right - throw(ArgumentError("tail=$(tail) is invalid")) + throw(ArgumentError("`tail=$(repr(tail))` is invalid")) end end @@ -26,4 +26,4 @@ function StatsAPI.pvalue(dist::ContinuousUnivariateDistribution, x::Number; tail else # tail === :right ccdf(dist, x) end -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index 8d38c0abcc..e62ffe5a71 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -90,6 +90,7 @@ const tests = [ "multivariate/product", "eachvariate", "univariate/continuous/triangular", + "statsapi.jl", ### missing files compared to /src: # "common", diff --git a/test/statsapi.jl b/test/statsapi.jl new file mode 100644 index 0000000000..3cf9daed4f --- /dev/null +++ b/test/statsapi.jl @@ -0,0 +1,26 @@ +using Distributions +using StatsAPI: pvalue + +using Test + +@testset "pvalue" begin + # For two discrete and two continuous distribution + for dist in (Binomial(10, 0.3), Poisson(0.3), Normal(1.4, 2.1), Gamma(1.9, 0.8)) + # Draw sample + x = rand(dist) + + # Draw 10^6 additional samples + ys = rand(dist, 1_000_000) + + # Check that empirical frequencies match pvalues of left/right tail approximately + @test pvalue(dist, x; tail=:left) ≈ count(y -> y ≤ x, ys) / length(ys) rtol=5e-3 + @test pvalue(dist, x; tail=:right) ≈ count(y -> y ≥ x, ys) / length(ys) rtol=5e-3 + + # Check consistency of pvalues of both tails + @test pvalue(dist, x; tail=:both) == + min(1, 2 * min(pvalue(dist, x; tail=:left), pvalue(dist, x; tail=:right))) + + # Incorrect value for keyword argument + @test_throws ArgumentError("`tail=:l` is invalid") pvalue(dist, x; tail=:l) + end +end From 9e363ca83dc983ff1fcd35419c69d93c37ecf89d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 9 May 2023 15:03:24 +0200 Subject: [PATCH 3/4] Update test/runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index e62ffe5a71..7eb9020667 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -90,7 +90,7 @@ const tests = [ "multivariate/product", "eachvariate", "univariate/continuous/triangular", - "statsapi.jl", + "statsapi", ### missing files compared to /src: # "common", From d73145d874a005548c0b294f248230886198fae1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 9 May 2023 23:25:49 +0200 Subject: [PATCH 4/4] Update test/statsapi.jl Co-authored-by: Alex Arslan --- test/statsapi.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/statsapi.jl b/test/statsapi.jl index 3cf9daed4f..37745c5223 100644 --- a/test/statsapi.jl +++ b/test/statsapi.jl @@ -13,8 +13,8 @@ using Test ys = rand(dist, 1_000_000) # Check that empirical frequencies match pvalues of left/right tail approximately - @test pvalue(dist, x; tail=:left) ≈ count(y -> y ≤ x, ys) / length(ys) rtol=5e-3 - @test pvalue(dist, x; tail=:right) ≈ count(y -> y ≥ x, ys) / length(ys) rtol=5e-3 + @test pvalue(dist, x; tail=:left) ≈ mean(≤(x), ys) rtol=5e-3 + @test pvalue(dist, x; tail=:right) ≈ mean(≥(x), ys) rtol=5e-3 # Check consistency of pvalues of both tails @test pvalue(dist, x; tail=:both) ==