diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 28f06f39..f06cd02b 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -6,14 +6,14 @@ using StatsFuns using LinearAlgebra using MappedArrays -export TransformDistribution, +export TransformDistribution, RealDistribution, PositiveDistribution, UnitDistribution, SimplexDistribution, PDMatDistribution, - link, - invlink, + link, + invlink, logpdf_with_trans _eps(::Type{T}) where {T} = eps(T) @@ -154,8 +154,8 @@ end const SimplexDistribution = Union{Dirichlet} function link( - d::SimplexDistribution, - x::AbstractVector{T}, + d::SimplexDistribution, + x::AbstractVector{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} y, K = similar(x), length(x) @@ -183,8 +183,8 @@ end # Vectorised implementation of the above. function link( - d::SimplexDistribution, - X::AbstractMatrix{T}, + d::SimplexDistribution, + X::AbstractMatrix{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} Y, K, N = similar(X), size(X, 1), size(X, 2) @@ -211,8 +211,8 @@ function link( end function invlink( - d::SimplexDistribution, - y::AbstractVector{T}, + d::SimplexDistribution, + y::AbstractVector{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} x, K = similar(y), length(y) @@ -237,8 +237,8 @@ end # Vectorised implementation of the above. function invlink( - d::SimplexDistribution, - Y::AbstractMatrix{T}, + d::SimplexDistribution, + Y::AbstractMatrix{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} X, K, N = similar(Y), size(Y, 1), size(Y, 2) @@ -276,11 +276,11 @@ function logpdf_with_trans( sum_tmp = zero(eltype(x)) z = x[1] - lp += log(z + ϵ) + log(one(T) - z + ϵ) + lp += log(max(0, z + ϵ)) + log(max(0, one(T) - z + ϵ)) @inbounds for k in 2:(K - 1) sum_tmp += x[k-1] z = x[k] / (one(T) - sum_tmp) - lp += log(z + ϵ) + log(one(T) - z + ϵ) + log(one(T) - sum_tmp + ϵ) + lp += log(max(0, z + ϵ)) + log(max(0, one(T) - z + ϵ)) + log(max(0, one(T) - sum_tmp + ϵ)) end end return lp @@ -332,8 +332,8 @@ function invlink(d::PDMatDistribution, Y::AbstractMatrix{T}) where {T<:Real} end function logpdf_with_trans( - d::PDMatDistribution, - X::AbstractMatrix{<:Real}, + d::PDMatDistribution, + X::AbstractMatrix{<:Real}, transform::Bool ) lp = logpdf(d, X)