Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update stansummary to use split, rank-normalized Rhat, ESS #1301

Merged
merged 10 commits into from
Nov 4, 2024
110 changes: 56 additions & 54 deletions src/cmdstan/diagnose.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#include <cmdstan/return_codes.hpp>
#include <cmdstan/stansummary_helper.hpp>
#include <stan/mcmc/chains.hpp>
#include <stan/mcmc/chainset.hpp>
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <ios>
#include <iostream>

double RHAT_MAX = 1.05;
using cmdstan::return_codes;

double RHAT_MAX = 1.01499; // round to 1.01

void diagnose_usage() {
std::cout << "USAGE: diagnose <filename 1> [<filename 2> ... <filename N>]"
Expand All @@ -26,7 +29,7 @@ void diagnose_usage() {
int main(int argc, const char *argv[]) {
if (argc == 1) {
diagnose_usage();
return 0;
return return_codes::OK;
}

// Parse any arguments specifying filenames
Expand All @@ -45,49 +48,47 @@ int main(int argc, const char *argv[]) {

if (!filenames.size()) {
std::cout << "No valid input files, exiting." << std::endl;
return 0;
return return_codes::NOT_OK;
}

std::cout << std::fixed << std::setprecision(2);

// Parse specified files
std::cout << "Processing csv files: " << filenames[0];
ifstream.open(filenames[0].c_str());

stan::io::stan_csv stan_csv
= stan::io::stan_csv_reader::parse(ifstream, &std::cout);
stan::mcmc::chains<> chains(stan_csv);
ifstream.close();

if (filenames.size() > 1)
std::cout << ", ";
else
std::cout << std::endl << std::endl;

for (std::vector<std::string>::size_type chain = 1; chain < filenames.size();
++chain) {
std::cout << filenames[chain];
ifstream.open(filenames[chain].c_str());
stan_csv = stan::io::stan_csv_reader::parse(ifstream, &std::cout);
chains.add(stan_csv);
ifstream.close();
if (chain < filenames.size() - 1)
std::cout << ", ";
else
std::cout << std::endl << std::endl;
std::vector<stan::io::stan_csv> csv_parsed;
for (int i = 0; i < filenames.size(); ++i) {
std::ifstream infile;
std::stringstream out;
stan::io::stan_csv sample;
infile.open(filenames[i].c_str());
try {
sample = stan::io::stan_csv_reader::parse(infile, &out);
// csv_reader warnings are errors - fail fast.
if (!out.str().empty()) {
throw std::invalid_argument(out.str());
}
csv_parsed.push_back(sample);
} catch (const std::invalid_argument &e) {
std::cout << "Cannot parse input csv file: " << filenames[i] << e.what()
<< "." << std::endl;
return return_codes::NOT_OK;
}
}

stan::mcmc::chainset chains(csv_parsed);
stan::io::stan_csv_metadata metadata = csv_parsed[0].metadata;
std::vector<std::string> param_names = csv_parsed[0].header;
size_t num_params = param_names.size();
int num_samples = chains.num_samples();
std::vector<std::string> bad_n_eff_names;
std::vector<std::string> bad_rhat_names;
bool has_errors = false;

for (int i = 0; i < chains.num_params(); ++i) {
if (chains.param_name(i) == std::string("treedepth__")) {
for (int i = 0; i < num_params; ++i) {
if (param_names[i] == std::string("treedepth__")) {
std::cout << "Checking sampler transitions treedepth." << std::endl;
int max_limit = stan_csv.metadata.max_depth;
int max_limit = metadata.max_depth;
long n_max = 0;
Eigen::VectorXd t_samples = chains.samples(i);
Eigen::MatrixXd draws = chains.samples(i);
Eigen::VectorXd t_samples
= Eigen::Map<Eigen::VectorXd>(draws.data(), draws.size());
for (long n = 0; n < t_samples.size(); ++n) {
if (t_samples(n) >= max_limit) {
++n_max;
Expand All @@ -109,7 +110,7 @@ int main(int argc, const char *argv[]) {
std::cout << "Treedepth satisfactory for all transitions." << std::endl
<< std::endl;
}
} else if (chains.param_name(i) == std::string("divergent__")) {
} else if (param_names[i] == std::string("divergent__")) {
std::cout << "Checking sampler transitions for divergences." << std::endl;
int n_divergent = chains.samples(i).sum();
if (n_divergent > 0) {
Expand All @@ -129,26 +130,22 @@ int main(int argc, const char *argv[]) {
std::cout << "No divergent transitions found." << std::endl
<< std::endl;
}
} else if (chains.param_name(i) == std::string("energy__")) {
} else if (param_names[i] == std::string("energy__")) {
std::cout << "Checking E-BFMI - sampler transitions HMC potential energy."
<< std::endl;
Eigen::VectorXd e_samples = chains.samples(i);
Eigen::MatrixXd draws = chains.samples(i);
Eigen::VectorXd e_samples
= Eigen::Map<Eigen::VectorXd>(draws.data(), draws.size());
double delta_e_sq_mean = 0;
double e_mean = 0;
double e_var = 0;
e_mean += e_samples(0);
e_var += e_samples(0) * (e_samples(0) - e_mean);
double e_mean = chains.mean(i);
double e_var = chains.variance(i);
for (long n = 1; n < e_samples.size(); ++n) {
double e = e_samples(n);
double delta_e_sq = (e - e_samples(n - 1)) * (e - e_samples(n - 1));
double d = delta_e_sq - delta_e_sq_mean;
delta_e_sq_mean += d / n;
d = e - e_mean;
e_mean += d / (n + 1);
e_var += d * (e - e_mean);
}

e_var /= static_cast<double>(e_samples.size() - 1);
double e_bfmi = delta_e_sq_mean / e_var;
double e_bfmi_threshold = 0.3;
if (e_bfmi < e_bfmi_threshold) {
Expand All @@ -163,14 +160,16 @@ int main(int argc, const char *argv[]) {
} else {
std::cout << "E-BFMI satisfactory." << std::endl << std::endl;
}
} else if (chains.param_name(i).find("__") == std::string::npos) {
double n_eff = chains.effective_sample_size(i);
} else if (param_names[i].find("__") == std::string::npos) {
auto [ess_bulk, ess_tail] = chains.split_rank_normalized_ess(i);
double n_eff = ess_bulk < ess_tail ? ess_bulk : ess_tail;
if (n_eff / num_samples < 0.001)
bad_n_eff_names.push_back(chains.param_name(i));
bad_n_eff_names.push_back(param_names[i]);

double split_rhat = chains.split_potential_scale_reduction(i);
auto [rhat_bulk, rhat_tail] = chains.split_rank_normalized_rhat(i);
double split_rhat = rhat_bulk > rhat_tail ? rhat_bulk : rhat_tail;
if (split_rhat > RHAT_MAX)
bad_rhat_names.push_back(chains.param_name(i));
bad_rhat_names.push_back(param_names[i]);
}
}
if (bad_n_eff_names.size() > 0) {
Expand All @@ -187,13 +186,15 @@ int main(int argc, const char *argv[]) {
<< " may be substantially lower than quoted." << std::endl
<< std::endl;
} else {
std::cout << "Effective sample size satisfactory." << std::endl
std::cout << "Rank-normalized split effective sample size satisfactory "
<< "for all parameters." << std::endl
<< std::endl;
}

if (bad_rhat_names.size() > 0) {
has_errors = true;
std::cout << "The following parameters had split R-hat greater than "
std::cout << "The following parameters had rank-normalized split R-hat "
"greater than "
<< RHAT_MAX << ":" << std::endl;
std::cout << " ";
for (size_t n = 0; n < bad_rhat_names.size() - 1; ++n)
Expand All @@ -207,13 +208,14 @@ int main(int argc, const char *argv[]) {
<< " effective parameterization." << std::endl
<< std::endl;
} else {
std::cout << "Split R-hat values satisfactory all parameters." << std::endl
std::cout << "Rank-normalized split R-hat values satisfactory "
<< "for all parameters." << std::endl
<< std::endl;
}
if (!has_errors)
std::cout << "Processing complete, no problems detected." << std::endl;
else
std::cout << "Processing complete." << std::endl;

return 0;
return return_codes::OK;
}
7 changes: 3 additions & 4 deletions src/cmdstan/stansummary.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <cmdstan/return_codes.hpp>
#include <cmdstan/stansummary_helper.hpp>
#include <stan/mcmc/chains.hpp>
#include <stan/io/ends_with.hpp>
#include <algorithm>
#include <fstream>
Expand Down Expand Up @@ -34,7 +33,7 @@ Example: stansummary model_chain_1.csv model_chain_2.csv
-c, --csv_filename [file] Write statistics to a csv file.
-h, --help Produce help message, then exit.
-p, --percentiles [values] Percentiles to report as ordered set of
comma-separated numbers from (0.1,99.9), inclusive.
comma-separated numbers from (0.0,100.0), inclusive.
Default is 5,50,95.
-s, --sig_figs [n] Significant figures reported. Default is 2.
Must be an integer from (1, 18), inclusive.
Expand Down Expand Up @@ -140,8 +139,8 @@ Example: stansummary model_chain_1.csv model_chain_2.csv

// check for stan csv file parse errors written to output stream
std::stringstream cout_ss;
stan::mcmc::chains<> chains = parse_csv_files(
filenames, metadata, warmup_times, sampling_times, thin, &std::cout);
auto chains = parse_csv_files(filenames, metadata, warmup_times,
sampling_times, thin, &std::cout);

// Get column headers for sampler, model params
size_t max_name_length = 0;
Expand Down
Loading