Skip to content

Commit

Permalink
Fix #1241; correct the behavior of logpdf(d::Dirichlet, x) on values …
Browse files Browse the repository at this point in the history
…at the boundary of and outside of the support, and replace hard-coded Float64 values with more general types (#1242)
  • Loading branch information
yurivish authored Dec 20, 2020
1 parent 530b9e0 commit 2892421
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,14 @@ function insupport(d::Dirichlet, x::AbstractVector{T}) where T<:Real
return true
end

function _logpdf(d::Dirichlet, x::AbstractVector{T}) where T<:Real
function _logpdf(d::Dirichlet{S}, x::AbstractVector{T}) where {S, T<:Real}
if !insupport(d, x)
return convert(promote_type(S, T), -Inf)
end
a = d.alpha
s = 0.
s = zero(promote_type(S, T))
for i in 1:length(a)
@inbounds s += (a[i] - 1.0) * log(x[i])
@inbounds s += xlogy(a[i] - one(S), x[i])
end
return s - d.lmnB
end
Expand Down
6 changes: 6 additions & 0 deletions test/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ d = Dirichlet(3, 2.0)
@test cov(d) [8 -4 -4; -4 8 -4; -4 -4 8] / (36 * 7)
@test var(d) diag(cov(d))

@test pdf(Dirichlet([1, 1]), [0, 1]) 1.0
@test pdf(Dirichlet([1f0, 1f0]), [0f0, 1f0]) 1.0f0
@test typeof(pdf(Dirichlet([1f0, 1f0]), [0f0, 1f0])) == Float32

@test pdf(d, [-1, 1, 0]) 0.0
@test pdf(d, [0, 0, 1]) 0.0
@test pdf(d, [0.2, 0.3, 0.5]) 3.6
@test pdf(d, [0.4, 0.5, 0.1]) 2.4
@test logpdf(d, [0.2, 0.3, 0.5]) log(3.6)
Expand Down

0 comments on commit 2892421

Please sign in to comment.