diff --git a/Project.toml b/Project.toml index 336b24c55..1e8a71b70 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.44" +version = "0.25.45" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/make.jl b/docs/make.jl index 78ecef710..39e51d4db 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,6 +11,7 @@ makedocs( "types.md", "univariate.md", "truncate.md", + "censored.md", "multivariate.md", "matrix.md", "reshape.md", diff --git a/docs/src/censored.md b/docs/src/censored.md new file mode 100644 index 000000000..3059da015 --- /dev/null +++ b/docs/src/censored.md @@ -0,0 +1,53 @@ +# Censored Distributions + +In *censoring* of data, values exceeding an upper limit (right censoring) or falling below a lower limit (left censoring), or both (interval censoring) are replaced by the corresponding limit itself. +The package provides the `censored` function, which creates the most appropriate distribution to represent a censored version of a given distribution. + +A censored distribution can be constructed using the following signature: + +```@docs +censored +``` + +In the general case, this will create a `Distributions.Censored{typeof(d0)}` structure, defined as follows: + +```@docs +Distributions.Censored +``` + +In general, `censored` should be called instead of the constructor of `Censored`, which is not exported. + +Many functions, including those for the evaluation of pdf and sampling, are defined for all censored univariate distributions: + +- [`maximum(::UnivariateDistribution)`](@ref) +- [`minimum(::UnivariateDistribution)`](@ref) +- [`insupport(::UnivariateDistribution, x::Any)`](@ref) +- [`pdf(::UnivariateDistribution, ::Real)`](@ref) +- [`logpdf(::UnivariateDistribution, ::Real)`](@ref) +- [`cdf(::UnivariateDistribution, ::Real)`](@ref) +- [`logcdf(::UnivariateDistribution, ::Real)`](@ref) +- [`logdiffcdf(::UnivariateDistribution, ::T, ::T) where {T <: Real}`](@ref) +- [`ccdf(::UnivariateDistribution, ::Real)`](@ref) +- [`logccdf(::UnivariateDistribution, ::Real)`](@ref) +- [`quantile(::UnivariateDistribution, ::Real)`](@ref) +- [`cquantile(::UnivariateDistribution, ::Real)`](@ref) +- [`invlogcdf(::UnivariateDistribution, ::Real)`](@ref) +- [`invlogccdf(::UnivariateDistribution, ::Real)`](@ref) +- [`median(::UnivariateDistribution)`](@ref) +- [`rand(::UnivariateDistribution)`](@ref) +- [`rand!(::UnivariateDistribution, ::AbstractArray)`](@ref) + +Some functions to compute statistics are available for the censored distribution if they are also available for its truncation: +- [`mean(::UnivariateDistribution)`](@ref) +- [`var(::UnivariateDistribution)`](@ref) +- [`std(::UnivariateDistribution)`](@ref) +- [`entropy(::UnivariateDistribution)`](@ref) + +For example, these functions are available for the following uncensored distributions: +- `DiscreteUniform` +- `Exponential` +- `LogUniform` +- `Normal` +- `Uniform` + +[`mode`](@ref) is not implemented for censored distributions. diff --git a/docs/src/index.md b/docs/src/index.md index 300c0f724..4fc944bc0 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -7,4 +7,4 @@ The [*Distributions*](https://github.com/JuliaStats/Distributions.jl) package pr * Probability density/mass functions (pdf) and their logarithm (logpdf) * Moment-generating functions and characteristic functions * Maximum likelihood estimation -* Distribution composition (Cartesian product of distributions, truncated distributions) +* Distribution composition and derived distributions (Cartesian product of distributions, truncated distributions, censored distributions) diff --git a/src/Distributions.jl b/src/Distributions.jl index 3cf0a2812..2e4ef62e5 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -182,6 +182,7 @@ export canonform, # get canonical form of a distribution ccdf, # complementary cdf, i.e. 1 - cdf cdf, # cumulative distribution function + censored, # censor a distribution with a lower and upper bound cf, # characteristic function cquantile, # complementary quantile (i.e. using prob in right hand tail) component, # get the k-th component of a mixture model @@ -295,6 +296,7 @@ include("samplers.jl") # others include("reshaped.jl") include("truncate.jl") +include("censored.jl") include("conversion.jl") include("convolution.jl") include("qq.jl") @@ -334,7 +336,7 @@ information. Supported distributions: Arcsine, Bernoulli, Beta, BetaBinomial, BetaPrime, Binomial, Biweight, - Categorical, Cauchy, Chi, Chisq, Cosine, DiagNormal, DiagNormalCanon, + Categorical, Cauchy, Censored, Chi, Chisq, Cosine, DiagNormal, DiagNormalCanon, Dirichlet, DiscreteUniform, DoubleExponential, EdgeworthMean, EdgeworthSum, EdgeworthZ, Erlang, Epanechnikov, Exponential, FDist, FisherNoncentralHypergeometric, diff --git a/src/censored.jl b/src/censored.jl new file mode 100644 index 000000000..2d5e1d67f --- /dev/null +++ b/src/censored.jl @@ -0,0 +1,462 @@ +""" + censored(d0::UnivariateDistribution; [lower::Real], [upper::Real]) + censored(d0::UnivariateDistribution, lower::Real, upper::Real) + +A _censored distribution_ `d` of a distribution `d0` to the interval +``[l, u]=```[lower, upper]` has the probability density (mass) function: + +```math +f(x; d_0, l, u) = \\begin{cases} + P_{Z \\sim d_0}(Z \\le l), & x = l \\\\ + f_{d_0}(x), & l < x < u \\\\ + P_{Z \\sim d_0}(Z \\ge u), & x = u \\\\ + \\end{cases}, \\quad x \\in [l, u] +``` +where ``f_{d_0}(x)`` is the probability density (mass) function of ``d_0``. + +If ``Z \\sim d_0``, and `X = clamp(Z, l, u)`, then ``X \\sim d``. Note that this implies +that even if ``d_0`` is continuous, its censored form assigns positive probability to the +bounds ``l`` and ``u``. Therefore, a censored continuous distribution has atoms and is a +mixture of discrete and continuous components. + +The function falls back to constructing a [`Distributions.Censored`](@ref) wrapper. + +# Usage + +```julia +censored(d0; lower=l) # d0 left-censored to the interval [l, Inf) +censored(d0; upper=u) # d0 right-censored to the interval (-Inf, u] +censored(d0; lower=l, upper=u) # d0 interval-censored to the interval [l, u] +censored(d0, l, u) # d0 interval-censored to the interval [l, u] +``` + +# Implementation + +To implement a specialized censored form for distributions of type `D`, instead of +overloading a method with one of the above signatures, one or more of the following methods +should be implemented: +- `censored(d0::D, l::T, u::T) where {T <: Real}` +- `censored(d0::D, ::Nothing, u::Real)` +- `censored(d0::D, l::Real, ::Nothing)` +""" +function censored end +function censored(d0::UnivariateDistribution, l::T, u::T) where {T<:Real} + return Censored(d0, l, u) +end +function censored(d0::UnivariateDistribution, ::Nothing, u::Real) + return Censored(d0, nothing, u) +end +function censored(d0::UnivariateDistribution, l::Real, ::Nothing) + return Censored(d0, l, nothing) +end +censored(d0::UnivariateDistribution, l::Real, u::Real) = censored(d0, promote(l, u)...) +censored(d0::UnivariateDistribution, ::Nothing, ::Nothing) = d0 +function censored( + d0::UnivariateDistribution; + lower::Union{Real,Nothing} = nothing, + upper::Union{Real,Nothing} = nothing, +) + return censored(d0, lower, upper) +end + +""" + Censored + +Generic wrapper for a [`censored`](@ref) distribution. +""" +struct Censored{ + D<:UnivariateDistribution, + S<:ValueSupport, + T<:Real, + TL<:Union{T,Nothing}, + TU<:Union{T,Nothing}, +} <: UnivariateDistribution{S} + uncensored::D # the original distribution (uncensored) + lower::TL # lower bound + upper::TU # upper bound + function Censored(d0::UnivariateDistribution, lower::T, upper::T; check_args::Bool=true) where {T<:Real} + @check_args(Censored, lower ā‰¤ upper) + new{typeof(d0), value_support(typeof(d0)), T, T, T}(d0, lower, upper) + end + function Censored(d0::UnivariateDistribution, l::Nothing, u::Real; check_args::Bool=true) + new{typeof(d0), value_support(typeof(d0)), typeof(u), Nothing, typeof(u)}(d0, l, u) + end + function Censored(d0::UnivariateDistribution, l::Real, u::Nothing; check_args::Bool=true) + new{typeof(d0), value_support(typeof(d0)), typeof(l), typeof(l), Nothing}(d0, l, u) + end +end + +const LeftCensored{D<:UnivariateDistribution,S<:ValueSupport,T<:Real} = Censored{D,S,T,T,Nothing} +const RightCensored{D<:UnivariateDistribution,S<:ValueSupport,T<:Real} = Censored{D,S,T,Nothing,T} + +function censored(d::Censored, l::T, u::T) where {T<:Real} + return censored( + d.uncensored, + d.lower === nothing ? l : max(l, d.lower), + d.upper === nothing ? u : min(u, d.upper), + ) +end +function censored(d::Censored, ::Nothing, u::Real) + return censored(d.uncensored, d.lower, d.upper === nothing ? u : min(u, d.upper)) +end +function censored(d::Censored, l::Real, ::Nothing) + return censored(d.uncensored, d.lower === nothing ? l : max(l, d.lower), d.upper) +end + +function params(d::Censored) + d0params = params(d.uncensored) + return (d0params..., d.lower, d.upper) +end + +function partype(d::Censored{<:UnivariateDistribution,<:ValueSupport,T}) where {T} + return promote_type(partype(d.uncensored), T) +end + +Base.eltype(::Type{<:Censored{D,S,T}}) where {D,S,T} = promote_type(T, eltype(D)) + +#### Range and Support + +isupperbounded(d::LeftCensored) = isupperbounded(d.uncensored) +isupperbounded(d::Censored) = isupperbounded(d.uncensored) || _ccdf_inclusive(d.uncensored, d.upper) > 0 + +islowerbounded(d::RightCensored) = islowerbounded(d.uncensored) +islowerbounded(d::Censored) = islowerbounded(d.uncensored) || cdf(d.uncensored, d.lower) > 0 + +maximum(d::LeftCensored) = max(maximum(d.uncensored), d.lower) +maximum(d::Censored) = min(maximum(d.uncensored), d.upper) + +minimum(d::RightCensored) = min(minimum(d.uncensored), d.upper) +minimum(d::Censored) = max(minimum(d.uncensored), d.lower) + +function insupport(d::Censored, x::Real) + d0 = d.uncensored + lower = d.lower + upper = d.upper + return ( + (_in_open_interval(x, lower, upper) && insupport(d0, x)) || + (x == lower && cdf(d0, lower) > 0) || + (x == upper && _ccdf_inclusive(d0, upper) > 0) + ) +end + +#### Show + +function show(io::IO, ::MIME"text/plain", d::Censored) + print(io, "Censored(") + d0 = d.uncensored + uml, namevals = _use_multline_show(d0) + uml ? show_multline(io, d0, namevals; newline=false) : show_oneline(io, d0, namevals) + if d.lower === nothing + print(io, "; upper=$(d.upper))") + elseif d.upper === nothing + print(io, "; lower=$(d.lower))") + else + print(io, "; lower=$(d.lower), upper=$(d.upper))") + end +end + +_use_multline_show(d::Censored) = _use_multline_show(d.uncensored) + + +#### Statistics + +quantile(d::Censored, p::Real) = _clamp(quantile(d.uncensored, p), d.lower, d.upper) + +median(d::Censored) = _clamp(median(d.uncensored), d.lower, d.upper) + +# the expectations use the following relation: +# š”¼_{x ~ d}[h(x)] = P_{x ~ dā‚€}(x < l) h(l) + P_{x ~ dā‚€}(x > u) h(u) +# + P_{x ~ dā‚€}(l ā‰¤ x ā‰¤ u) š”¼_{x ~ Ļ„}[h(x)], +# where dā‚€ is the uncensored distribution, d is dā‚€ censored to [l, u], +# and Ļ„ is dā‚€ truncated to [l, u] + +function mean(d::LeftCensored) + lower = d.lower + log_prob_lower = _logcdf_noninclusive(d.uncensored, lower) + log_prob_interval = log1mexp(log_prob_lower) + Ī¼ = xexpy(lower, log_prob_lower) + xexpy(mean(_to_truncated(d)), log_prob_interval) + return Ī¼ +end +function mean(d::RightCensored) + upper = d.upper + log_prob_upper = logccdf(d.uncensored, upper) + log_prob_interval = log1mexp(log_prob_upper) + Ī¼ = xexpy(upper, log_prob_upper) + xexpy(mean(_to_truncated(d)), log_prob_interval) + return Ī¼ +end +function mean(d::Censored) + d0 = d.uncensored + lower = d.lower + upper = d.upper + log_prob_lower = _logcdf_noninclusive(d0, lower) + log_prob_upper = logccdf(d0, upper) + log_prob_interval = log1mexp(logaddexp(log_prob_lower, log_prob_upper)) + Ī¼ = (xexpy(lower, log_prob_lower) + xexpy(upper, log_prob_upper) + + xexpy(mean(_to_truncated(d)), log_prob_interval)) + return Ī¼ +end + +function var(d::LeftCensored) + lower = d.lower + log_prob_lower = _logcdf_noninclusive(d.uncensored, lower) + log_prob_interval = log1mexp(log_prob_lower) + dtrunc = _to_truncated(d) + Ī¼_interval = mean(dtrunc) + Ī¼ = xexpy(lower, log_prob_lower) + xexpy(Ī¼_interval, log_prob_interval) + v_interval = var(dtrunc) + abs2(Ī¼_interval - Ī¼) + v = xexpy(abs2(lower - Ī¼), log_prob_lower) + xexpy(v_interval, log_prob_interval) + return v +end +function var(d::RightCensored) + upper = d.upper + log_prob_upper = logccdf(d.uncensored, upper) + log_prob_interval = log1mexp(log_prob_upper) + dtrunc = _to_truncated(d) + Ī¼_interval = mean(dtrunc) + Ī¼ = xexpy(upper, log_prob_upper) + xexpy(Ī¼_interval, log_prob_interval) + v_interval = var(dtrunc) + abs2(Ī¼_interval - Ī¼) + v = xexpy(abs2(upper - Ī¼), log_prob_upper) + xexpy(v_interval, log_prob_interval) + return v +end +function var(d::Censored) + d0 = d.uncensored + lower = d.lower + upper = d.upper + log_prob_lower = _logcdf_noninclusive(d0, lower) + log_prob_upper = logccdf(d0, upper) + log_prob_interval = log1mexp(logaddexp(log_prob_lower, log_prob_upper)) + dtrunc = _to_truncated(d) + Ī¼_interval = mean(dtrunc) + Ī¼ = (xexpy(lower, log_prob_lower) + xexpy(upper, log_prob_upper) + + xexpy(Ī¼_interval, log_prob_interval)) + v_interval = var(dtrunc) + abs2(Ī¼_interval - Ī¼) + v = (xexpy(abs2(lower - Ī¼), log_prob_lower) + xexpy(abs2(upper - Ī¼), log_prob_upper) + + xexpy(v_interval, log_prob_interval)) + return v +end + +# this expectation also uses the following relation: +# š”¼_{x ~ Ļ„}[-log d(x)] = H[Ļ„] - log P_{x ~ dā‚€}(l ā‰¤ x ā‰¤ u) +# + (P_{x ~ dā‚€}(x = l) (log P_{x ~ dā‚€}(x = l) - log P_{x ~ dā‚€}(x ā‰¤ l)) + +# P_{x ~ dā‚€}(x = u) (log P_{x ~ dā‚€}(x = u) - log P_{x ~ dā‚€}(x ā‰„ u)) +# ) / P_{x ~ dā‚€}(l ā‰¤ x ā‰¤ u), +# where H[Ļ„] is the entropy of Ļ„. + +function entropy(d::LeftCensored) + d0 = d.uncensored + lower = d.lower + log_prob_lower_inc = logcdf(d0, lower) + if value_support(typeof(d0)) === Discrete + logpl = logpdf(d0, lower) + log_prob_lower = logsubexp(log_prob_lower_inc, logpl) + xlogx_pl = xexpx(logpl) + else + log_prob_lower = log_prob_lower_inc + xlogx_pl = 0 + end + log_prob_interval = log1mexp(log_prob_lower) + entropy_bound = -xexpx(log_prob_lower_inc) + dtrunc = _to_truncated(d) + entropy_interval = xexpy(entropy(dtrunc), log_prob_interval) - xexpx(log_prob_interval) + xlogx_pl + return entropy_interval + entropy_bound +end +function entropy(d::RightCensored) + d0 = d.uncensored + upper = d.upper + log_prob_upper = logccdf(d0, upper) + if value_support(typeof(d0)) === Discrete + logpu = logpdf(d0, upper) + log_prob_upper_inc = logaddexp(log_prob_upper, logpu) + xlogx_pu = xexpx(logpu) + else + log_prob_upper_inc = log_prob_upper + xlogx_pu = 0 + end + log_prob_interval = log1mexp(log_prob_upper) + entropy_bound = -xexpx(log_prob_upper_inc) + dtrunc = _to_truncated(d) + entropy_interval = xexpy(entropy(dtrunc), log_prob_interval) - xexpx(log_prob_interval) + xlogx_pu + return entropy_interval + entropy_bound +end +function entropy(d::Censored) + d0 = d.uncensored + lower = d.lower + upper = d.upper + log_prob_lower_inc = logcdf(d0, lower) + log_prob_upper = logccdf(d0, upper) + if value_support(typeof(d0)) === Discrete + logpl = logpdf(d0, lower) + logpu = logpdf(d0, upper) + log_prob_lower = logsubexp(log_prob_lower_inc, logpl) + log_prob_upper_inc = logaddexp(log_prob_upper, logpu) + xlogx_pl = xexpx(logpl) + xlogx_pu = xexpx(logpu) + else + log_prob_lower = log_prob_lower_inc + log_prob_upper_inc = log_prob_upper + xlogx_pl = xlogx_pu = 0 + end + log_prob_interval = log1mexp(logaddexp(log_prob_lower, log_prob_upper)) + entropy_bound = -(xexpx(log_prob_lower_inc) + xexpx(log_prob_upper_inc)) + dtrunc = _to_truncated(d) + entropy_interval = xexpy(entropy(dtrunc), log_prob_interval) - xexpx(log_prob_interval) + xlogx_pl + xlogx_pu + return entropy_interval + entropy_bound +end + + +#### Evaluation + +function pdf(d::Censored, x::Real) + d0 = d.uncensored + lower = d.lower + upper = d.upper + px = float(pdf(d0, x)) + return if _in_open_interval(x, lower, upper) + px + elseif x == lower + x == upper ? one(px) : oftype(px, cdf(d0, x)) + elseif x == upper + if value_support(typeof(d0)) === Discrete + oftype(px, ccdf(d0, x) + px) + else + oftype(px, ccdf(d0, x)) + end + else # not in support + zero(px) + end +end + +function logpdf(d::Censored, x::Real) + d0 = d.uncensored + lower = d.lower + upper = d.upper + logpx = logpdf(d0, x) + return if _in_open_interval(x, lower, upper) + logpx + elseif x == lower + x == upper ? zero(logpx) : oftype(logpx, logcdf(d0, x)) + elseif x == upper + if value_support(typeof(d0)) === Discrete + oftype(logpx, logaddexp(logccdf(d0, x), logpx)) + else + oftype(logpx, logccdf(d0, x)) + end + else # not in support + oftype(logpx, -Inf) + end +end + +function loglikelihood(d::Censored, x::AbstractArray{<:Real}) + d0 = d.uncensored + lower = d.lower + upper = d.upper + logpx = logpdf(d0, first(x)) + log_prob_lower = lower === nothing ? zero(logpx) : oftype(logpx, logcdf(d0, lower)) + log_prob_upper = upper === nothing ? zero(logpx) : oftype(logpx, _logccdf_inclusive(d0, upper)) + logzero = oftype(logpx, -Inf) + return sum(x) do xi + _in_open_interval(xi, lower, upper) && return logpdf(d0, xi) + xi == lower && return log_prob_lower + xi == upper && return log_prob_upper + return logzero + end +end + +function cdf(d::Censored, x::Real) + lower = d.lower + upper = d.upper + result = cdf(d.uncensored, x) + return if lower !== nothing && x < lower + zero(result) + elseif upper === nothing || x < upper + result + else + one(result) + end +end + +function logcdf(d::Censored, x::Real) + lower = d.lower + upper = d.upper + result = logcdf(d.uncensored, x) + return if d.lower !== nothing && x < d.lower + oftype(result, -Inf) + elseif d.upper === nothing || x < d.upper + result + else + zero(result) + end +end + +function ccdf(d::Censored, x::Real) + lower = d.lower + upper = d.upper + result = ccdf(d.uncensored, x) + return if lower !== nothing && x < lower + one(result) + elseif upper === nothing || x < upper + result + else + zero(result) + end +end + +function logccdf(d::Censored{<:Any,<:Any,T}, x::Real) where {T} + lower = d.lower + upper = d.upper + result = logccdf(d.uncensored, x) + return if lower !== nothing && x < lower + zero(result) + elseif upper === nothing || x < upper + result + else + oftype(result, -Inf) + end +end + + +#### Sampling + +rand(rng::AbstractRNG, d::Censored) = _clamp(rand(rng, d.uncensored), d.lower, d.upper) + + +#### Utilities + +# utilities to handle intervals represented with possibly `nothing` bounds + +_in_open_interval(x::Real, l::Real, u::Real) = l < x < u +_in_open_interval(x::Real, ::Nothing, u::Real) = x < u +_in_open_interval(x::Real, l::Real, ::Nothing) = x > l + +_clamp(x, l, u) = clamp(x, l, u) +_clamp(x, ::Nothing, u) = min(x, u) +_clamp(x, l, ::Nothing) = max(x, l) + +_to_truncated(d::Censored) = truncated(d.uncensored, d.lower, d.upper) + +# utilities for non-inclusive CDF p(x < u) and inclusive CCDF (p ā‰„ u) + +_logcdf_noninclusive(d::UnivariateDistribution, x) = logcdf(d, x) +function _logcdf_noninclusive(d::DiscreteUnivariateDistribution, x) + return logsubexp(logcdf(d, x), logpdf(d, x)) +end + +_ccdf_inclusive(d::UnivariateDistribution, x) = ccdf(d, x) +_ccdf_inclusive(d::DiscreteUnivariateDistribution, x) = ccdf(d, x) + pdf(d, x) + +_logccdf_inclusive(d::UnivariateDistribution, x) = logccdf(d, x) +function _logccdf_inclusive(d::DiscreteUnivariateDistribution, x) + return logaddexp(logccdf(d, x), logpdf(d, x)) +end + +# like xlogx but for input on log scale, safe when x == -Inf +function xexpx(x::Real) + result = x * exp(x) + return x == -Inf ? zero(result) : result +end + +# x * exp(y) with correct limit for y == -Inf +function xexpy(x::Real, y::Real) + result = x * exp(y) + return y == -Inf && !isnan(x) ? zero(result) : result +end diff --git a/src/common.jl b/src/common.jl index d677ee256..3602372a5 100644 --- a/src/common.jl +++ b/src/common.jl @@ -97,7 +97,9 @@ for func in (:(==), :isequal, :isapprox) for f in fields isdefined(s1, f) && isdefined(s2, f) || return false - $func(getfield(s1, f), getfield(s2, f); kwargs...) || return false + # perform equivalence check to support types that have no defined equality, such + # as `missing` + getfield(s1, f) === getfield(s2, f) || $func(getfield(s1, f), getfield(s2, f); kwargs...) || return false end return true diff --git a/src/show.jl b/src/show.jl index 8c5db7d06..31bc5cd92 100644 --- a/src/show.jl +++ b/src/show.jl @@ -67,7 +67,7 @@ function show_oneline(io::IO, d::Distribution, namevals) print(io, ')') end -function show_multline(io::IO, d::Distribution, namevals) +function show_multline(io::IO, d::Distribution, namevals; newline=true) print(io, distrname(d)) println(io, "(") for (p, pv) in namevals @@ -75,5 +75,5 @@ function show_multline(io::IO, d::Distribution, namevals) print(io, ": ") println(io, pv) end - println(io, ")") + newline ? println(io, ")") : print(io, ")") end diff --git a/src/truncate.jl b/src/truncate.jl index d0bdc4446..b9f2a4385 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -229,3 +229,4 @@ include(joinpath("truncated", "normal.jl")) include(joinpath("truncated", "exponential.jl")) include(joinpath("truncated", "uniform.jl")) include(joinpath("truncated", "loguniform.jl")) +include(joinpath("truncated", "discrete_uniform.jl")) diff --git a/src/truncated/discrete_uniform.jl b/src/truncated/discrete_uniform.jl new file mode 100644 index 000000000..a22ac390d --- /dev/null +++ b/src/truncated/discrete_uniform.jl @@ -0,0 +1,15 @@ +##### +##### Shortcut for truncating discrete uniform distributions. +##### + +function truncated(d::DiscreteUniform, l::T, u::T) where {T <: Real} + a = ceil(Int, max(l, d.a)) + b = floor(Int, min(u, d.b)) + return DiscreteUniform(a, b) +end +function truncated(d::DiscreteUniform, l::Real, ::Nothing) + return DiscreteUniform(ceil(Int, max(l, d.a)), d.b) +end +function truncated(d::DiscreteUniform, ::Nothing, u::Real) + return DiscreteUniform(d.a, floor(Int, min(u, d.b))) +end diff --git a/test/censored.jl b/test/censored.jl new file mode 100644 index 000000000..eaad72cdc --- /dev/null +++ b/test/censored.jl @@ -0,0 +1,386 @@ +# Testing censored distributions + +module TestCensored + +using Distributions, Test +using Distributions: Censored + +function _as_mixture(d::Censored) + d0 = d.uncensored + dtrunc = if d0 isa DiscreteUniform || d0 isa Poisson + truncated( + d0, + d.lower === nothing ? -Inf : floor(d.lower) + 1, + d.upper === nothing ? Inf : ceil(d.upper) - 1, + ) + elseif d0 isa ContinuousDistribution + truncated( + d0, + d.lower === nothing ? -Inf : nextfloat(float(d.lower)), + d.upper === nothing ? Inf : prevfloat(float(d.upper)), + ) + else + error("truncation to open interval not implemented for $d0") + end + prob_lower = d.lower === nothing ? 0 : cdf(d0, d.lower) + prob_upper = if d.upper === nothing + 0 + elseif d0 isa ContinuousDistribution + ccdf(d0, d.upper) + else + ccdf(d0, d.upper) + pdf(d0, d.upper) + end + prob_interval = 1 - (prob_lower + prob_upper) + components = Distribution[dtrunc] + probs = [prob_interval] + if prob_lower > 0 + # workaround for MixtureModel currently not supporting mixtures of discrete and + # continuous components + push!(components, d0 isa DiscreteDistribution ? Dirac(d.lower) : Normal(d.lower, 0)) + push!(probs, prob_lower) + end + if prob_upper > 0 + push!(components, d0 isa DiscreteDistribution ? Dirac(d.upper) : Normal(d.upper, 0)) + push!(probs, prob_upper) + end + return MixtureModel(map(identity, components), probs) +end + +@testset "censored" begin + d0 = Normal(0, 1) + @test_throws ArgumentError censored(d0, 1, -1) + + # bound argument constructors + d = censored(d0, -1, 1.0) + @test d isa Censored + @test d.lower === -1.0 + @test d.upper === 1.0 + + d = censored(d0, nothing, -1) + @test d isa Censored + @test d.lower === nothing + @test d.upper == -1 + + d = censored(d0, 1, nothing) + @test d isa Censored + @test d.upper === nothing + @test d.lower == 1 + + d = censored(d0, nothing, nothing) + @test d === d0 + + # bound keyword constructors + d = censored(d0; lower=-2, upper=1.5) + @test d isa Censored + @test d.lower === -2.0 + @test d.upper === 1.5 + + d = censored(d0; upper=true) + @test d isa Censored + @test d.lower === nothing + @test d.upper === true + + d = censored(d0; lower=-3) + @test d isa Censored + @test d.upper === nothing + @test d.lower === -3 + + d = censored(d0) + @test d === d0 +end + +@testset "Censored" begin + @testset "basic" begin + # check_args + @test_throws ArgumentError Censored(Normal(0, 1), 2, 1) + @test_throws ArgumentError Censored(Normal(0, 1), 2, 1; check_args=true) + Censored(Normal(0, 1), 2, 1; check_args=false) + Censored(Normal(0, 1), nothing, 1; check_args=true) + Censored(Normal(0, 1), 2, nothing; check_args=true) + + d = Censored(Normal(0.0, 1.0), -1, 2) + @test d isa Censored + @test eltype(d) === Float64 + @test params(d) === (params(Normal(0.0, 1.0))..., -1, 2) + @test partype(d) === Float64 + @test @inferred extrema(d) == (-1, 2) + @test @inferred islowerbounded(d) + @test @inferred isupperbounded(d) + @test @inferred insupport(d, 0.1) + @test insupport(d, -1) + @test insupport(d, 2) + @test !insupport(d, -1.1) + @test !insupport(d, 2.1) + @test sprint(show, "text/plain", d) == "Censored($(Normal(0.0, 1.0)); lower=-1, upper=2)" + + d = Censored(Cauchy(0, 1), nothing, 2) + @test d isa Censored + @test eltype(d) === Base.promote_type(eltype(Cauchy(0, 1)), Int) + @test params(d) === (params(Cauchy(0, 1))..., nothing, 2) + @test partype(d) === Float64 + @test extrema(d) == (-Inf, 2.0) + @test @inferred !islowerbounded(d) + @test @inferred isupperbounded(d) + @test @inferred insupport(d, 0.1) + @test insupport(d, -3) + @test insupport(d, 2) + @test !insupport(d, 2.1) + @test sprint(show, "text/plain", d) == "Censored($(Cauchy(0.0, 1.0)); upper=2)" + + d = Censored(Gamma(1, 2), 2, nothing) + @test d isa Censored + @test eltype(d) === Base.promote_type(eltype(Gamma(1, 2)), Int) + @test params(d) === (params(Gamma(1, 2))..., 2, nothing) + @test partype(d) === Float64 + @test extrema(d) == (2.0, Inf) + @test @inferred islowerbounded(d) + @test @inferred !isupperbounded(d) + @test @inferred insupport(d, 2.1) + @test insupport(d, 2.0) + @test !insupport(d, 1.9) + @test sprint(show, "text/plain", d) == "Censored($(Gamma(1, 2)); lower=2)" + + d = Censored(Binomial(10, 0.2), -1.5, 9.5) + @test extrema(d) === (0.0, 9.5) + @test @inferred islowerbounded(d) + @test @inferred isupperbounded(d) + @test @inferred !insupport(d, -1.5) + @test insupport(d, 0) + @test insupport(d, 9.5) + @test !insupport(d, 10) + + @test censored(Censored(Normal(), 1, nothing), nothing, 2) == Censored(Normal(), 1, 2) + @test censored(Censored(Normal(), nothing, 1), -1, nothing) == Censored(Normal(), -1, 1) + @test censored(Censored(Normal(), 1, 2), 1.5, 2.5) == Censored(Normal(), 1.5, 2.0) + @test censored(Censored(Normal(), 1, 3), 1.5, 2.5) == Censored(Normal(), 1.5, 2.5) + @test censored(Censored(Normal(), 1, 2), 0.5, 2.5) == Censored(Normal(), 1.0, 2.0) + @test censored(Censored(Normal(), 1, 2), 0.5, 1.5) == Censored(Normal(), 1.0, 1.5) + + @test censored(Censored(Normal(), nothing, 1), nothing, 1) == Censored(Normal(), nothing, 1) + @test censored(Censored(Normal(), nothing, 1), nothing, 2) == Censored(Normal(), nothing, 1) + @test censored(Censored(Normal(), nothing, 1), nothing, 1.5) == Censored(Normal(), nothing, 1) + @test censored(Censored(Normal(), nothing, 1.5), nothing, 1) == Censored(Normal(), nothing, 1) + + @test censored(Censored(Normal(), 1, nothing), 1, nothing) == Censored(Normal(), 1, nothing) + @test censored(Censored(Normal(), 1, nothing), 2, nothing) == Censored(Normal(), 2, nothing) + @test censored(Censored(Normal(), 1, nothing), 1.5, nothing) == Censored(Normal(), 1.5, nothing) + @test censored(Censored(Normal(), 1.5, nothing), 1, nothing) == Censored(Normal(), 1.5, nothing) + end + + @testset "Uniform" begin + d0 = Uniform(0, 10) + bounds = [ + (nothing, 8), + (-Inf, 8), + (nothing, Inf), + (2, nothing), + (2, Inf), + (-Inf, nothing), + (2, 8), + (3.5, nothing), + (3.5, Inf), + (-Inf, Inf), + ] + @testset "lower = $(lower === nothing ? "nothing" : lower), upper = $(upper === nothing ? "nothing" : upper)" for (lower, upper) in bounds + d = censored(d0, lower, upper) + dmix = _as_mixture(d) + l, u = extrema(d) + if lower === nothing || !isfinite(lower) + @test l == minimum(d0) + else + @test l == lower + end + if upper === nothing || !isfinite(upper) + @test u == maximum(d0) + else + @test u == upper + end + @testset for f in [cdf, logcdf, ccdf, logccdf] + @test @inferred(f(d, l)) ā‰ˆ f(dmix, l) atol=1e-8 + @test @inferred(f(d, l - 0.1)) ā‰ˆ f(dmix, l - 0.1) atol=1e-8 + @test @inferred(f(d, u)) ā‰ˆ f(dmix, u) atol=1e-8 + @test @inferred(f(d, u + 0.1)) ā‰ˆ f(dmix, u + 0.1) atol=1e-8 + @test @inferred(f(d, 5)) ā‰ˆ f(dmix, 5) + end + @testset for f in [mean, var] + @test @inferred(f(d)) ā‰ˆ f(dmix) + end + @test @inferred(median(d)) ā‰ˆ clamp(median(d0), l, u) + @inferred quantile(d, 0.5) + @test quantile.(d, 0:0.01:1) ā‰ˆ clamp.(quantile.(d0, 0:0.01:1), l, u) + # special-case pdf/logpdf/loglikelihood since when replacing Dirac(Ī¼) with + # Normal(Ī¼, 0), they are infinite + if lower === nothing || !isfinite(lower) + @test @inferred(pdf(d, l)) ā‰ˆ pdf(d0, l) + @test @inferred(logpdf(d, l)) ā‰ˆ logpdf(d0, l) + else + @test @inferred(pdf(d, l)) ā‰ˆ cdf(d0, l) + @test @inferred(logpdf(d, l)) ā‰ˆ logcdf(d0, l) + end + if upper === nothing || !isfinite(upper) + @test @inferred(pdf(d, u)) ā‰ˆ pdf(d0, u) + @test @inferred(logpdf(d, u)) ā‰ˆ logpdf(d0, u) + else + @test @inferred(pdf(d, u)) ā‰ˆ ccdf(d0, u) + @test @inferred(logpdf(d, u)) ā‰ˆ logccdf(d0, u) + end + # rand + x = rand(d, 10_000) + @test all(x -> insupport(d, x), x) + # loglikelihood + @test @inferred(loglikelihood(d, x)) ā‰ˆ sum(x -> logpdf(d, x), x) + @test loglikelihood(d, [x; -1]) == -Inf + # entropy + @test @inferred(entropy(d)) ā‰ˆ mean(x -> -logpdf(d, x), x) atol = 1e-1 + end + end + + @testset "Normal" begin + d0 = Normal() + bounds = [(nothing, 0.2), (-0.1, nothing), (-0.1, 0.2)] + @testset "lower = $(lower === nothing ? "nothing" : lower), upper = $(upper === nothing ? "nothing" : upper)" for (lower, upper) in bounds + d = censored(d0, lower, upper) + dmix = _as_mixture(d) + l, u = extrema(d) + @testset for f in [cdf, logcdf, ccdf, logccdf] + @test f(d, l) ā‰ˆ f(dmix, l) atol=1e-8 + @test f(d, l - 0.1) ā‰ˆ f(dmix, l - 0.1) atol=1e-8 + @test f(d, u) ā‰ˆ f(dmix, u) atol=1e-8 + @test f(d, u + 0.1) ā‰ˆ f(dmix, u + 0.1) atol=1e-8 + @test f(d, 5) ā‰ˆ f(dmix, 5) + end + @testset for f in [mean, var] + @test f(d) ā‰ˆ f(dmix) + end + @test median(d) ā‰ˆ clamp(median(d0), l, u) + @test quantile.(d, 0:0.01:1) ā‰ˆ clamp.(quantile.(d0, 0:0.01:1), l, u) + # special-case pdf/logpdf/loglikelihood since when replacing Dirac(Ī¼) with + # Normal(Ī¼, 0), they are infinite + if lower === nothing + @test pdf(d, l) ā‰ˆ pdf(d0, l) + @test logpdf(d, l) ā‰ˆ logpdf(d0, l) + else + @test pdf(d, l) ā‰ˆ cdf(d0, l) + @test logpdf(d, l) ā‰ˆ logcdf(d0, l) + end + if upper === nothing + @test pdf(d, u) ā‰ˆ pdf(d0, u) + @test logpdf(d, u) ā‰ˆ logpdf(d0, u) + else + @test pdf(d, u) ā‰ˆ ccdf(d0, u) + @test logpdf(d, u) ā‰ˆ logccdf(d0, u) + end + # rand + x = rand(d, 10_000) + @test all(x -> insupport(d, x), x) + # loglikelihood + @test loglikelihood(d, x) ā‰ˆ sum(x -> logpdf(d, x), x) + # entropy + @test entropy(d) ā‰ˆ mean(x -> -logpdf(d, x), x) atol = 1e-1 + end + end + + @testset "DiscreteUniform" begin + d0 = DiscreteUniform(0, 10) + bounds = [ + (nothing, 8), + (-Inf, 8), + (nothing, Inf), + (2, nothing), + (2, Inf), + (-Inf, nothing), + (2, 8), + (3.5, nothing), + (3.5, Inf), + (-Inf, Inf), + ] + @testset "lower = $(lower === nothing ? "nothing" : lower), upper = $(upper === nothing ? "nothing" : upper)" for (lower, upper) in bounds + d = censored(d0, lower, upper) + dmix = _as_mixture(d) + @test extrema(d) == extrema(dmix) + l, u = extrema(d) + @testset for f in [pdf, logpdf, cdf, logcdf, ccdf, logccdf] + @test @inferred(f(d, l)) ā‰ˆ f(dmix, l) atol=1e-8 + @test @inferred(f(d, l - 0.1)) ā‰ˆ f(dmix, l - 0.1) atol=1e-8 + @test @inferred(f(d, u)) ā‰ˆ f(dmix, u) atol=1e-8 + @test @inferred(f(d, u + 0.1)) ā‰ˆ f(dmix, u + 0.1) atol=1e-8 + @test @inferred(f(d, 5)) ā‰ˆ f(dmix, 5) + end + @testset for f in [mean, var] + @test @inferred(f(d)) ā‰ˆ f(dmix) + end + @test @inferred(median(d)) ā‰ˆ clamp(median(d0), l, u) + @inferred quantile(d, 0.5) + @test quantile.(d, 0:0.01:1) ā‰ˆ clamp.(quantile.(d0, 0:0.01:1), l, u) + # rand + x = rand(d, 10_000) + @test all(x -> insupport(d, x), x) + # loglikelihood + @test @inferred(loglikelihood(d, x)) ā‰ˆ loglikelihood(dmix, x) + # mean, std + Ī¼ = @inferred mean(d) + xall = unique(x) + @test Ī¼ ā‰ˆ sum(x -> pdf(d, x) * x, xall) + @test mean(x) ā‰ˆ Ī¼ atol = 1e-1 + v = @inferred var(d) + @test v ā‰ˆ sum(x -> pdf(d, x) * abs2(x - Ī¼), xall) + @test std(x) ā‰ˆ sqrt(v) atol = 1e-1 + # entropy + @test @inferred(entropy(d)) ā‰ˆ sum(x -> pdf(d, x) * -logpdf(d, x), xall) + end + end + + @testset "Poisson" begin + d0 = Poisson(20) + bounds = [(nothing, 12), (2, nothing), (2, 12), (8, nothing)] + @testset "lower = $(lower === nothing ? "nothing" : lower), upper = $(upper === nothing ? "nothing" : upper)" for (lower, upper) in bounds + d = censored(d0, lower, upper) + dmix = _as_mixture(d) + @test extrema(d) == extrema(dmix) + l, u = extrema(d) + @testset for f in [pdf, logpdf, cdf, logcdf, ccdf, logccdf] + @test f(d, l) ā‰ˆ f(dmix, l) atol=1e-8 + @test f(d, l - 0.1) ā‰ˆ f(dmix, l - 0.1) atol=1e-8 + @test f(d, u) ā‰ˆ f(dmix, u) atol=1e-8 + @test f(d, u + 0.1) ā‰ˆ f(dmix, u + 0.1) atol=1e-8 + @test f(d, 5) ā‰ˆ f(dmix, 5) + end + @test median(d) ā‰ˆ clamp(median(d0), l, u) + @test quantile.(d, 0:0.01:0.99) ā‰ˆ clamp.(quantile.(d0, 0:0.01:0.99), l, u) + x = rand(d, 100) + @test loglikelihood(d, x) ā‰ˆ loglikelihood(dmix, x) + # rand + x = rand(d, 10_000) + @test all(x -> insupport(d, x), x) + # mean, std + @test mean(x) ā‰ˆ mean(x) atol = 1e-1 + @test std(x) ā‰ˆ std(x) atol = 1e-1 + end + end + + @testset "mixed types are still type-inferrible" begin + bounds = [(nothing, 8), (2, nothing), (2, 8)] + @testset "lower = $(lower === nothing ? "nothing" : lower), upper = $(upper === nothing ? "nothing" : upper), uncensored partype=$T0, partype=$T" for (lower, upper) in bounds, + T in (Int, Float32, Float64), T0 in (Int, Float32, Float64) + d0 = Uniform(T0(0), T0(10)) + d = censored(d0, lower === nothing ? nothing : T(lower), upper === nothing ? nothing : T(upper)) + l, u = extrema(d) + @testset for f in [pdf, logpdf, cdf, logcdf, ccdf, logccdf] + @inferred f(d, 3) + @inferred f(d, 4f0) + @inferred f(d, 5.0) + end + @testset for f in [median, mean, var, entropy] + @inferred f(d) + end + @inferred quantile(d, 0.3f0) + @inferred quantile(d, 0.5) + x = randn(Float32, 100) + @inferred loglikelihood(d, x) + x = randn(100) + @inferred loglikelihood(d, x) + end + end +end + +end # module \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 314dea3ea..a08d4620a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,8 @@ const tests = [ "truncnormal", "truncated_exponential", "truncated_uniform", + "truncated_discrete_uniform", + "censored", "normal", "laplace", "cauchy", diff --git a/test/truncate.jl b/test/truncate.jl index d08a6fc77..457cad218 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -39,7 +39,7 @@ function verify_and_test_drive(jsonfile, selected, n_tsamples::Int,lower::Int,up println(" testing truncated($(ex),$lower,$upper)") d = truncated(eval(Meta.parse(ex)),lower,upper) - if dtype != Uniform && dtype != TruncatedNormal # Uniform is truncated to Uniform + if dtype != Uniform && dtype != DiscreteUniform && dtype != TruncatedNormal # Uniform is truncated to Uniform @assert isa(dtype, Type) && dtype <: UnivariateDistribution @test isa(d, dtypet) # verification and testing diff --git a/test/truncated_discrete_uniform.jl b/test/truncated_discrete_uniform.jl new file mode 100644 index 000000000..acc87896f --- /dev/null +++ b/test/truncated_discrete_uniform.jl @@ -0,0 +1,20 @@ +using Distributions, Test + +@testset "truncated DiscreteUniform" begin + # just test equivalence of truncation results + bounds = [(1, 10), (-3, 7), (-5, -2)] + @testset "lower=$lower, upper=$upper" for (lower, upper) in bounds + d = DiscreteUniform(lower, upper) + @test truncated(d, -Inf, Inf) == d + @test truncated(d, nothing, nothing) === d + @test truncated(d, lower - 0.1, Inf) == d + @test truncated(d, lower - 0.1, nothing) == d + @test truncated(d, -Inf, upper + 0.1) == d + @test truncated(d, nothing, upper + 0.1) == d + @test truncated(d, lower + 0.3, Inf) == DiscreteUniform(lower + 1, upper) + @test truncated(d, lower + 0.3, nothing) == DiscreteUniform(lower + 1, upper) + @test truncated(d, -Inf, upper - 0.5) == DiscreteUniform(lower, upper - 1) + @test truncated(d, nothing, upper - 0.5) == DiscreteUniform(lower, upper - 1) + @test truncated(d, lower + 1.5, upper - 1) == DiscreteUniform(lower + 2, upper - 1) + end +end