diff --git a/src/common.jl b/src/common.jl index 8effc7f27..4fdb36ecb 100644 --- a/src/common.jl +++ b/src/common.jl @@ -20,7 +20,7 @@ const Matrixvariate = ArrayLikeVariate{2} `F <: NamedTupleVariate{K}` specifies that the variate or a sample is of type `NamedTuple{K}`. """ -abstract type NamedTupleVariate{K} <: VariateForm end +struct NamedTupleVariate{K} <: VariateForm end """ `F <: CholeskyVariate` specifies that the variate or a sample is of type diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 3e77d0a41..4d6b4f03e 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -55,7 +55,7 @@ function _gentype(d::Distribution{CholeskyVariate}) end _gentype(::Distribution) = Any -_product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...) +_product_namedtuple_eltype(dists::NamedTuple{K,V}) where {K,V} = __product_promote_type(eltype, V) function Base.show(io::IO, d::ProductNamedTupleDistribution) return show_multline(io, d, collect(pairs(d.dists))) @@ -127,7 +127,7 @@ entropy(d::ProductNamedTupleDistribution) = sum(entropy, values(d.dists)) function kldivergence( d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution{K} ) where {K} - return mapreduce(kldivergence, +, d1.dists, d2.dists) + return sum(map(kldivergence, d1.dists, d2.dists)) end # Sampling