From 99b2b57faf21e34498ad443d6c572525d95986fc Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 2 May 2021 10:36:36 +0200 Subject: [PATCH] Improve type inference of `MixtureModel` (#1308) * Improve type inference of `MixtureModel` * Bump version * Remove comment * Bump minor version --- Project.toml | 2 +- src/mixtures/mixturemodel.jl | 74 +++++------------------------------- test/mixture.jl | 46 +++++++++++----------- 3 files changed, 34 insertions(+), 88 deletions(-) diff --git a/Project.toml b/Project.toml index d4e1d9c1d..16927f0dd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.24.18" +version = "0.25.0" [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index cc7a8f5d4..93402c14d 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -19,11 +19,11 @@ A mixture of distributions, parametrized on: * `C` distribution family of the mixture * `CT` the type for probabilities of the prior """ -struct MixtureModel{VF<:VariateForm,VS<:ValueSupport,C<:Distribution,CT<:Real} <: AbstractMixtureModel{VF,VS,C} +struct MixtureModel{VF<:VariateForm,VS<:ValueSupport,C<:Distribution,CT<:Categorical} <: AbstractMixtureModel{VF,VS,C} components::Vector{C} - prior::Categorical{CT} + prior::CT - function MixtureModel{VF,VS,C}(cs::Vector{C}, pri::Categorical{CT}) where {VF,VS,C,CT} + function MixtureModel{VF,VS,C}(cs::Vector{C}, pri::CT) where {VF,VS,C,CT} length(cs) == ncategories(pri) || error("The number of components does not match the length of prior.") new{VF,VS,C,CT}(cs, pri) @@ -171,16 +171,8 @@ minimum(d::MixtureModel) = minimum([minimum(dci) for dci in d.components]) maximum(d::MixtureModel) = maximum([maximum(dci) for dci in d.components]) function mean(d::UnivariateMixture) - K = ncomponents(d) p = probs(d) - m = 0.0 - for i = 1:K - pi = p[i] - if pi > 0.0 - c = component(d, i) - m += mean(c) * pi - end - end + m = sum(pi * mean(component(d, i)) for (i, pi) in enumerate(p) if !iszero(pi)) return m end @@ -281,28 +273,13 @@ end #### Evaluation function insupport(d::AbstractMixtureModel, x::AbstractVector) - K = ncomponents(d) p = probs(d) - @inbounds for i in eachindex(p) - pi = p[i] - if pi > 0.0 && insupport(component(d, i), x) - return true - end - end - return false + return any(insupport(component(d, i), x) for (i, pi) in enumerate(p) if !iszero(pi)) end function _cdf(d::UnivariateMixture, x::Real) - K = ncomponents(d) p = probs(d) - r = 0.0 - @inbounds for i in eachindex(p) - pi = p[i] - if pi > 0.0 - c = component(d, i) - r += pi * cdf(c, x) - end - end + r = sum(pi * cdf(component(d, i), x) for (i, pi) in enumerate(p) if !iszero(pi)) return r end @@ -310,9 +287,8 @@ cdf(d::UnivariateMixture{Continuous}, x::Real) = _cdf(d, x) cdf(d::UnivariateMixture{Discrete}, x::Integer) = _cdf(d, x) function _mixpdf1(d::AbstractMixtureModel, x) - ps = probs(d) - cs = components(d) - return sum((ps[i] > 0) * (ps[i] * pdf(cs[i], x)) for i in eachindex(ps)) + p = probs(d) + return sum(pi * pdf(component(d, i), x) for (i, pi) in enumerate(p) if !iszero(pi)) end function _mixpdf!(r::AbstractArray, d::AbstractMixtureModel, x) @@ -335,39 +311,9 @@ function _mixpdf!(r::AbstractArray, d::AbstractMixtureModel, x) end function _mixlogpdf1(d::AbstractMixtureModel, x) - # using the formula below for numerical stability - # - # logpdf(d, x) = log(sum_i pri[i] * pdf(cs[i], x)) - # = log(sum_i pri[i] * exp(logpdf(cs[i], x))) - # = log(sum_i exp(logpri[i] + logpdf(cs[i], x))) - # = m + log(sum_i exp(logpri[i] + logpdf(cs[i], x) - m)) - # - # m is chosen to be the maximum of logpri[i] + logpdf(cs[i], x) - # such that the argument of exp is in a reasonable range - # - - K = ncomponents(d) p = probs(d) - lp = Vector{eltype(p)}(undef, K) - m = -Inf # m <- the maximum of log(p(cs[i], x)) + log(pri[i]) - @inbounds for i in eachindex(p) - pi = p[i] - if pi > 0.0 - # lp[i] <- log(p(cs[i], x)) + log(pri[i]) - lp_i = logpdf(component(d, i), x) + log(pi) - lp[i] = lp_i - if lp_i > m - m = lp_i - end - end - end - v = 0.0 - @inbounds for i = 1:K - if p[i] > 0.0 - v += exp(lp[i] - m) - end - end - return m + log(v) + lp = logsumexp(log(pi) + logpdf(component(d, i), x) for (i, pi) in enumerate(p) if !iszero(pi)) + return lp end function _mixlogpdf!(r::AbstractArray, d::AbstractMixtureModel, x) diff --git a/test/mixture.jl b/test/mixture.jl index b027324c1..3873df72a 100644 --- a/test/mixture.jl +++ b/test/mixture.jl @@ -18,7 +18,7 @@ function test_mixture(g::UnivariateMixture, n::Int, ns::Int, end K = ncomponents(g) - pr = probs(g) + pr = @inferred(probs(g)) @assert length(pr) == K # mean @@ -26,7 +26,7 @@ function test_mixture(g::UnivariateMixture, n::Int, ns::Int, for k = 1:K mu += pr[k] * mean(component(g, k)) end - @test mean(g) ≈ mu + @test @inferred(mean(g)) ≈ mu # evaluation of cdf cf = zeros(T, n) @@ -38,7 +38,7 @@ function test_mixture(g::UnivariateMixture, n::Int, ns::Int, end for i = 1:n - @test cdf(g, X[i]) ≈ cf[i] + @test @inferred(cdf(g, X[i])) ≈ cf[i] end @test cdf.(g, X) ≈ cf @@ -58,16 +58,16 @@ function test_mixture(g::UnivariateMixture, n::Int, ns::Int, mix_lp0 = log.(mix_p0) for i = 1:n - @test pdf(g, X[i]) ≈ mix_p0[i] - @test logpdf(g, X[i]) ≈ mix_lp0[i] - @test componentwise_pdf(g, X[i]) ≈ vec(P0[i,:]) - @test componentwise_logpdf(g, X[i]) ≈ vec(LP0[i,:]) + @test @inferred(pdf(g, X[i])) ≈ mix_p0[i] + @test @inferred(logpdf(g, X[i])) ≈ mix_lp0[i] + @test @inferred(componentwise_pdf(g, X[i])) ≈ vec(P0[i,:]) + @test @inferred(componentwise_logpdf(g, X[i])) ≈ vec(LP0[i,:]) end - @test pdf.(g, X) ≈ mix_p0 - @test logpdf.(g, X) ≈ mix_lp0 - @test componentwise_pdf(g, X) ≈ P0 - @test componentwise_logpdf(g, X) ≈ LP0 + @test @inferred(map(Base.Fix1(pdf, g), X)) ≈ mix_p0 + @test @inferred(map(Base.Fix1(logpdf, g), X)) ≈ mix_lp0 + @test @inferred(componentwise_pdf(g, X)) ≈ P0 + @test @inferred(componentwise_logpdf(g, X)) ≈ LP0 # sampling if (T <: AbstractFloat) @@ -94,7 +94,7 @@ function test_mixture(g::MultivariateMixture, n::Int, ns::Int, end K = ncomponents(g) - pr = probs(g) + pr = @inferred(probs(g)) @assert length(pr) == K # mean @@ -102,7 +102,7 @@ function test_mixture(g::MultivariateMixture, n::Int, ns::Int, for k = 1:K mu .+= pr[k] .* mean(component(g, k)) end - @test mean(g) ≈ mu + @test @inferred(mean(g)) ≈ mu # evaluation P0 = zeros(n, K) @@ -121,20 +121,20 @@ function test_mixture(g::MultivariateMixture, n::Int, ns::Int, for i = 1:n x_i = X[:,i] - @test pdf(g, x_i) ≈ mix_p0[i] - @test logpdf(g, x_i) ≈ mix_lp0[i] - @test componentwise_pdf(g, x_i) ≈ vec(P0[i,:]) - @test componentwise_logpdf(g, x_i) ≈ vec(LP0[i,:]) + @test @inferred(pdf(g, x_i)) ≈ mix_p0[i] + @test @inferred(logpdf(g, x_i)) ≈ mix_lp0[i] + @test @inferred(componentwise_pdf(g, x_i)) ≈ vec(P0[i,:]) + @test @inferred(componentwise_logpdf(g, x_i)) ≈ vec(LP0[i,:]) end #= @show g @show size(X) @show size(mix_p0) =# - @test pdf(g, X) ≈ mix_p0 - @test logpdf(g, X) ≈ mix_lp0 - @test componentwise_pdf(g, X) ≈ P0 - @test componentwise_logpdf(g, X) ≈ LP0 + @test @inferred(pdf(g, X)) ≈ mix_p0 + @test @inferred(logpdf(g, X)) ≈ mix_lp0 + @test @inferred(componentwise_pdf(g, X)) ≈ P0 + @test @inferred(componentwise_logpdf(g, X)) ≈ LP0 # sampling if ismissing(rng) @@ -172,8 +172,8 @@ end "rand(rng, ...)" => MersenneTwister(123)) @testset "Testing UnivariateMixture" begin - g_u = MixtureModel(Normal, [(0.0, 1.0), (2.0, 1.0), (-4.0, 1.5)], [0.2, 0.5, 0.3]) - @test isa(g_u, MixtureModel{Univariate, Continuous, Normal}) + g_u = MixtureModel(Normal{Float64}, [(0.0, 1.0), (2.0, 1.0), (-4.0, 1.5)], [0.2, 0.5, 0.3]) + @test isa(g_u, MixtureModel{Univariate,Continuous,<:Normal}) @test ncomponents(g_u) == 3 test_mixture(g_u, 1000, 10^6, rng) test_params(g_u)