From 82def13b4ed4ae442d1fd655dfe3644b5113da74 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 May 2023 20:36:23 +0200 Subject: [PATCH] Forward missing arguments to `summarystats` (#424) * Forward missing arguments to `summarystats` * Fix test --- Project.toml | 2 +- src/stats.jl | 9 +++++---- src/summarize.jl | 4 ++-- test/summarize_tests.jl | 10 ++++++++++ 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 7835d05a..5c5bca73 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "6.0.2" +version = "6.0.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/stats.jl b/src/stats.jl index 7fd1e04d..e31044ea 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -289,6 +289,7 @@ function summarystats( append_chains::Bool = true, autocov_method::MCMCDiagnosticTools.AbstractAutocovMethod = AutocovMethod(), maxlag = 250, + name = "Summary Statistics", kwargs... ) # Store everything. @@ -339,10 +340,10 @@ function summarystats( # Summarize. summary_df = summarize( _chains, funs...; - func_names = func_names, - append_chains = append_chains, - additional_df = additional_df, - name = "Summary Statistics", + func_names, + append_chains, + additional_df, + name, sections = nothing ) diff --git a/src/summarize.jl b/src/summarize.jl index 08062baf..70bba5d8 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -122,7 +122,7 @@ function Base.convert(::Type{Array}, cs::Array{ChainDataFrame{T},1}) where T<:Na end """ - summarize(chains, funs...[; sections, func_names = []]) + summarize(chains, funs...[; sections, func_names = [], name = "", append_chains = true]) Summarize `chains` in a `ChainsDataFrame`. @@ -143,7 +143,7 @@ function summarize( ) # If we weren't given any functions, fall back to summary stats. if isempty(funs) - return summarystats(chains; sections = sections) + return summarystats(chains; sections, append_chains, name) end # Generate a chain to work on. diff --git a/test/summarize_tests.jl b/test/summarize_tests.jl index 741fccd3..d847ec43 100644 --- a/test/summarize_tests.jl +++ b/test/summarize_tests.jl @@ -26,8 +26,18 @@ using Statistics: std @test parm_df[[:a, :b], :][:,:parameters] == [:a, :b] all_sections_df = summarize(chns, sections=[:parameters, :internals]) + @test all_sections_df isa ChainDataFrame @test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h] @test size(all_sections_df) == (8, 8) + @test all_sections_df.name == "" + + all_sections_dfs = summarize(chns, sections=[:parameters, :internals], name = "Summary", append_chains = false) + @test all_sections_dfs isa Vector{<:ChainDataFrame} + for (i, all_sections_df) in enumerate(all_sections_dfs) + @test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h] + @test size(all_sections_df) == (8, 8) + @test all_sections_df.name == "Summary (Chain $i)" + end two_parms_two_funs_df = summarize(chns[[:a, :b]], mean, std) @test two_parms_two_funs_df[:, :parameters] == [:a, :b]