Skip to content

Commit

Permalink
Improve type inference of MixtureModel (#1308)
Browse files Browse the repository at this point in the history
* Improve type inference of `MixtureModel`

* Bump version

* Remove comment

* Bump minor version
  • Loading branch information
devmotion authored May 2, 2021
1 parent 5604316 commit 99b2b57
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 88 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Distributions"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
authors = ["JuliaStats"]
version = "0.24.18"
version = "0.25.0"

[deps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
74 changes: 10 additions & 64 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ A mixture of distributions, parametrized on:
* `C` distribution family of the mixture
* `CT` the type for probabilities of the prior
"""
struct MixtureModel{VF<:VariateForm,VS<:ValueSupport,C<:Distribution,CT<:Real} <: AbstractMixtureModel{VF,VS,C}
struct MixtureModel{VF<:VariateForm,VS<:ValueSupport,C<:Distribution,CT<:Categorical} <: AbstractMixtureModel{VF,VS,C}
components::Vector{C}
prior::Categorical{CT}
prior::CT

function MixtureModel{VF,VS,C}(cs::Vector{C}, pri::Categorical{CT}) where {VF,VS,C,CT}
function MixtureModel{VF,VS,C}(cs::Vector{C}, pri::CT) where {VF,VS,C,CT}
length(cs) == ncategories(pri) ||
error("The number of components does not match the length of prior.")
new{VF,VS,C,CT}(cs, pri)
Expand Down Expand Up @@ -171,16 +171,8 @@ minimum(d::MixtureModel) = minimum([minimum(dci) for dci in d.components])
maximum(d::MixtureModel) = maximum([maximum(dci) for dci in d.components])

function mean(d::UnivariateMixture)
K = ncomponents(d)
p = probs(d)
m = 0.0
for i = 1:K
pi = p[i]
if pi > 0.0
c = component(d, i)
m += mean(c) * pi
end
end
m = sum(pi * mean(component(d, i)) for (i, pi) in enumerate(p) if !iszero(pi))
return m
end

Expand Down Expand Up @@ -281,38 +273,22 @@ end
#### Evaluation

function insupport(d::AbstractMixtureModel, x::AbstractVector)
K = ncomponents(d)
p = probs(d)
@inbounds for i in eachindex(p)
pi = p[i]
if pi > 0.0 && insupport(component(d, i), x)
return true
end
end
return false
return any(insupport(component(d, i), x) for (i, pi) in enumerate(p) if !iszero(pi))
end

function _cdf(d::UnivariateMixture, x::Real)
K = ncomponents(d)
p = probs(d)
r = 0.0
@inbounds for i in eachindex(p)
pi = p[i]
if pi > 0.0
c = component(d, i)
r += pi * cdf(c, x)
end
end
r = sum(pi * cdf(component(d, i), x) for (i, pi) in enumerate(p) if !iszero(pi))
return r
end

cdf(d::UnivariateMixture{Continuous}, x::Real) = _cdf(d, x)
cdf(d::UnivariateMixture{Discrete}, x::Integer) = _cdf(d, x)

function _mixpdf1(d::AbstractMixtureModel, x)
ps = probs(d)
cs = components(d)
return sum((ps[i] > 0) * (ps[i] * pdf(cs[i], x)) for i in eachindex(ps))
p = probs(d)
return sum(pi * pdf(component(d, i), x) for (i, pi) in enumerate(p) if !iszero(pi))
end

function _mixpdf!(r::AbstractArray, d::AbstractMixtureModel, x)
Expand All @@ -335,39 +311,9 @@ function _mixpdf!(r::AbstractArray, d::AbstractMixtureModel, x)
end

function _mixlogpdf1(d::AbstractMixtureModel, x)
# using the formula below for numerical stability
#
# logpdf(d, x) = log(sum_i pri[i] * pdf(cs[i], x))
# = log(sum_i pri[i] * exp(logpdf(cs[i], x)))
# = log(sum_i exp(logpri[i] + logpdf(cs[i], x)))
# = m + log(sum_i exp(logpri[i] + logpdf(cs[i], x) - m))
#
# m is chosen to be the maximum of logpri[i] + logpdf(cs[i], x)
# such that the argument of exp is in a reasonable range
#

K = ncomponents(d)
p = probs(d)
lp = Vector{eltype(p)}(undef, K)
m = -Inf # m <- the maximum of log(p(cs[i], x)) + log(pri[i])
@inbounds for i in eachindex(p)
pi = p[i]
if pi > 0.0
# lp[i] <- log(p(cs[i], x)) + log(pri[i])
lp_i = logpdf(component(d, i), x) + log(pi)
lp[i] = lp_i
if lp_i > m
m = lp_i
end
end
end
v = 0.0
@inbounds for i = 1:K
if p[i] > 0.0
v += exp(lp[i] - m)
end
end
return m + log(v)
lp = logsumexp(log(pi) + logpdf(component(d, i), x) for (i, pi) in enumerate(p) if !iszero(pi))
return lp
end

function _mixlogpdf!(r::AbstractArray, d::AbstractMixtureModel, x)
Expand Down
46 changes: 23 additions & 23 deletions test/mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ function test_mixture(g::UnivariateMixture, n::Int, ns::Int,
end

K = ncomponents(g)
pr = probs(g)
pr = @inferred(probs(g))
@assert length(pr) == K

# mean
mu = 0.0
for k = 1:K
mu += pr[k] * mean(component(g, k))
end
@test mean(g) mu
@test @inferred(mean(g)) mu

# evaluation of cdf
cf = zeros(T, n)
Expand All @@ -38,7 +38,7 @@ function test_mixture(g::UnivariateMixture, n::Int, ns::Int,
end

for i = 1:n
@test cdf(g, X[i]) cf[i]
@test @inferred(cdf(g, X[i])) cf[i]
end
@test cdf.(g, X) cf

Expand All @@ -58,16 +58,16 @@ function test_mixture(g::UnivariateMixture, n::Int, ns::Int,
mix_lp0 = log.(mix_p0)

for i = 1:n
@test pdf(g, X[i]) mix_p0[i]
@test logpdf(g, X[i]) mix_lp0[i]
@test componentwise_pdf(g, X[i]) vec(P0[i,:])
@test componentwise_logpdf(g, X[i]) vec(LP0[i,:])
@test @inferred(pdf(g, X[i])) mix_p0[i]
@test @inferred(logpdf(g, X[i])) mix_lp0[i]
@test @inferred(componentwise_pdf(g, X[i])) vec(P0[i,:])
@test @inferred(componentwise_logpdf(g, X[i])) vec(LP0[i,:])
end

@test pdf.(g, X) mix_p0
@test logpdf.(g, X) mix_lp0
@test componentwise_pdf(g, X) P0
@test componentwise_logpdf(g, X) LP0
@test @inferred(map(Base.Fix1(pdf, g), X)) mix_p0
@test @inferred(map(Base.Fix1(logpdf, g), X)) mix_lp0
@test @inferred(componentwise_pdf(g, X)) P0
@test @inferred(componentwise_logpdf(g, X)) LP0

# sampling
if (T <: AbstractFloat)
Expand All @@ -94,15 +94,15 @@ function test_mixture(g::MultivariateMixture, n::Int, ns::Int,
end

K = ncomponents(g)
pr = probs(g)
pr = @inferred(probs(g))
@assert length(pr) == K

# mean
mu = zeros(length(g))
for k = 1:K
mu .+= pr[k] .* mean(component(g, k))
end
@test mean(g) mu
@test @inferred(mean(g)) mu

# evaluation
P0 = zeros(n, K)
Expand All @@ -121,20 +121,20 @@ function test_mixture(g::MultivariateMixture, n::Int, ns::Int,

for i = 1:n
x_i = X[:,i]
@test pdf(g, x_i) mix_p0[i]
@test logpdf(g, x_i) mix_lp0[i]
@test componentwise_pdf(g, x_i) vec(P0[i,:])
@test componentwise_logpdf(g, x_i) vec(LP0[i,:])
@test @inferred(pdf(g, x_i)) mix_p0[i]
@test @inferred(logpdf(g, x_i)) mix_lp0[i]
@test @inferred(componentwise_pdf(g, x_i)) vec(P0[i,:])
@test @inferred(componentwise_logpdf(g, x_i)) vec(LP0[i,:])
end
#=
@show g
@show size(X)
@show size(mix_p0)
=#
@test pdf(g, X) mix_p0
@test logpdf(g, X) mix_lp0
@test componentwise_pdf(g, X) P0
@test componentwise_logpdf(g, X) LP0
@test @inferred(pdf(g, X)) mix_p0
@test @inferred(logpdf(g, X)) mix_lp0
@test @inferred(componentwise_pdf(g, X)) P0
@test @inferred(componentwise_logpdf(g, X)) LP0

# sampling
if ismissing(rng)
Expand Down Expand Up @@ -172,8 +172,8 @@ end
"rand(rng, ...)" => MersenneTwister(123))

@testset "Testing UnivariateMixture" begin
g_u = MixtureModel(Normal, [(0.0, 1.0), (2.0, 1.0), (-4.0, 1.5)], [0.2, 0.5, 0.3])
@test isa(g_u, MixtureModel{Univariate, Continuous, Normal})
g_u = MixtureModel(Normal{Float64}, [(0.0, 1.0), (2.0, 1.0), (-4.0, 1.5)], [0.2, 0.5, 0.3])
@test isa(g_u, MixtureModel{Univariate,Continuous,<:Normal})
@test ncomponents(g_u) == 3
test_mixture(g_u, 1000, 10^6, rng)
test_params(g_u)
Expand Down

0 comments on commit 99b2b57

Please sign in to comment.