diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index 31f1931fb..babf66a5f 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -371,19 +371,31 @@ pdf(d::UnivariateMixture, x::Real) = _mixpdf1(d, x) logpdf(d::UnivariateMixture, x::Real) = _mixlogpdf1(d, x) function gradlogpdf(d::UnivariateMixture, x::Real) - cp = components(d) - pr = probs(d) - pdfx1 = pdf(cp[1], x) - pdfx = pr[1] * pdfx1 - _glp = pdfx * gradlogpdf(cp[1], x) - glp = (!iszero(pr[1])) && (!iszero(pdfx)) ? _glp : zero(_glp) - @inbounds for i in Iterators.drop(eachindex(pr, cp), 1) - if !iszero(pr[i]) - pdfxi = pdf(cp[i], x) + ps = probs(d) + cs = components(d) + + # `d` is expected to have at least one distribution, otherwise this will just error + psi, idxps = iterate(ps) + csi, idxcs = iterate(cs) + pdfx1 = pdf(csi, x) + pdfx = psi * pdfx1 + glp = pdfx * gradlogpdf(csi, x) + if iszero(psi) || iszero(pdfx) + glp = zero(glp) + end + + while true + iterps = iterate(ps, idxps) + itercs = iterate(cs, idxcs) + ( (iterps !== nothing) && (itercs !== nothing) ) || break + psi, idxps = iterps + csi, idxcs = itercs + if !iszero(psi) + pdfxi = pdf(csi, x) if !iszero(pdfxi) - pipdfxi = pr[i] * pdfxi + pipdfxi = psi * pdfxi pdfx += pipdfxi - glp += pipdfxi * gradlogpdf(cp[i], x) + glp += pipdfxi * gradlogpdf(csi, x) end end end @@ -403,27 +415,37 @@ _pdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real _logpdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixlogpdf!(r, d, x) function gradlogpdf(d::MultivariateMixture, x::AbstractVector{<:Real}) - cp = components(d) - pr = probs(d) - pdfx1 = pdf(cp[1], x) - pdfx = pr[1] * pdfx1 - glp = pdfx * gradlogpdf(cp[1], x) - if ( iszero(pr[1]) || iszero(pdfx) ) + ps = probs(d) + cs = components(d) + + # `d` is expected to have at least one distribution, otherwise this will just error + psi, idxps = iterate(ps) + csi, idxcs = iterate(cs) + pdfx1 = pdf(csi, x) + pdfx = psi * pdfx1 + glp = pdfx * gradlogpdf(csi, x) + if iszero(psi) || iszero(pdfx) glp .= zero(eltype(glp)) end - @inbounds for i in Iterators.drop(eachindex(pr, cp), 1) - if !iszero(pr[i]) - pdfxi = pdf(cp[i], x) + + while true + iterps = iterate(ps, idxps) + itercs = iterate(cs, idxcs) + ( (iterps !== nothing) && (itercs !== nothing) ) || break + psi, idxps = iterps + csi, idxcs = itercs + if !iszero(psi) + pdfxi = pdf(csi, x) if !iszero(pdfxi) - pipdfxi = pr[i] * pdfxi + pipdfxi = psi * pdfxi pdfx += pipdfxi - glp .+= pipdfxi * gradlogpdf(cp[i], x) + glp .+= pipdfxi * gradlogpdf(csi, x) end end end if !iszero(pdfx) # else glp is already zero glp ./= pdfx - end + end return glp end