Skip to content

Commit

Permalink
use iterate
Browse files Browse the repository at this point in the history
  • Loading branch information
rmsrosa committed Jan 27, 2024
1 parent b50f5eb commit 1d3c281
Showing 1 changed file with 45 additions and 23 deletions.
68 changes: 45 additions & 23 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,19 +371,31 @@ pdf(d::UnivariateMixture, x::Real) = _mixpdf1(d, x)
logpdf(d::UnivariateMixture, x::Real) = _mixlogpdf1(d, x)

function gradlogpdf(d::UnivariateMixture, x::Real)
cp = components(d)
pr = probs(d)
pdfx1 = pdf(cp[1], x)
pdfx = pr[1] * pdfx1
_glp = pdfx * gradlogpdf(cp[1], x)
glp = (!iszero(pr[1])) && (!iszero(pdfx)) ? _glp : zero(_glp)
@inbounds for i in Iterators.drop(eachindex(pr, cp), 1)
if !iszero(pr[i])
pdfxi = pdf(cp[i], x)
ps = probs(d)
cs = components(d)

# `d` is expected to have at least one distribution, otherwise this will just error
psi, idxps = iterate(ps)
csi, idxcs = iterate(cs)
pdfx1 = pdf(csi, x)
pdfx = psi * pdfx1
glp = pdfx * gradlogpdf(csi, x)
if iszero(psi) || iszero(pdfx)
glp = zero(glp)
end

while true
iterps = iterate(ps, idxps)
itercs = iterate(cs, idxcs)
( (iterps !== nothing) && (itercs !== nothing) ) || break
psi, idxps = iterps
csi, idxcs = itercs
if !iszero(psi)
pdfxi = pdf(csi, x)
if !iszero(pdfxi)
pipdfxi = pr[i] * pdfxi
pipdfxi = psi * pdfxi
pdfx += pipdfxi
glp += pipdfxi * gradlogpdf(cp[i], x)
glp += pipdfxi * gradlogpdf(csi, x)
end
end
end
Expand All @@ -403,27 +415,37 @@ _pdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real
_logpdf!(r::AbstractArray{<:Real}, d::MultivariateMixture, x::AbstractMatrix{<:Real}) = _mixlogpdf!(r, d, x)

function gradlogpdf(d::MultivariateMixture, x::AbstractVector{<:Real})
cp = components(d)
pr = probs(d)
pdfx1 = pdf(cp[1], x)
pdfx = pr[1] * pdfx1
glp = pdfx * gradlogpdf(cp[1], x)
if ( iszero(pr[1]) || iszero(pdfx) )
ps = probs(d)
cs = components(d)

# `d` is expected to have at least one distribution, otherwise this will just error
psi, idxps = iterate(ps)
csi, idxcs = iterate(cs)
pdfx1 = pdf(csi, x)
pdfx = psi * pdfx1
glp = pdfx * gradlogpdf(csi, x)
if iszero(psi) || iszero(pdfx)
glp .= zero(eltype(glp))
end
@inbounds for i in Iterators.drop(eachindex(pr, cp), 1)
if !iszero(pr[i])
pdfxi = pdf(cp[i], x)

while true
iterps = iterate(ps, idxps)
itercs = iterate(cs, idxcs)
( (iterps !== nothing) && (itercs !== nothing) ) || break
psi, idxps = iterps
csi, idxcs = itercs
if !iszero(psi)
pdfxi = pdf(csi, x)
if !iszero(pdfxi)
pipdfxi = pr[i] * pdfxi
pipdfxi = psi * pdfxi
pdfx += pipdfxi
glp .+= pipdfxi * gradlogpdf(cp[i], x)
glp .+= pipdfxi * gradlogpdf(csi, x)
end
end
end
if !iszero(pdfx) # else glp is already zero
glp ./= pdfx
end
end
return glp
end

Expand Down

0 comments on commit 1d3c281

Please sign in to comment.