Skip to content

Commit

Permalink
Improve type-stability everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Jan 19, 2022
1 parent ac3b799 commit f024f45
Showing 1 changed file with 33 additions and 26 deletions.
59 changes: 33 additions & 26 deletions src/censored.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,62 +312,69 @@ function entropy(d::Censored)

# truncation contains ~ no probability
if log_prob_trunc < log(eps(one(log_prob_trunc)))
return entropy_bound
entropy_trunc = zero(entropy_bound)
else
entropy_trunc = oftype(entropy_bound, entropy(_to_truncated(d)))
end

dtrunc = _to_truncated(d)
entropy_interval =
exp(log_prob_trunc) * entropy(dtrunc) - xexpx(log_prob_trunc) + xlogx_pl + xlogx_pu
entropy_interval = exp(log_prob_trunc) * entropy_trunc - xexpx(log_prob_trunc) + xlogx_pl + xlogx_pu
return entropy_bound + entropy_interval
end


#### Evaluation

function pdf(d::Censored{<:Any,<:Any,T}, x::Real) where {T}
function pdf(d::Censored{<:Any,<:Any,T}, x0::Real) where {T}
d0 = d.uncensored
lower = d.lower
upper = d.upper
S = Base.promote_eltype(T, x)
S = Base.promote_eltype(T, x0)
x = S(x0)
px = float(pdf(d0, x))
return if lower !== missing && x == lower
result = cdf(d0, S(x))
_eqnotmissing(x, upper) ? one(result) : result
_eqnotmissing(x, upper) ? one(px) : oftype(px, cdf(d0, x))
elseif _eqnotmissing(x, upper)
_ccdf_inclusive(d0, S(x))
if value_support(typeof(d0)) === Discrete
oftype(px, ccdf(d0, x) + px)
else
oftype(px, ccdf(d0, x))
end
else
result = pdf(d0, S(x))
_in_open_interval(x, lower, upper) ? result : zero(result)
_in_open_interval(x, lower, upper) ? px : zero(px)
end
end

function logpdf(d::Censored{<:Any,<:Any,T}, x::Real) where {T}
function logpdf(d::Censored{<:Any,<:Any,T}, x0::Real) where {T}
d0 = d.uncensored
lower = d.lower
upper = d.upper
S = Base.promote_eltype(T, x)
S = Base.promote_eltype(T, x0)
x = S(x0)
logpx = float(logpdf(d0, x))
return if lower !== missing && x == lower
result = logcdf(d0, S(x))
_eqnotmissing(x, upper) ? zero(result) : result
_eqnotmissing(x, upper) ? zero(logpx) : oftype(logpx, logcdf(d0, x))
elseif _eqnotmissing(x, upper)
_logccdf_inclusive(d0, S(x))
if value_support(typeof(d0)) === Discrete
oftype(logpx, logaddexp(logccdf(d0, x), logpx))
else
oftype(logpx, logccdf(d0, x))
end
else
result = logpdf(d0, S(x))
_in_open_interval(x, lower, upper) ? result : oftype(result, -Inf)
_in_open_interval(x, lower, upper) ? logpx : oftype(logpx, -Inf)
end
end

function loglikelihood(d::Censored{<:Any,<:Any,T}, x::AbstractArray{<:Real}) where {T}
d0 = d.uncensored
lower = d.lower
upper = d.upper
S = float(Base.promote_eltype(T, first(x)))
log_prob_lower = lower === missing ? 0 : logcdf(d0, S(lower))
log_prob_upper = upper === missing ? 0 : _logccdf_inclusive(d0, S(upper))
logzero = S(-Inf)

S = Base.promote_eltype(T, x)
x1 = S(first(x))
logpx1 = float(logpdf(d0, x1))
log_prob_lower = lower === missing ? zero(logpx1) : oftype(logpx1, logcdf(d0, S(lower)))
log_prob_upper = upper === missing ? zero(logpx1) : oftype(logpx1, _logccdf_inclusive(d0, S(upper)))
logzero = oftype(logpx1, -Inf)
return sum(x) do xi
R = float(Base.promote_eltype(T, xi))
_in_open_interval(xi, lower, upper) && return logpdf(d0, R(xi))
_in_open_interval(xi, lower, upper) && return oftype(logpx1, logpdf(d0, oftype(x1, xi)))
_eqnotmissing(xi, lower) && return log_prob_lower
_eqnotmissing(xi, upper) && return log_prob_upper
return logzero
Expand Down

0 comments on commit f024f45

Please sign in to comment.