Skip to content

Commit

Permalink
implement mohamed's solution (#12#issuecomment-451416440)
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Jan 11, 2019
1 parent 9bc1e3c commit 521c9aa
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ using StatsFuns
using LinearAlgebra
using MappedArrays

export TransformDistribution,
export TransformDistribution,
RealDistribution,
PositiveDistribution,
UnitDistribution,
SimplexDistribution,
PDMatDistribution,
link,
invlink,
link,
invlink,
logpdf_with_trans

_eps(::Type{T}) where {T} = eps(T)
Expand Down Expand Up @@ -154,8 +154,8 @@ end
const SimplexDistribution = Union{Dirichlet}

function link(
d::SimplexDistribution,
x::AbstractVector{T},
d::SimplexDistribution,
x::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
y, K = similar(x), length(x)
Expand Down Expand Up @@ -183,8 +183,8 @@ end

# Vectorised implementation of the above.
function link(
d::SimplexDistribution,
X::AbstractMatrix{T},
d::SimplexDistribution,
X::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
Y, K, N = similar(X), size(X, 1), size(X, 2)
Expand All @@ -211,8 +211,8 @@ function link(
end

function invlink(
d::SimplexDistribution,
y::AbstractVector{T},
d::SimplexDistribution,
y::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
x, K = similar(y), length(y)
Expand All @@ -237,8 +237,8 @@ end

# Vectorised implementation of the above.
function invlink(
d::SimplexDistribution,
Y::AbstractMatrix{T},
d::SimplexDistribution,
Y::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
X, K, N = similar(Y), size(Y, 1), size(Y, 2)
Expand Down Expand Up @@ -276,11 +276,11 @@ function logpdf_with_trans(

sum_tmp = zero(eltype(x))
z = x[1]
lp += log(z + ϵ) + log(one(T) - z + ϵ)
lp += log(max(0, z + ϵ)) + log(max(0, one(T) - z + ϵ))
@inbounds for k in 2:(K - 1)
sum_tmp += x[k-1]
z = x[k] / (one(T) - sum_tmp)
lp += log(z + ϵ) + log(one(T) - z + ϵ) + log(one(T) - sum_tmp + ϵ)
lp += log(max(0, z + ϵ)) + log(max(0, one(T) - z + ϵ)) + log(max(0, one(T) - sum_tmp + ϵ))
end
end
return lp
Expand Down Expand Up @@ -332,8 +332,8 @@ function invlink(d::PDMatDistribution, Y::AbstractMatrix{T}) where {T<:Real}
end

function logpdf_with_trans(
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
transform::Bool
)
lp = logpdf(d, X)
Expand Down

0 comments on commit 521c9aa

Please sign in to comment.