diff --git a/src/truncate.jl b/src/truncate.jl index 2f0ce87cab..804c709fad 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -45,33 +45,26 @@ function truncated(d::UnivariateDistribution, ::Nothing, u::Real) logucdf = logtp = logcdf(d, u) ucdf = tp = exp(logucdf) - Truncated(d, promote(oftype(float(u), -Inf), u, oftype(ucdf, -Inf), zero(ucdf), ucdf, tp, logtp)...) + Truncated(d, nothing, promote(u, oftype(ucdf, -Inf), zero(ucdf), ucdf, tp, logtp)...) end function truncated(d::UnivariateDistribution, l::Real, ::Nothing) # (log)lcdf = (log) P(X < l) where X ~ d - loglcdf = if value_support(typeof(d)) === Discrete - logsubexp(logcdf(d, l), logpdf(d, l)) - else - logcdf(d, l) - end + loglcdf = _logcdf_noninclusive(d, l) lcdf = exp(loglcdf) # (log)tp = (log) P(l ≤ X) where X ∼ d logtp = log1mexp(loglcdf) tp = exp(logtp) - Truncated(d, promote(l, oftype(float(l), Inf), loglcdf, lcdf, one(lcdf), tp, logtp)...) + l, loglcdf, lcdf, ucdf, tp, logtp = promote(l, loglcdf, lcdf, one(lcdf), tp, logtp) + Truncated(d, l, nothing, loglcdf, lcdf, ucdf, tp, logtp) end truncated(d::UnivariateDistribution, ::Nothing, ::Nothing) = d function truncated(d::UnivariateDistribution, l::T, u::T) where {T <: Real} l <= u || error("the lower bound must be less or equal than the upper bound") # (log)lcdf = (log) P(X < l) where X ~ d - loglcdf = if value_support(typeof(d)) === Discrete - logsubexp(logcdf(d, l), logpdf(d, l)) - else - logcdf(d, l) - end + loglcdf = _logcdf_noninclusive(d, l) lcdf = exp(loglcdf) # (log)ucdf = (log) P(X ≤ u) where X ~ d @@ -90,10 +83,10 @@ end Generic wrapper for a truncated distribution. """ -struct Truncated{D<:UnivariateDistribution, S<:ValueSupport, T <: Real} <: UnivariateDistribution{S} +struct Truncated{D<:UnivariateDistribution, S<:ValueSupport, T<: Real, TL<:Union{T,Nothing}, TU<:Union{T,Nothing}} <: UnivariateDistribution{S} untruncated::D # the original distribution (untruncated) - lower::T # lower bound - upper::T # upper bound + lower::TL # lower bound + upper::TU # upper bound loglcdf::T # log-cdf of lower bound (exclusive): log P(X < lower) lcdf::T # cdf of lower bound (exclusive): P(X < lower) ucdf::T # cdf of upper bound (inclusive): P(X ≤ upper) @@ -101,29 +94,54 @@ struct Truncated{D<:UnivariateDistribution, S<:ValueSupport, T <: Real} <: Univa tp::T # the probability of the truncated part, i.e. ucdf - lcdf logtp::T # log(tp), i.e. log(ucdf - lcdf) - function Truncated(d::UnivariateDistribution, l::T, u::T, loglcdf::T, lcdf::T, ucdf::T, tp::T, logtp::T) where {T <: Real} - new{typeof(d), value_support(typeof(d)), T}(d, l, u, loglcdf, lcdf, ucdf, tp, logtp) + function Truncated(d::UnivariateDistribution, l::TL, u::TU, loglcdf::T, lcdf::T, ucdf::T, tp::T, logtp::T) where {T <: Real, TL <: Union{T,Nothing}, TU <: Union{T,Nothing}} + new{typeof(d), value_support(typeof(d)), T, TL, TU}(d, l, u, loglcdf, lcdf, ucdf, tp, logtp) end end +const LeftTruncated{D<:UnivariateDistribution,S<:ValueSupport,T<:Real} = Truncated{D,S,T,T,Nothing} +const RightTruncated{D<:UnivariateDistribution,S<:ValueSupport,T<:Real} = Truncated{D,S,T,Nothing,T} + ### Constructors of `Truncated` are deprecated - users should call `truncated` @deprecate Truncated(d::UnivariateDistribution, l::Real, u::Real) truncated(d, l, u) @deprecate Truncated(d::UnivariateDistribution, l::T, u::T, lcdf::T, ucdf::T, tp::T, logtp::T) where {T <: Real} Truncated(d, l, u, log(lcdf), lcdf, ucdf, tp, logtp) +function truncated(d::Truncated, l::T, u::T) where {T<:Real} + return truncated( + d.untruncated, + d.lower === nothing ? l : max(l, d.lower), + d.upper === nothing ? u : min(u, d.upper), + ) +end +function truncated(d::Truncated, ::Nothing, u::Real) + return truncated(d.untruncated, d.lower, d.upper === nothing ? u : min(u, d.upper)) +end +function truncated(d::Truncated, l::Real, ::Nothing) + return truncated(d.untruncated, d.lower === nothing ? l : max(l, d.lower), d.upper) +end + params(d::Truncated) = tuple(params(d.untruncated)..., d.lower, d.upper) -partype(d::Truncated) = partype(d.untruncated) -Base.eltype(::Type{Truncated{D, S, T} } ) where {D, S, T} = T +partype(d::Truncated{<:UnivariateDistribution,<:ValueSupport,T}) where {T<:Real} = promote_type(partype(d.untruncated), T) + +Base.eltype(::Type{<:Truncated{D}}) where {D<:UnivariateDistribution} = eltype(D) +Base.eltype(d::Truncated) = eltype(d.untruncated) ### range and support +islowerbounded(d::RightTruncated) = islowerbounded(d.untruncated) islowerbounded(d::Truncated) = islowerbounded(d.untruncated) || isfinite(d.lower) + +isupperbounded(d::LeftTruncated) = isupperbounded(d.untruncated) isupperbounded(d::Truncated) = isupperbounded(d.untruncated) || isfinite(d.upper) +minimum(d::RightTruncated) = minimum(d.untruncated) minimum(d::Truncated) = max(minimum(d.untruncated), d.lower) + +maximum(d::LeftTruncated) = maximum(d.untruncated) maximum(d::Truncated) = min(maximum(d.untruncated), d.upper) -function insupport(d::Truncated{D,<:Union{Discrete,Continuous}}, x::Real) where {D<:UnivariateDistribution} - return d.lower <= x <= d.upper && insupport(d.untruncated, x) +function insupport(d::Truncated{<:UnivariateDistribution,<:Union{Discrete,Continuous}}, x::Real) + return _in_closed_interval(x, d.lower, d.upper) && insupport(d.untruncated, x) end ### evaluation @@ -132,19 +150,19 @@ quantile(d::Truncated, p::Real) = quantile(d.untruncated, d.lcdf + p * d.tp) function pdf(d::Truncated, x::Real) result = pdf(d.untruncated, x) / d.tp - return d.lower <= x <= d.upper ? result : zero(result) + return _in_closed_interval(x, d.lower, d.upper) ? result : zero(result) end function logpdf(d::Truncated, x::Real) result = logpdf(d.untruncated, x) - d.logtp - return d.lower <= x <= d.upper ? result : oftype(result, -Inf) + return _in_closed_interval(x, d.lower, d.upper) ? result : oftype(result, -Inf) end function cdf(d::Truncated, x::Real) result = (cdf(d.untruncated, x) - d.lcdf) / d.tp - return if x < d.lower + return if d.lower !== nothing && x < d.lower zero(result) - elseif x >= d.upper + elseif d.upper !== nothing && x >= d.upper one(result) else result @@ -153,9 +171,9 @@ end function logcdf(d::Truncated, x::Real) result = logsubexp(logcdf(d.untruncated, x), d.loglcdf) - d.logtp - return if x < d.lower + return if d.lower !== nothing && x < d.lower oftype(result, -Inf) - elseif x >= d.upper + elseif d.upper !== nothing && x >= d.upper zero(result) else result @@ -164,9 +182,9 @@ end function ccdf(d::Truncated, x::Real) result = (d.ucdf - cdf(d.untruncated, x)) / d.tp - return if x <= d.lower + return if d.lower !== nothing && x <= d.lower one(result) - elseif x > d.upper + elseif d.upper !== nothing && x > d.upper zero(result) else result @@ -175,9 +193,9 @@ end function logccdf(d::Truncated, x::Real) result = logsubexp(logccdf(d.untruncated, x), log1p(-d.ucdf)) - d.logtp - return if x <= d.lower + return if d.lower !== nothing && x <= d.lower zero(result) - elseif x > d.upper + elseif d.upper !== nothing && x > d.upper oftype(result, -Inf) else result @@ -189,10 +207,12 @@ end function rand(rng::AbstractRNG, d::Truncated) d0 = d.untruncated tp = d.tp + lower = d.lower + upper = d.upper if tp > 0.25 while true r = rand(rng, d0) - if d.lower <= r <= d.upper + if _in_closed_interval(r, lower, upper) return r end end @@ -212,16 +232,12 @@ function show(io::IO, d::Truncated) uml, namevals = _use_multline_show(d0) uml ? show_multline(io, d0, namevals) : show_oneline(io, d0, namevals) - if d.lower > -Inf - if d.upper < Inf - print(io, "; lower=$(d.lower), upper=$(d.upper))") - else - print(io, "; lower=$(d.lower))") - end - elseif d.upper < Inf + if d.lower === nothing print(io, "; upper=$(d.upper))") + elseif d.upper === nothing + print(io, "; lower=$(d.lower))") else - print(io, ")") + print(io, "; lower=$(d.lower), upper=$(d.upper))") end uml && println(io) end @@ -236,3 +252,10 @@ include(joinpath("truncated", "exponential.jl")) include(joinpath("truncated", "uniform.jl")) include(joinpath("truncated", "loguniform.jl")) include(joinpath("truncated", "discrete_uniform.jl")) + +#### Utilities + +# utilities to handle closed intervals represented with possibly `nothing` bounds +_in_closed_interval(x::Real, l::Real, u::Real) = l ≤ x ≤ u +_in_closed_interval(x::Real, ::Nothing, u::Real) = x ≤ u +_in_closed_interval(x::Real, l::Real, ::Nothing) = x ≥ l diff --git a/src/truncated/normal.jl b/src/truncated/normal.jl index dac92f11a1..6b79b5dae3 100644 --- a/src/truncated/normal.jl +++ b/src/truncated/normal.jl @@ -12,17 +12,12 @@ TruncatedNormal ### statistics -minimum(d::Truncated{Normal{T},Continuous}) where {T <: Real} = d.lower -maximum(d::Truncated{Normal{T},Continuous}) where {T <: Real} = d.upper - - -function mode(d::Truncated{Normal{T},Continuous}) where T <: Real +function mode(d::Truncated{<:Normal{<:Real},Continuous}) μ = mean(d.untruncated) - d.upper < μ ? d.upper : - d.lower > μ ? d.lower : μ + return clamp(μ, extrema(d)...) end -modes(d::Truncated{Normal{T},Continuous}) where {T <: Real} = [mode(d)] +modes(d::Truncated{<:Normal{<:Real},Continuous}) = [mode(d)] # do not export. Used in mean # computes mean of standard normal distribution truncated to [a, b] @@ -94,39 +89,42 @@ function _tnvar(a::Real, b::Real) end end -function mean(d::Truncated{Normal{T},Continuous}) where T <: Real +function mean(d::Truncated{<:Normal{<:Real},Continuous}) d0 = d.untruncated μ = mean(d0) σ = std(d0) if iszero(σ) return mode(d) else - a = (d.lower - μ) / σ - b = (d.upper - μ) / σ + lower, upper = extrema(d) + a = (lower - μ) / σ + b = (upper - μ) / σ return μ + _tnmom1(a, b) * σ end end -function var(d::Truncated{Normal{T},Continuous}) where T <: Real +function var(d::Truncated{<:Normal{<:Real},Continuous}) d0 = d.untruncated μ = mean(d0) σ = std(d0) if iszero(σ) return σ else - a = (d.lower - μ) / σ - b = (d.upper - μ) / σ + lower, upper = extrema(d) + a = (lower - μ) / σ + b = (upper - μ) / σ return _tnvar(a, b) * σ^2 end end -function entropy(d::Truncated{Normal{T},Continuous}) where T <: Real +function entropy(d::Truncated{<:Normal{<:Real},Continuous}) d0 = d.untruncated z = d.tp μ = mean(d0) σ = std(d0) - a = (d.lower - μ) / σ - b = (d.upper - μ) / σ + lower, upper = extrema(d) + a = (lower - μ) / σ + b = (upper - μ) / σ aφa = isinf(a) ? 0.0 : a * normpdf(a) bφb = isinf(b) ? 0.0 : b * normpdf(b) 0.5 * (log2π + 1.) + log(σ * z) + (aφa - bφb) / (2.0 * z) @@ -138,17 +136,18 @@ end ## Use specialized sampler, as quantile-based method is inaccurate in ## tail regions of the Normal, issue #343 -function rand(rng::AbstractRNG, d::Truncated{Normal{T},Continuous}) where T <: Real +function rand(rng::AbstractRNG, d::Truncated{<:Normal{<:Real},Continuous}) d0 = d.untruncated μ = mean(d0) σ = std(d0) if isfinite(μ) - a = (d.lower - μ) / σ - b = (d.upper - μ) / σ + lower, upper = extrema(d) + a = (lower - μ) / σ + b = (upper - μ) / σ z = randnt(rng, a, b, d.tp) return μ + σ * z else - return clamp(μ, d.lower, d.upper) + return clamp(μ, extrema(d)...) end end diff --git a/test/truncate.jl b/test/truncate.jl index 1ca48e4c5b..b9c0b42635 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -78,7 +78,7 @@ function verify_and_test(d::UnivariateDistribution, dct::Dict, n_tsamples::Int) end @test cdf(d, x) ≈ cf atol=sqrt(eps()) # NOTE: some distributions use pdf() in StatsFuns.jl which have no generic support yet - if !(typeof(d) in [Distributions.Truncated{Distributions.NoncentralChisq{Float64},Distributions.Continuous, Float64}, + if !any(T -> d isa T, [Distributions.Truncated{Distributions.NoncentralChisq{Float64},Distributions.Continuous, Float64}, Distributions.Truncated{Distributions.NoncentralF{Float64},Distributions.Continuous, Float64}, Distributions.Truncated{Distributions.NoncentralT{Float64},Distributions.Continuous, Float64}, Distributions.Truncated{Distributions.StudentizedRange{Float64},Distributions.Continuous, Float64}, @@ -133,12 +133,18 @@ for (μ, lower, upper) in [(0, -1, 1), (1, 2, 4)] end for bound in (-2, 1) d = @test_deprecated Distributions.Truncated(Normal(), Float64(bound), Inf) - @test truncated(Normal(); lower=bound) == d @test truncated(Normal(); lower=bound, upper=Inf) == d + d_nothing = truncated(Normal(); lower=bound) + @test truncated(Normal(); lower=bound, upper=nothing) == d_nothing + @test extrema(d_nothing) == promote(bound, Inf) + d = @test_deprecated Distributions.Truncated(Normal(), -Inf, Float64(bound)) - @test truncated(Normal(); upper=bound) == d @test truncated(Normal(); lower=-Inf, upper=bound) == d + + d_nothing = truncated(Normal(); upper=bound) + @test truncated(Normal(); lower=nothing, upper=bound) == d_nothing + @test extrema(d_nothing) == promote(-Inf, bound) end @test truncated(Normal()) === Normal()