Skip to content

Commit

Permalink
Allow general vectors in Dirichlet (#1243)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jan 12, 2021
1 parent 3a29bb1 commit ccebbd7
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 100 deletions.
165 changes: 66 additions & 99 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,53 +20,46 @@ Dirichlet(alpha) # Dirichlet distribution with parameter vector alpha
Dirichlet(k, a) # Dirichlet distribution with parameter a * ones(k)
```
"""
struct Dirichlet{T<:Real} <: ContinuousMultivariateDistribution
alpha::Vector{T}
struct Dirichlet{T<:Real,Ts<:AbstractVector{T},S<:Real} <: ContinuousMultivariateDistribution
alpha::Ts
alpha0::T
lmnB::T

function Dirichlet{T}(alpha::Vector{T}) where T
alpha0::T = zero(T)
lmnB::T = zero(T)
for i in 1:length(alpha)
ai = alpha[i]
ai > 0 ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
alpha0 += ai
lmnB += loggamma(ai)
end
lmnB -= loggamma(alpha0)
new{T}(alpha, alpha0, lmnB)
end
lmnB::S

function Dirichlet{T}(d::Integer, alpha::T) where T
alpha0 = alpha * d
new{T}(fill(alpha, d), alpha0, loggamma(alpha) * d - loggamma(alpha0))
function Dirichlet{T}(alpha::AbstractVector{T}; check_args=true) where T
if check_args && !all(x -> x > zero(x), alpha)
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
end
alpha0 = sum(alpha)
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
new{T,typeof(alpha),typeof(lmnB)}(alpha, alpha0, lmnB)
end
end

Dirichlet(alpha::Vector{T}) where {T<:Real} = Dirichlet{T}(alpha)
Dirichlet(d::Integer, alpha::T) where {T<:Real} = Dirichlet{T}(d, alpha)
Dirichlet(alpha::Vector{T}) where {T<:Integer} =
Dirichlet{Float64}(convert(Vector{Float64},alpha))
Dirichlet(d::Integer, alpha::Integer) = Dirichlet{Float64}(d, Float64(alpha))
function Dirichlet(alpha::AbstractVector{<:Real}; check_args=true)
Dirichlet{eltype(alpha)}(alpha; check_args=check_args)
end
Dirichlet(d::Integer, alpha::Real; kwargs...) = Dirichlet(Fill(alpha, d); kwargs...)

struct DirichletCanon
alpha::Vector{Float64}
struct DirichletCanon{T<:Real,Ts<:AbstractVector{T}}
alpha::Ts
end

length(d::DirichletCanon) = length(d.alpha)

Base.eltype(::Type{Dirichlet{T}}) where {T} = T
Base.eltype(::Type{<:Dirichlet{T}}) where {T} = T

#### Conversions
convert(::Type{Dirichlet{Float64}}, cf::DirichletCanon) = Dirichlet(cf.alpha)
convert(::Type{Dirichlet{T}}, alpha::Vector{S}) where {T<:Real, S<:Real} =
Dirichlet(convert(Vector{T}, alpha))
convert(::Type{Dirichlet{T}}, d::Dirichlet{S}) where {T<:Real, S<:Real} =
Dirichlet(convert(Vector{T}, d.alpha))


convert(::Type{Dirichlet{T}}, cf::DirichletCanon) where {T<:Real} =
Dirichlet(convert(AbstractVector{T}, cf.alpha))
convert(::Type{Dirichlet{T}}, alpha::AbstractVector{<:Real}) where {T<:Real} =
Dirichlet(convert(AbstractVector{T}, alpha))
convert(::Type{Dirichlet{T}}, d::Dirichlet{<:Real}) where {T<:Real} =
Dirichlet(convert(AbstractVector{T}, d.alpha))

convert(::Type{Dirichlet{T}}, cf::DirichletCanon{T}) where {T<:Real} = Dirichlet(cf.alpha)
convert(::Type{Dirichlet{T}}, alpha::AbstractVector{T}) where {T<:Real} =
Dirichlet(alpha)
convert(::Type{Dirichlet{T}}, d::Dirichlet{T}) where {T<:Real} = d

Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,))

Expand All @@ -78,70 +71,58 @@ params(d::Dirichlet) = (d.alpha,)
@inline partype(d::Dirichlet{T}) where {T<:Real} = T

function var(d::Dirichlet)
α = d.alpha
α0 = d.alpha0
c = 1.0 / (α0 * α0 * (α0 + 1.0))

k = length(α)
v = Vector{Float64}(undef, k)
for i = 1:k
@inbounds αi = α[i]
@inbounds v[i] = αi * (α0 - αi) * c
c = inv(α0^2 * (α0 + 1))
v = map(d.alpha) do αi
αi * (α0 - αi) * c
end
return v
end

function cov(d::Dirichlet)
α = d.alpha
α0 = d.alpha0
c = 1.0 / (α0 * α0 * (α0 + 1.0))
c = inv(α0^2 * (α0 + 1))

T = typeof(zero(eltype(α))^2 * c)
k = length(α)
C = Matrix{Float64}(undef, k, k)

C = Matrix{T}(undef, k, k)
for j = 1:k
αj = α[j]
αjc = αj * c
for i = 1:j-1
for i in 1:(j-1)
@inbounds C[i,j] = C[j,i]
end
@inbounds C[j,j] = (α0 - αj) * αjc
for i in (j+1):k
@inbounds C[i,j] = - α[i] * αjc
end
@inbounds C[j,j] = αj * (α0 - αj) * c
end

for j = 1:k-1, i = j+1:k
@inbounds C[i,j] = C[j,i]
end
return C
end

function entropy(d::Dirichlet)
α = d.alpha
α0 = d.alpha0
k = length(α)

en = d.lmnB + (α0 - k) * digamma(α0)
for j in 1:k
@inbounds αj = α[j]
en -= (αj - 1.0) * digamma(αj)
end
en = d.lmnB + (α0 - k) * digamma(α0) - sum(αj -> (αj - 1) * digamma(αj), d.alpha)
return en
end


function dirichlet_mode!(r::Vector{T}, α::Vector{T}, α0::T) where T <: Real
function dirichlet_mode!(r::AbstractVector{<:Real}, α::AbstractVector{<:Real}, α0::Real)
k = length(α)
s = α0 - k
for i = 1:k
@inbounds αi = α[i]
if αi <= one(T)
error("Dirichlet has a mode only when alpha[i] > 1 for all i" )
end
@inbounds r[i] = (αi - one(T)) / s
end
inv_s = inv(α0 - k)
@. r = inv_s *- 1)
return r
end

dirichlet_mode::Vector{T}, α0::T) where {T <: Real} = dirichlet_mode!(Vector{T}(undef, length(α)), α, α0)
function dirichlet_mode::AbstractVector{<:Real}, α0::Real)
all(αi < 1 for αi in α) || error("Dirichlet has a mode only when alpha[i] > 1 for all i")
inv_s = inv(α0 - length(α))
r = map(α) do αi
inv_s * (αi - 1)
end
return r
end

mode(d::Dirichlet) = dirichlet_mode(d.alpha, d.alpha0)
mode(d::DirichletCanon) = dirichlet_mode(d.alpha, sum(d.alpha))
Expand All @@ -151,34 +132,16 @@ modes(d::Dirichlet) = [mode(d)]

# Evaluation

function insupport(d::Dirichlet, x::AbstractVector{T}) where T<:Real
n = length(x)
if length(d.alpha) != n
return false
end
s = 0.0
for i in 1:n
xi = x[i]
if xi < 0.0
return false
end
s += xi
end
if abs(s - 1.0) > 1e-8
return false
end
return true
function insupport(d::Dirichlet, x::AbstractVector{<:Real})
return length(d) == length(x) && !any(x -> x < zero(x), x) && sum(x) 1
end

function _logpdf(d::Dirichlet{S}, x::AbstractVector{T}) where {S, T<:Real}
function _logpdf(d::Dirichlet, x::AbstractVector{<:Real})
if !insupport(d, x)
return convert(promote_type(S, T), -Inf)
return xlogy(one(eltype(d.alpha)), zero(eltype(x))) - d.lmnB
end
a = d.alpha
s = zero(promote_type(S, T))
for i in 1:length(a)
@inbounds s += xlogy(a[i] - one(S), x[i])
end
s = sum(xlogy(αi - 1, xi) for (αi, xi) in zip(d.alpha, x))
return s - d.lmnB
end

Expand All @@ -187,13 +150,17 @@ end
function _rand!(rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real})
s = 0.0
n = length(x)
α = d.alpha
for i in 1:n
@inbounds s += (x[i] = rand(rng, Gamma(α[i])))
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = rand(rng, Gamma(αi))
end
multiply!(x, inv(s)) # this returns x
multiply!(x, inv(sum(x))) # this returns x
end

function _rand!(rng::AbstractRNG,
d::Dirichlet{T,<:FillArrays.AbstractFill{T}},
x::AbstractVector{<:Real}) where {T<:Real}
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
multiply!(x, inv(sum(x))) # this returns x
end

#######################################
Expand Down
2 changes: 1 addition & 1 deletion test/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ d = Dirichlet(v)
@test length(d) == length(v)
@test d.alpha == v
@test d.alpha0 == sum(v)
@test d == typeof(d)(params(d)...)
@test d == Dirichlet{eltype(d)}(params(d)...)
@test d == deepcopy(d)

@test mean(d) v / sum(v)
Expand Down

0 comments on commit ccebbd7

Please sign in to comment.