diff --git a/Project.toml b/Project.toml index 5ede357c..fb5b93a1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "6.0.4" +version = "6.0.5" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -30,16 +30,20 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] AbstractMCMC = "0.4, 0.5, 1.0, 2.0, 3.0, 4, 5" AxisArrays = "0.4.4" +Dates = "<0.0.1, 1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" Formatting = "0.4" IteratorInterfaceExtensions = "0.1.1, 1" KernelDensity = "0.6.2" +LinearAlgebra = "<0.0.1, 1" MCMCDiagnosticTools = "0.3" MLJModelInterface = "0.3.5, 0.4, 1.0" NaturalSort = "1" OrderedCollections = "1.4" PrettyTables = "0.9, 0.10, 0.11, 0.12, 1, 2" +Random = "<0.0.1, 1" RecipesBase = "0.7, 0.8, 1.0" +Statistics = "<0.0.1, 1" StatsBase = "0.33.2, 0.34" StatsFuns = "0.8, 0.9, 1" TableTraits = "0.4, 1" diff --git a/docs/Project.toml b/docs/Project.toml index 9defdea2..e021ab6c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -10,7 +10,7 @@ MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" [compat] -CairoMakie = "0.6, 0.7, 0.8, 0.9, 0.10" +CairoMakie = "0.6, 0.7, 0.8, 0.9, 0.10, 0.11" CategoricalArrays = "0.8, 0.9, 0.10" DataFrames = "0.22, 1" Documenter = "0.26, 0.27, 1" diff --git a/src/plot.jl b/src/plot.jl index 1575846c..9ef7f561 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -64,7 +64,7 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag) # Chains are already appended in `c` if desired, hence we use `append_chains=false` ac = autocor(c; sections = nothing, lags = lags, append_chains=false) - ac_mat = convert(Array, ac) + ac_mat = convert(Array{Float64}, ac) val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :] _AutocorPlot(lags, val) elseif st ∈ supportedplots diff --git a/src/summarize.jl b/src/summarize.jl index 70bba5d8..53af6364 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -104,8 +104,11 @@ function Base.lastindex(c::ChainDataFrame, i::Integer) end end -function Base.convert(::Type{Array}, c::C) where C<:ChainDataFrame +function Base.convert(::Type{Array}, c::ChainDataFrame) T = promote_eltype_namedtuple_tail(c.nt) + return convert(Array{T}, c) +end +function Base.convert(::Type{Array{T}}, c::ChainDataFrame) where {T} arr = Array{T, 2}(undef, c.nrows, c.ncols - 1) for (i, k) in enumerate(Iterators.drop(keys(c.nt), 1)) @@ -115,9 +118,13 @@ function Base.convert(::Type{Array}, c::C) where C<:ChainDataFrame return arr end -function Base.convert(::Type{Array}, cs::Array{ChainDataFrame{T},1}) where T<:NamedTuple +function Base.convert(::Type{Array}, cs::Vector{ChainDataFrame{NamedTuple{K,V}}}) where {K,V} + T = promote_eltype_tuple_type(Base.tuple_type_tail(V)) + return convert(Array{T}, cs) +end +function Base.convert(::Type{Array{T}}, cs::Vector{<:ChainDataFrame}) where {T} return mapreduce((x, y) -> cat(x, y; dims = Val(3)), cs) do c - reshape(convert(Array, c), Val(3)) + reshape(convert(Array{T}, c), Val(3)) end end diff --git a/test/Project.toml b/test/Project.toml index 2883bba3..331b2b4a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,7 +11,6 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -25,17 +24,22 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [compat] AbstractMCMC = "2.2.1, 3.0, 4, 5" DataFrames = "0.22.4, 1.0" +Dates = "<0.0.1, 1" Distributions = "0.24.12, 0.25" Documenter = "0.26, 0.27, 1" FFTW = "1.1" IteratorInterfaceExtensions = "1" KernelDensity = "0.6.2" +Logging = "<0.0.1, 1" MCMCChains = "6" MLJBase = "0.18, 0.19, 0.20, 0.21, 1" MLJDecisionTreeInterface = "0.3, 0.4" +Random = "<0.0.1, 1" +Serialization = "<0.0.1, 1" StatsBase = "0.33.2, 0.34" StatsPlots = "0.14.17, 0.15" TableTraits = "1" Tables = "1.3.1" +Test = "<0.0.1, 1" UnicodePlots = "2, 3" julia = "1.6"