From 003db0dc04404d1e8f53b4fdb9f76c229e784969 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 28 Jan 2021 10:33:45 +0100 Subject: [PATCH] Fix entropy of `Dirichlet` --- Project.toml | 2 +- src/multivariate/dirichlet.jl | 4 +++- test/dirichlet.jl | 9 +++++++++ test/runtests.jl | 1 + 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 1c4944d76..371373a60 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.24.12" +version = "0.24.13" [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 1a5510511..28f5dd3cd 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -104,7 +104,9 @@ end function entropy(d::Dirichlet) α0 = d.alpha0 - en = d.lmnB + (α0 - k) * digamma(α0) - sum(αj -> (αj - 1) * digamma(αj), d.alpha) + α = d.alpha + k = length(d.alpha) + en = d.lmnB + (α0 - k) * digamma(α0) - sum(αj -> (αj - 1) * digamma(αj), α) return en end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 530dd4ab6..d8b70a6db 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -113,3 +113,12 @@ rng = MersenneTwister(123) # @test r.alpha ≈ d.alpha atol=0.25 end end + +@testset "Dirichlet: entropy" begin + α = exp.(rand(2)) + @test entropy(Dirichlet(α)) ≈ entropy(Beta(α...)) + + N = 10 + @test entropy(Dirichlet(N, 1)) ≈ -loggamma(N) + @test entropy(Dirichlet(ones(N))) ≈ -loggamma(N) +end diff --git a/test/runtests.jl b/test/runtests.jl index f59e1ece2..1a2af85a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using PDMats # test dependencies using Test using Distributed using Random +using SpecialFunctions using StatsBase using LinearAlgebra using HypothesisTests