Skip to content

Commit

Permalink
New implementation of Wishart
Browse files Browse the repository at this point in the history
  • Loading branch information
lindahua committed Nov 8, 2014
1 parent 9dca937 commit c988945
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 111 deletions.
6 changes: 6 additions & 0 deletions src/deprecates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ end
@Base.deprecate logpmf logpdf
@Base.deprecate logpmf! logpmf!
@Base.deprecate pmf pdf


#### Deprecate on 0.6 (to be removed on 0.7)

@Base.deprecate expected_logdet meanlogdet

22 changes: 0 additions & 22 deletions src/matrix/inversewishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,3 @@ function rand!(IW::InverseWishart, X::Array{Matrix{Float64}})
end

var(IW::InverseWishart) = error("Not yet implemented")

# because X == X' keeps failing due to floating point nonsense
function isApproxSymmmetric(a::Matrix{Float64})
tmp = true
for j in 2:size(a, 1)
for i in 1:(j - 1)
tmp &= abs(a[i, j] - a[j, i]) < 1e-8
end
end
return tmp
end

# because isposdef keeps giving the wrong answer for samples
# from Wishart and InverseWisharts
hasCholesky(a::Matrix{Float64}) = isa(trycholfact(a), Cholesky)

function trycholfact(a::Matrix{Float64})
try cholfact(a)
catch e
return e
end
end
149 changes: 76 additions & 73 deletions src/matrix/wishart.jl
Original file line number Diff line number Diff line change
@@ -1,96 +1,99 @@
##############################################################################
# Wishart distribution
#
# Wishart Distribution
# following the Wikipedia parameterization
#
# Parameters nu and S such that E(X) = nu * S
# See the rwish and dwish implementation in R's MCMCPack
# This parametrization differs from Bernardo & Smith p 435
# in this way: (nu, S) = (2.0 * alpha, 0.5 * beta^-1)
#
##############################################################################

immutable Wishart <: ContinuousMatrixDistribution
nu::Float64
Schol::Cholesky{Float64}
function Wishart(n::Real, Sc::Cholesky{Float64})
if n > size(Sc, 1) - 1
new(float64(n), Sc)
else
error("Wishart parameters must be df > p - 1")
end
end

immutable Wishart{ST<:AbstractPDMat} <: ContinuousMatrixDistribution
df::Float64 # degree of freedom
S::ST # the scale matrix
c0::Float64 # the logarithm of normalizing constant in pdf
end

Wishart(nu::Real, S::Matrix{Float64}) = Wishart(nu, cholfact(S))
#### Constructors

show(io::IO, d::Wishart) = show_multline(io, d, [(:nu, d.nu), (:S, full(d.Schol))])
function Wishart{ST<:AbstractPDMat}(df::Real, S::ST)
p = dim(S)
df > p - 1 || error("df should be greater than dim - 1.")
Wishart{ST}(df, S, _wishart_c0(df, S))
end

Wishart(df::Real, S::Matrix{Float64}) = Wishart(df, PDMat(S))

dim(W::Wishart) = size(W.Schol, 1)
size(W::Wishart) = size(W.Schol)
Wishart(df::Real, S::Cholesky) = Wishart(df, PDMat(S))

function insupport(W::Wishart, X::Matrix{Float64})
return size(X) == size(W) && isApproxSymmmetric(X) && hasCholesky(X)
end
# This just checks if X could come from any Wishart
function insupport(::Type{Wishart}, X::Matrix{Float64})
return size(X, 1) == size(X, 2) && isApproxSymmmetric(X) && hasCholesky(X)
function _wishart_c0(df::Float64, S::AbstractPDMat)
h_df = df / 2
p = dim(S)
h_df * (logdet(S) + p * logtwo) + lpgamma(p, h_df)
end

mean(w::Wishart) = w.nu * (w.Schol[:U]' * w.Schol[:U])

function expected_logdet(W::Wishart)
logd = 0.
d = dim(W)
#### Properties

for i=1:d
logd += digamma(0.5 * (W.nu + 1 - i))
end
insupport(::Type{Wishart}, X::Matrix{Float64}) = isposdef(X)
insupport(d::Wishart, X::Matrix{Float64}) = size(X) == size(d) && isposdef(X)

logd += d * log(2)
logd += logdet(W.Schol)
dim(d::Wishart) = dim(d.S)
size(d::Wishart) = (p = dim(d); (p, p))

return logd
end

function lognorm(W::Wishart)
d = dim(W)
return (W.nu / 2) * logdet(W.Schol) + (d * W.nu / 2) * log(2) + lpgamma(d, W.nu / 2)
end
#### Show

show(io::IO, d::Wishart) = show_multline(io, d, [(:df, d.df), (:S, full(d.S))])


#### Statistics

mean(d::Wishart) = d.df * full(d.S)

function _logpdf{T<:Real}(W::Wishart, X::DenseMatrix{T})
Xchol = trycholfact(X)
if size(X) == size(W) && isApproxSymmmetric(X) && isa(Xchol, Cholesky)
d = dim(W)
logd = -lognorm(W)
logd += 0.5 * (W.nu - d - 1.0) * logdet(Xchol)
logd -= 0.5 * trace(W.Schol \ X)
return logd
else
return -Inf
function meanlogdet(d::Wishart)
p = dim(d)
df = d.df
v = logdet(d.S) + p * logtwo
for i = 1:p
v += digamma(0.5 * (df - (i - 1)))
end
return v
end

function rand(w::Wishart)
p = size(w.Schol, 1)
X = zeros(p, p)
for ii in 1:p
X[ii, ii] = sqrt(rand(Chisq(w.nu - ii + 1)))
end
if p > 1
for col in 2:p
for row in 1:(col - 1)
X[row, col] = randn()
end
end
end
Z = X * w.Schol[:U]
return At_mul_B(Z, Z)
function entropy(d::Wishart)
p = dim(d)
df = d.df
d.c0 - 0.5 * (df - p - 1) * meanlogdet(d) + 0.5 * df * p
end

function entropy(W::Wishart)
d = dim(W)
return lognorm(W) - (W.nu - d - 1) / 2 * expected_logdet(W) + W.nu * d / 2

#### Evaluation

function _logpdf(d::Wishart, X::DenseMatrix{Float64})
Xcf = cholfact(X)
df = d.df
p = dim(d)
0.5 * ((df - (p + 1)) * logdet(Xcf) - trace(d.S \ X)) - d.c0
end

var(w::Wishart) = error("Not yet implemented")

#### Sampling

function rand(d::Wishart)
Z = unwhiten!(d.S, _wishart_genA(dim(d), d.df))
A_mul_Bt(Z, Z)
end

function _wishart_genA(p::Int, df::Float64)
# Generate the matrix A in the Bartlett decomposition
#
# A is a lower triangular matrix, with
#
# A(i, j) ~ sqrt of Chisq(df - i + 1) when i == j
# ~ Normal() when i > j
#
A = zeros(p, p)
for i = 1:p
@inbounds A[i,i] = sqrt(rand(Chisq(df - i + 1.0)))
end
for j = 1:p-1, i = j+1:p
@inbounds A[i,j] = randn()
end
return A
end
26 changes: 26 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,29 @@ function simpson(f::AbstractVector{Float64}, h::Float64)
return s * h / 3.0
end


# because X == X' keeps failing due to floating point nonsense
function isApproxSymmmetric(a::Matrix{Float64})
tmp = true
for j in 2:size(a, 1)
for i in 1:(j - 1)
tmp &= abs(a[i, j] - a[j, i]) < 1e-8
end
end
return tmp
end

# because isposdef keeps giving the wrong answer for samples
# from Wishart and InverseWisharts
hasCholesky(a::Matrix{Float64}) = isa(trycholfact(a), Cholesky)

function trycholfact(a::Matrix{Float64})
try cholfact(a)
catch e
return e
end
end




1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ tests = [
"conjugates",
"conjugates_normal",
"conjugates_mvnormal",
"wishart",
"mixture",
"gradlogpdf"]

Expand Down
15 changes: 0 additions & 15 deletions test/wishart.jl

This file was deleted.

0 comments on commit c988945

Please sign in to comment.