diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 28f06f39..04a33fc6 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -210,6 +210,8 @@ function link( return Y end +clamp0to1(x::T) where T = clamp(x, zero(T), one(T)) + function invlink( d::SimplexDistribution, y::AbstractVector{T}, @@ -219,18 +221,18 @@ function invlink( ϵ = _eps(T) z = StatsFuns.logistic(y[1] + log(one(T) / (K - 1))) - x[1] = (z - ϵ) / (one(T) - 2ϵ) + x[1] = (z - ϵ) / (one(T) - 2ϵ) |> clamp0to1 sum_tmp = zero(T) @inbounds for k = 2:(K - 1) z = StatsFuns.logistic(y[k] + log(one(T) / (K - k))) sum_tmp += x[k-1] - x[k] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ + x[k] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ |> clamp0to1 end sum_tmp += x[K - 1] if proj - x[K] = one(T) - sum_tmp + x[K] = one(T) - sum_tmp |> clamp0to1 else - x[K] = one(T) - sum_tmp - y[K] + x[K] = one(T) - sum_tmp - y[K] |> clamp0to1 end return x end @@ -246,17 +248,17 @@ function invlink( ϵ = _eps(T) @inbounds for n in 1:size(X, 2) sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] + log(one(T) / (K - 1))) - X[1, n] = (z - ϵ) / (one(T) - 2ϵ) + X[1, n] = (z - ϵ) / (one(T) - 2ϵ) |> clamp0to1 for k in 2:(K - 1) z = StatsFuns.logistic(Y[k, n] + log(one(T) / (K - k))) sum_tmp += X[k - 1] - X[k, n] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ + X[k, n] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ |> clamp0to1 end sum_tmp += X[K - 1, n] if proj - X[K, n] = one(T) - sum_tmp + X[K, n] = one(T) - sum_tmp |> clamp0to1 else - X[K, n] = one(T) - sum_tmp - Y[K, n] + X[K, n] = one(T) - sum_tmp - Y[K, n] |> clamp0to1 end end @@ -279,7 +281,7 @@ function logpdf_with_trans( lp += log(z + ϵ) + log(one(T) - z + ϵ) @inbounds for k in 2:(K - 1) sum_tmp += x[k-1] - z = x[k] / (one(T) - sum_tmp) + z = x[k] / (one(T) - sum_tmp + ϵ) lp += log(z + ϵ) + log(one(T) - z + ϵ) + log(one(T) - sum_tmp + ϵ) end end