diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index a1b87d351..059b68a2d 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -170,11 +170,14 @@ function insupport(d::Dirichlet, x::AbstractVector{T}) where T<:Real return true end -function _logpdf(d::Dirichlet, x::AbstractVector{T}) where T<:Real +function _logpdf(d::Dirichlet{S}, x::AbstractVector{T}) where {S, T<:Real} + if !insupport(d, x) + return convert(promote_type(S, T), -Inf) + end a = d.alpha - s = 0. + s = zero(promote_type(S, T)) for i in 1:length(a) - @inbounds s += (a[i] - 1.0) * log(x[i]) + @inbounds s += xlogy(a[i] - one(S), x[i]) end return s - d.lmnB end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 5034098e4..098a10a9f 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -22,6 +22,12 @@ d = Dirichlet(3, 2.0) @test cov(d) ≈ [8 -4 -4; -4 8 -4; -4 -4 8] / (36 * 7) @test var(d) ≈ diag(cov(d)) +@test pdf(Dirichlet([1, 1]), [0, 1]) ≈ 1.0 +@test pdf(Dirichlet([1f0, 1f0]), [0f0, 1f0]) ≈ 1.0f0 +@test typeof(pdf(Dirichlet([1f0, 1f0]), [0f0, 1f0])) == Float32 + +@test pdf(d, [-1, 1, 0]) ≈ 0.0 +@test pdf(d, [0, 0, 1]) ≈ 0.0 @test pdf(d, [0.2, 0.3, 0.5]) ≈ 3.6 @test pdf(d, [0.4, 0.5, 0.1]) ≈ 2.4 @test logpdf(d, [0.2, 0.3, 0.5]) ≈ log(3.6)