Skip to content

Commit

Permalink
Store nothing bounds in Truncated (#1720)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored May 15, 2023
1 parent 9cf6a74 commit 32cea9b
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 64 deletions.
103 changes: 63 additions & 40 deletions src/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,26 @@ function truncated(d::UnivariateDistribution, ::Nothing, u::Real)
logucdf = logtp = logcdf(d, u)
ucdf = tp = exp(logucdf)

Truncated(d, promote(oftype(float(u), -Inf), u, oftype(ucdf, -Inf), zero(ucdf), ucdf, tp, logtp)...)
Truncated(d, nothing, promote(u, oftype(ucdf, -Inf), zero(ucdf), ucdf, tp, logtp)...)
end
function truncated(d::UnivariateDistribution, l::Real, ::Nothing)
# (log)lcdf = (log) P(X < l) where X ~ d
loglcdf = if value_support(typeof(d)) === Discrete
logsubexp(logcdf(d, l), logpdf(d, l))
else
logcdf(d, l)
end
loglcdf = _logcdf_noninclusive(d, l)
lcdf = exp(loglcdf)

# (log)tp = (log) P(l ≤ X) where X ∼ d
logtp = log1mexp(loglcdf)
tp = exp(logtp)

Truncated(d, promote(l, oftype(float(l), Inf), loglcdf, lcdf, one(lcdf), tp, logtp)...)
l, loglcdf, lcdf, ucdf, tp, logtp = promote(l, loglcdf, lcdf, one(lcdf), tp, logtp)
Truncated(d, l, nothing, loglcdf, lcdf, ucdf, tp, logtp)
end
truncated(d::UnivariateDistribution, ::Nothing, ::Nothing) = d
function truncated(d::UnivariateDistribution, l::T, u::T) where {T <: Real}
l <= u || error("the lower bound must be less or equal than the upper bound")

# (log)lcdf = (log) P(X < l) where X ~ d
loglcdf = if value_support(typeof(d)) === Discrete
logsubexp(logcdf(d, l), logpdf(d, l))
else
logcdf(d, l)
end
loglcdf = _logcdf_noninclusive(d, l)
lcdf = exp(loglcdf)

# (log)ucdf = (log) P(X ≤ u) where X ~ d
Expand All @@ -90,40 +83,65 @@ end
Generic wrapper for a truncated distribution.
"""
struct Truncated{D<:UnivariateDistribution, S<:ValueSupport, T <: Real} <: UnivariateDistribution{S}
struct Truncated{D<:UnivariateDistribution, S<:ValueSupport, T<: Real, TL<:Union{T,Nothing}, TU<:Union{T,Nothing}} <: UnivariateDistribution{S}
untruncated::D # the original distribution (untruncated)
lower::T # lower bound
upper::T # upper bound
lower::TL # lower bound
upper::TU # upper bound
loglcdf::T # log-cdf of lower bound (exclusive): log P(X < lower)
lcdf::T # cdf of lower bound (exclusive): P(X < lower)
ucdf::T # cdf of upper bound (inclusive): P(X ≤ upper)

tp::T # the probability of the truncated part, i.e. ucdf - lcdf
logtp::T # log(tp), i.e. log(ucdf - lcdf)

function Truncated(d::UnivariateDistribution, l::T, u::T, loglcdf::T, lcdf::T, ucdf::T, tp::T, logtp::T) where {T <: Real}
new{typeof(d), value_support(typeof(d)), T}(d, l, u, loglcdf, lcdf, ucdf, tp, logtp)
function Truncated(d::UnivariateDistribution, l::TL, u::TU, loglcdf::T, lcdf::T, ucdf::T, tp::T, logtp::T) where {T <: Real, TL <: Union{T,Nothing}, TU <: Union{T,Nothing}}
new{typeof(d), value_support(typeof(d)), T, TL, TU}(d, l, u, loglcdf, lcdf, ucdf, tp, logtp)
end
end

const LeftTruncated{D<:UnivariateDistribution,S<:ValueSupport,T<:Real} = Truncated{D,S,T,T,Nothing}
const RightTruncated{D<:UnivariateDistribution,S<:ValueSupport,T<:Real} = Truncated{D,S,T,Nothing,T}

### Constructors of `Truncated` are deprecated - users should call `truncated`
@deprecate Truncated(d::UnivariateDistribution, l::Real, u::Real) truncated(d, l, u)
@deprecate Truncated(d::UnivariateDistribution, l::T, u::T, lcdf::T, ucdf::T, tp::T, logtp::T) where {T <: Real} Truncated(d, l, u, log(lcdf), lcdf, ucdf, tp, logtp)

function truncated(d::Truncated, l::T, u::T) where {T<:Real}
return truncated(
d.untruncated,
d.lower === nothing ? l : max(l, d.lower),
d.upper === nothing ? u : min(u, d.upper),
)
end
function truncated(d::Truncated, ::Nothing, u::Real)
return truncated(d.untruncated, d.lower, d.upper === nothing ? u : min(u, d.upper))
end
function truncated(d::Truncated, l::Real, ::Nothing)
return truncated(d.untruncated, d.lower === nothing ? l : max(l, d.lower), d.upper)
end

params(d::Truncated) = tuple(params(d.untruncated)..., d.lower, d.upper)
partype(d::Truncated) = partype(d.untruncated)
Base.eltype(::Type{Truncated{D, S, T} } ) where {D, S, T} = T
partype(d::Truncated{<:UnivariateDistribution,<:ValueSupport,T}) where {T<:Real} = promote_type(partype(d.untruncated), T)

Base.eltype(::Type{<:Truncated{D}}) where {D<:UnivariateDistribution} = eltype(D)
Base.eltype(d::Truncated) = eltype(d.untruncated)

### range and support

islowerbounded(d::RightTruncated) = islowerbounded(d.untruncated)
islowerbounded(d::Truncated) = islowerbounded(d.untruncated) || isfinite(d.lower)

isupperbounded(d::LeftTruncated) = isupperbounded(d.untruncated)
isupperbounded(d::Truncated) = isupperbounded(d.untruncated) || isfinite(d.upper)

minimum(d::RightTruncated) = minimum(d.untruncated)
minimum(d::Truncated) = max(minimum(d.untruncated), d.lower)

maximum(d::LeftTruncated) = maximum(d.untruncated)
maximum(d::Truncated) = min(maximum(d.untruncated), d.upper)

function insupport(d::Truncated{D,<:Union{Discrete,Continuous}}, x::Real) where {D<:UnivariateDistribution}
return d.lower <= x <= d.upper && insupport(d.untruncated, x)
function insupport(d::Truncated{<:UnivariateDistribution,<:Union{Discrete,Continuous}}, x::Real)
return _in_closed_interval(x, d.lower, d.upper) && insupport(d.untruncated, x)
end

### evaluation
Expand All @@ -132,19 +150,19 @@ quantile(d::Truncated, p::Real) = quantile(d.untruncated, d.lcdf + p * d.tp)

function pdf(d::Truncated, x::Real)
result = pdf(d.untruncated, x) / d.tp
return d.lower <= x <= d.upper ? result : zero(result)
return _in_closed_interval(x, d.lower, d.upper) ? result : zero(result)
end

function logpdf(d::Truncated, x::Real)
result = logpdf(d.untruncated, x) - d.logtp
return d.lower <= x <= d.upper ? result : oftype(result, -Inf)
return _in_closed_interval(x, d.lower, d.upper) ? result : oftype(result, -Inf)
end

function cdf(d::Truncated, x::Real)
result = (cdf(d.untruncated, x) - d.lcdf) / d.tp
return if x < d.lower
return if d.lower !== nothing && x < d.lower
zero(result)
elseif x >= d.upper
elseif d.upper !== nothing && x >= d.upper
one(result)
else
result
Expand All @@ -153,9 +171,9 @@ end

function logcdf(d::Truncated, x::Real)
result = logsubexp(logcdf(d.untruncated, x), d.loglcdf) - d.logtp
return if x < d.lower
return if d.lower !== nothing && x < d.lower
oftype(result, -Inf)
elseif x >= d.upper
elseif d.upper !== nothing && x >= d.upper
zero(result)
else
result
Expand All @@ -164,9 +182,9 @@ end

function ccdf(d::Truncated, x::Real)
result = (d.ucdf - cdf(d.untruncated, x)) / d.tp
return if x <= d.lower
return if d.lower !== nothing && x <= d.lower
one(result)
elseif x > d.upper
elseif d.upper !== nothing && x > d.upper
zero(result)
else
result
Expand All @@ -175,9 +193,9 @@ end

function logccdf(d::Truncated, x::Real)
result = logsubexp(logccdf(d.untruncated, x), log1p(-d.ucdf)) - d.logtp
return if x <= d.lower
return if d.lower !== nothing && x <= d.lower
zero(result)
elseif x > d.upper
elseif d.upper !== nothing && x > d.upper
oftype(result, -Inf)
else
result
Expand All @@ -189,10 +207,12 @@ end
function rand(rng::AbstractRNG, d::Truncated)
d0 = d.untruncated
tp = d.tp
lower = d.lower
upper = d.upper
if tp > 0.25
while true
r = rand(rng, d0)
if d.lower <= r <= d.upper
if _in_closed_interval(r, lower, upper)
return r
end
end
Expand All @@ -212,16 +232,12 @@ function show(io::IO, d::Truncated)
uml, namevals = _use_multline_show(d0)
uml ? show_multline(io, d0, namevals) :
show_oneline(io, d0, namevals)
if d.lower > -Inf
if d.upper < Inf
print(io, "; lower=$(d.lower), upper=$(d.upper))")
else
print(io, "; lower=$(d.lower))")
end
elseif d.upper < Inf
if d.lower === nothing
print(io, "; upper=$(d.upper))")
elseif d.upper === nothing
print(io, "; lower=$(d.lower))")
else
print(io, ")")
print(io, "; lower=$(d.lower), upper=$(d.upper))")
end
uml && println(io)
end
Expand All @@ -236,3 +252,10 @@ include(joinpath("truncated", "exponential.jl"))
include(joinpath("truncated", "uniform.jl"))
include(joinpath("truncated", "loguniform.jl"))
include(joinpath("truncated", "discrete_uniform.jl"))

#### Utilities

# utilities to handle closed intervals represented with possibly `nothing` bounds
_in_closed_interval(x::Real, l::Real, u::Real) = l x u
_in_closed_interval(x::Real, ::Nothing, u::Real) = x u
_in_closed_interval(x::Real, l::Real, ::Nothing) = x l
41 changes: 20 additions & 21 deletions src/truncated/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@ TruncatedNormal

### statistics

minimum(d::Truncated{Normal{T},Continuous}) where {T <: Real} = d.lower
maximum(d::Truncated{Normal{T},Continuous}) where {T <: Real} = d.upper


function mode(d::Truncated{Normal{T},Continuous}) where T <: Real
function mode(d::Truncated{<:Normal{<:Real},Continuous})
μ = mean(d.untruncated)
d.upper < μ ? d.upper :
d.lower > μ ? d.lower : μ
return clamp(μ, extrema(d)...)
end

modes(d::Truncated{Normal{T},Continuous}) where {T <: Real} = [mode(d)]
modes(d::Truncated{<:Normal{<:Real},Continuous}) = [mode(d)]

# do not export. Used in mean
# computes mean of standard normal distribution truncated to [a, b]
Expand Down Expand Up @@ -94,39 +89,42 @@ function _tnvar(a::Real, b::Real)
end
end

function mean(d::Truncated{Normal{T},Continuous}) where T <: Real
function mean(d::Truncated{<:Normal{<:Real},Continuous})
d0 = d.untruncated
μ = mean(d0)
σ = std(d0)
if iszero(σ)
return mode(d)
else
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
lower, upper = extrema(d)
a = (lower - μ) / σ
b = (upper - μ) / σ
return μ + _tnmom1(a, b) * σ
end
end

function var(d::Truncated{Normal{T},Continuous}) where T <: Real
function var(d::Truncated{<:Normal{<:Real},Continuous})
d0 = d.untruncated
μ = mean(d0)
σ = std(d0)
if iszero(σ)
return σ
else
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
lower, upper = extrema(d)
a = (lower - μ) / σ
b = (upper - μ) / σ
return _tnvar(a, b) * σ^2
end
end

function entropy(d::Truncated{Normal{T},Continuous}) where T <: Real
function entropy(d::Truncated{<:Normal{<:Real},Continuous})
d0 = d.untruncated
z = d.tp
μ = mean(d0)
σ = std(d0)
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
lower, upper = extrema(d)
a = (lower - μ) / σ
b = (upper - μ) / σ
aφa = isinf(a) ? 0.0 : a * normpdf(a)
bφb = isinf(b) ? 0.0 : b * normpdf(b)
0.5 * (log2π + 1.) + log* z) + (aφa - bφb) / (2.0 * z)
Expand All @@ -138,17 +136,18 @@ end
## Use specialized sampler, as quantile-based method is inaccurate in
## tail regions of the Normal, issue #343

function rand(rng::AbstractRNG, d::Truncated{Normal{T},Continuous}) where T <: Real
function rand(rng::AbstractRNG, d::Truncated{<:Normal{<:Real},Continuous})
d0 = d.untruncated
μ = mean(d0)
σ = std(d0)
if isfinite(μ)
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
lower, upper = extrema(d)
a = (lower - μ) / σ
b = (upper - μ) / σ
z = randnt(rng, a, b, d.tp)
return μ + σ * z
else
return clamp(μ, d.lower, d.upper)
return clamp(μ, extrema(d)...)
end
end

Expand Down
12 changes: 9 additions & 3 deletions test/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function verify_and_test(d::UnivariateDistribution, dct::Dict, n_tsamples::Int)
end
@test cdf(d, x) cf atol=sqrt(eps())
# NOTE: some distributions use pdf() in StatsFuns.jl which have no generic support yet
if !(typeof(d) in [Distributions.Truncated{Distributions.NoncentralChisq{Float64},Distributions.Continuous, Float64},
if !any(T -> d isa T, [Distributions.Truncated{Distributions.NoncentralChisq{Float64},Distributions.Continuous, Float64},
Distributions.Truncated{Distributions.NoncentralF{Float64},Distributions.Continuous, Float64},
Distributions.Truncated{Distributions.NoncentralT{Float64},Distributions.Continuous, Float64},
Distributions.Truncated{Distributions.StudentizedRange{Float64},Distributions.Continuous, Float64},
Expand Down Expand Up @@ -133,12 +133,18 @@ for (μ, lower, upper) in [(0, -1, 1), (1, 2, 4)]
end
for bound in (-2, 1)
d = @test_deprecated Distributions.Truncated(Normal(), Float64(bound), Inf)
@test truncated(Normal(); lower=bound) == d
@test truncated(Normal(); lower=bound, upper=Inf) == d

d_nothing = truncated(Normal(); lower=bound)
@test truncated(Normal(); lower=bound, upper=nothing) == d_nothing
@test extrema(d_nothing) == promote(bound, Inf)

d = @test_deprecated Distributions.Truncated(Normal(), -Inf, Float64(bound))
@test truncated(Normal(); upper=bound) == d
@test truncated(Normal(); lower=-Inf, upper=bound) == d

d_nothing = truncated(Normal(); upper=bound)
@test truncated(Normal(); lower=nothing, upper=bound) == d_nothing
@test extrema(d_nothing) == promote(-Inf, bound)
end
@test truncated(Normal()) === Normal()

Expand Down

0 comments on commit 32cea9b

Please sign in to comment.