Skip to content

Commit

Permalink
Merge pull request #290 from JuliaStats/dh/show2
Browse files Browse the repository at this point in the history
Improved implementation of `show` for distributions
  • Loading branch information
lindahua committed Nov 1, 2014
2 parents d9fd2c4 + 13b79fe commit 4699002
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 39 deletions.
2 changes: 2 additions & 0 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ length(d::DirichletCanon) = length(d.alpha)
Base.convert(::Type{Dirichlet}, cf::DirichletCanon) = Dirichlet(cf.alpha)


Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,))

# Properties

length(d::Dirichlet) = length(d.alpha)
Expand Down
3 changes: 3 additions & 0 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ function GenericMvNormal{Cov<:AbstractPDMat}(Σ::Cov)
GenericMvNormal{Cov}(d, true, zeros(d), Σ)
end

Base.show(io::IO, d::GenericMvNormal) =
show_multline(io, d, [(:dim, d.dim), (, mean(d)), (, cov(d))])

## Construction of multivariate normal with specific covariance type

typealias IsoNormal GenericMvNormal{ScalMat}
Expand Down
101 changes: 62 additions & 39 deletions src/show.jl
Original file line number Diff line number Diff line change
@@ -1,48 +1,71 @@
function show(io::IO, d::Distribution)
@printf io "%s distribution\n" typeof(d)
for parameter in typeof(d).names
if isa(d.(parameter), AbstractArray)
param = string(ucfirst(string(parameter)),
":\n",
d.(parameter),
"\n")
else
param = string(ucfirst(string(parameter)),
": ",
d.(parameter),
"\n")


# the name of a distribution
#
# Generally, this should be just the type name, e.g. Normal.
# Under certain circumstances, one may want to specialize
# this function to provide a name that is easier to read,
# especially when the type is parametric.
#
distrname(d::Distribution) = string(typeof(d))

show(io::IO, d::Distribution) = show(io, d, typeof(d).names)

# For some distributions, the fields may contain internal details,
# which we don't want to show, this function allows one to
# specify which fields to show.
#
function show(io::IO, d::Distribution, pnames::(Symbol...))
# decide whether to use one-line or multi-line format
#
# Criteria: if total number of values is greater than 8, or
# there are matrix-valued params, we use multi-line format
#
namevals = (Symbol, Any)[]
multline = false
tlen = 0
for (i, p) in enumerate(pnames)
pv = d.(p)
if !(isa(pv, Number) || isa(pv, (Number...)) || isa(pv, AbstractVector))
multline = true
end
print(io, param)
tlen += length(pv)
push!(namevals, (p, pv))
end
if tlen > 8
multline = true
end

# call the function that actually does the job
multline ? show_multline(io, d, namevals) :
show_oneline(io, d, namevals)
end

function compact_show(io::IO, d::Distribution)
print(io, typeof(d))
print(io, "( ")
for parameter in typeof(d).names
print(io, string(parameter))
print(io, "=")
pv = d.(parameter)
if isa(pv, AbstractVector)
print(io, '[')
if !isempty(pv)
for i = 1 : length(pv)-1
print(io, pv[i])
print(io, ", ")
end
print(io, pv[end])
end
print(io, ']')
else
print(io, pv)
end
print(io, " ")
end
print(io, ")")
function show_oneline(io::IO, d::Distribution, namevals)
print(io, distrname(d))
np = length(namevals)
print(io, '(')
for (i, nv) in enumerate(namevals)
(p, pv) = nv
print(io, p)
print(io, '=')
show(io, pv)
if i < np
print(io, ", ")
end
end
print(io, ')')
end

function show(io::IO, d::UnivariateDistribution)
compact_show(io, d)
function show_multline(io::IO, d::Distribution, namevals)
print(io, distrname(d))
println(io, "(")
for (p, pv) in namevals
print(io, p)
print(io, ": ")
println(io, pv)
end
println(io, ")")
end


0 comments on commit 4699002

Please sign in to comment.