diff --git a/NEWS.md b/NEWS.md index b12c14fe1..1f53c57f1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,9 @@ ### Bug fixes +* Fixed reading inverse mass matrix with values written in scientific format in +the CSV. (#394) + ### New features * Added `$sample_mpi()` for MCMC sampling with MPI. (#350) diff --git a/R/read_csv.R b/R/read_csv.R index 3c279891d..0c2cddaac 100644 --- a/R/read_csv.R +++ b/R/read_csv.R @@ -388,8 +388,11 @@ read_csv_metadata <- function(csv_file) { inv_metric_next <- FALSE inv_metric_diagonal_next <- FALSE csv_file_info <- list() - inv_metric_rows <- 0 + csv_file_info$inv_metric <- NULL + inv_metric_rows_to_read <- -1 + inv_metric_rows <- -1 parsing_done <- FALSE + dense_inv_metric <- FALSE if (os_is_windows()) { grep_path <- repair_path(Sys.which("grep.exe")) fread_cmd <- paste0(grep_path, " '^[#a-zA-Z]' --color=never ", csv_file) @@ -422,26 +425,28 @@ read_csv_metadata <- function(csv_file) { } } else { parse_key_val <- TRUE - if (grepl("# Diagonal elements of inverse mass matrix:", line, perl = TRUE) - || grepl("# Elements of inverse mass matrix:", line, perl = TRUE)) { + if (grepl("# Diagonal elements of inverse mass matrix:", line, perl = TRUE)) { inv_metric_next <- TRUE parse_key_val <- FALSE + inv_metric_rows <- 1 + inv_metric_rows_to_read <- 1 + dense_inv_metric <- FALSE + } else if (grepl("# Elements of inverse mass matrix:", line, perl = TRUE)) { + inv_metric_next <- TRUE + parse_key_val <- FALSE + dense_inv_metric <- TRUE } else if (inv_metric_next) { inv_metric_split <- strsplit(gsub("# ", "", line), ",") - if ((length(inv_metric_split) == 0) || - ((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0))) || - grepl("[a-zA-z]", line, perl = TRUE) || - inv_metric_split == "#") { - parsing_done <- TRUE - parse_key_val <- TRUE - break; + numeric_inv_metric_split <- rapply(inv_metric_split, as.numeric) + if (inv_metric_rows == -1 && dense_inv_metric) { + inv_metric_rows <- length(inv_metric_split[[1]]) + inv_metric_rows_to_read <- inv_metric_rows } - if (inv_metric_rows == 0) { - csv_file_info$inv_metric <- rapply(inv_metric_split, as.numeric) - } else { - csv_file_info$inv_metric <- c(csv_file_info$inv_metric, rapply(inv_metric_split, as.numeric)) + csv_file_info$inv_metric <- c(csv_file_info$inv_metric, numeric_inv_metric_split) + inv_metric_rows_to_read <- inv_metric_rows_to_read - 1 + if (inv_metric_rows_to_read == 0) { + inv_metric_next <- FALSE } - inv_metric_rows <- inv_metric_rows + 1 parse_key_val <- FALSE } if (parse_key_val) { diff --git a/tests/testthat/resources/csv/model1-1-warmup.csv b/tests/testthat/resources/csv/model1-1-warmup.csv index c19c0aefd..d0ae22f73 100644 --- a/tests/testthat/resources/csv/model1-1-warmup.csv +++ b/tests/testthat/resources/csv/model1-1-warmup.csv @@ -140,7 +140,7 @@ lp__,accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__,mu,s # Adaptation terminated # Step size = 0.712907 # Diagonal elements of inverse mass matrix: -# 1.00098, 0.068748 +# 1.00098, 0.068748e-2 -19.4938,0.953779,0.712907,2,3,0,19.4971,8.11498,7.4563 -19.6889,0.983261,0.712907,1,1,0,19.8364,7.96487,7.78375 -18.0516,0.982462,0.712907,2,3,0,20.5179,8.24821,5.4579 @@ -241,8 +241,8 @@ lp__,accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__,mu,s -13.3724,1,0.712907,1,1,0,13.6117,5.58052,2.40945 -13.3724,0.292728,0.712907,2,3,0,17.7528,5.58052,2.40945 -13.348,0.991998,0.712907,2,3,0,13.5989,4.34492,2.68262 -# +# # Elapsed Time: 0.038029 seconds (Warm-up) # 0.030711 seconds (Sampling) # 0.06874 seconds (Total) -# +# diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index a22bc3526..5090c1e8a 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -160,7 +160,7 @@ test_that("read_cmdstan_csv() returns correct diagonal of inverse mass matrix", csv_files <- c(test_path("resources", "csv", "model1-1-warmup.csv"),test_path("resources", "csv", "model1-2-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) expect_equal(as.vector(csv_output$inv_metric[[as.character(1)]]), - c(1.00098, 0.068748)) + c(1.00098, 0.00068748)) expect_equal(as.vector(csv_output$inv_metric[[as.character(2)]]), c(0.909635, 0.066384)) }) @@ -296,7 +296,7 @@ test_that("read_cmdstan_csv() reads values up to adaptation", { csv_out <- read_cmdstan_csv(csv_files) expect_equal(csv_out$metadata$pi, 3.14) - expect_true(is.null(csv_out$metadata$pi_square)) + expect_false(is.null(csv_out$metadata$pi_square)) }) test_that("remaining_columns_to_read() works", {