From 4d1858c655a21263c60ef8c5e22a4b77c7382cdb Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 8 Nov 2023 14:47:41 -0500 Subject: [PATCH] Prevent pathfinder when 0 parameters, warn if too many PSIS requested --- src/cmdstan/command.hpp | 12 ++++++++++++ src/test/interface/pathfinder_test.cpp | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index c42d96f79f..aaa3266278 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -306,6 +306,18 @@ int command(int argc, const char *argv[]) { int num_draws = get_arg_val(*pathfinder_arg, "num_draws"); int num_psis_draws = get_arg_val(*pathfinder_arg, "num_psis_draws"); + + if (num_psis_draws > num_draws * num_chains) { + logger.warn( + "Warning: Number of PSIS draws is larger than the total number of " + "draws returned by the single Pathfinders. This is likely " + "unintentional and leads to re-sampling from the same draws."); + } + if (model.num_params_r() == 0) { + throw std::invalid_argument( + "Model has 0 parameters, cannot run Pathfinder."); + } + if (num_chains == 1) { save_single_paths = save_single_paths || !diagnostic_file.empty(); return_code = stan::services::pathfinder::pathfinder_lbfgs_single< diff --git a/src/test/interface/pathfinder_test.cpp b/src/test/interface/pathfinder_test.cpp index 639c17ff59..9231c152a0 100644 --- a/src/test/interface/pathfinder_test.cpp +++ b/src/test/interface/pathfinder_test.cpp @@ -17,6 +17,7 @@ class CmdStan : public testing::Test { eight_schools_model = {"src", "test", "test-models", "eight_schools"}; eight_schools_data = {"src", "test", "test-models", "eight_schools.data.json"}; + empty_model = {"src", "test", "test-models", "empty"}; arg_output = {"test", "output"}; arg_diags = {"test", "diagnostics"}; output_csv = {"test", "output.csv"}; @@ -38,6 +39,7 @@ class CmdStan : public testing::Test { std::vector multi_normal_model; std::vector eight_schools_model; std::vector eight_schools_data; + std::vector empty_model; std::vector arg_output; std::vector arg_diags; std::vector output_csv; @@ -241,3 +243,21 @@ TEST_F(CmdStan, pathfinder_lbfgs_iterations) { EXPECT_EQ(1, count_matches("\"3\":{\"iter\":3,", output)); EXPECT_EQ(0, count_matches("\"4\":{\"iter\":4,", output)); } + +TEST_F(CmdStan, pathfinder_empty_model) { + std::stringstream ss; + ss << convert_model_path(empty_model) << " method=pathfinder"; + run_command_output out = run_command(ss.str()); + ASSERT_TRUE(out.hasError); + EXPECT_EQ(1, count_matches("Model has 0 parameters", out.output)); +} + +TEST_F(CmdStan, pathfinder_too_many_PSIS_draws) { + std::stringstream ss; + ss << convert_model_path(multi_normal_model) << " method=pathfinder" + << " num_paths=1 num_draws=10 num_psis_draws=11"; + run_command_output out = run_command(ss.str()); + ASSERT_FALSE(out.hasError); + EXPECT_EQ( + 1, count_matches("Warning: Number of PSIS draws is larger", out.output)); +}