Skip to content

Commit

Permalink
fix #12
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Jan 12, 2019
1 parent 9bc1e3c commit ecb7e94
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ function link(
return Y
end

clamp0to1(x::T) where T = clamp(x, zero(T), one(T))

function invlink(
d::SimplexDistribution,
y::AbstractVector{T},
Expand All @@ -219,18 +221,18 @@ function invlink(

ϵ = _eps(T)
z = StatsFuns.logistic(y[1] + log(one(T) / (K - 1)))
x[1] = (z - ϵ) / (one(T) - 2ϵ)
x[1] = (z - ϵ) / (one(T) - 2ϵ) |> clamp0to1
sum_tmp = zero(T)
@inbounds for k = 2:(K - 1)
z = StatsFuns.logistic(y[k] + log(one(T) / (K - k)))
sum_tmp += x[k-1]
x[k] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ
x[k] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ |> clamp0to1
end
sum_tmp += x[K - 1]
if proj
x[K] = one(T) - sum_tmp
x[K] = one(T) - sum_tmp |> clamp0to1
else
x[K] = one(T) - sum_tmp - y[K]
x[K] = one(T) - sum_tmp - y[K] |> clamp0to1
end
return x
end
Expand All @@ -246,17 +248,17 @@ function invlink(
ϵ = _eps(T)
@inbounds for n in 1:size(X, 2)
sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] + log(one(T) / (K - 1)))
X[1, n] = (z - ϵ) / (one(T) - 2ϵ)
X[1, n] = (z - ϵ) / (one(T) - 2ϵ) |> clamp0to1
for k in 2:(K - 1)
z = StatsFuns.logistic(Y[k, n] + log(one(T) / (K - k)))
sum_tmp += X[k - 1]
X[k, n] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ
X[k, n] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ |> clamp0to1
end
sum_tmp += X[K - 1, n]
if proj
X[K, n] = one(T) - sum_tmp
X[K, n] = one(T) - sum_tmp |> clamp0to1
else
X[K, n] = one(T) - sum_tmp - Y[K, n]
X[K, n] = one(T) - sum_tmp - Y[K, n] |> clamp0to1
end
end

Expand All @@ -279,7 +281,7 @@ function logpdf_with_trans(
lp += log(z + ϵ) + log(one(T) - z + ϵ)
@inbounds for k in 2:(K - 1)
sum_tmp += x[k-1]
z = x[k] / (one(T) - sum_tmp)
z = x[k] / (one(T) - sum_tmp + ϵ)
lp += log(z + ϵ) + log(one(T) - z + ϵ) + log(one(T) - sum_tmp + ϵ)
end
end
Expand Down

0 comments on commit ecb7e94

Please sign in to comment.