From 719e3e163c74db2d813d989b23563a2c73ff4f84 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Fri, 5 Jan 2024 12:57:59 -0600 Subject: [PATCH 01/43] move common weight-processing code into validate_weights --- R/weight_draws.R | 52 ++++++++++++------------------------------------ 1 file changed, 13 insertions(+), 39 deletions(-) diff --git a/R/weight_draws.R b/R/weight_draws.R index fa8bfd8..6ae12f9 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -56,14 +56,7 @@ weight_draws <- function(x, weights, ...) { #' @rdname weight_draws #' @export weight_draws.draws_matrix <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { - - - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) - } + log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, ".log_weight"] <- log_weights @@ -78,13 +71,7 @@ weight_draws.draws_matrix <- function(x, weights, log = FALSE, pareto_smooth = F #' @rdname weight_draws #' @export weight_draws.draws_array <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { - - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) - } + log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, , ".log_weight"] <- log_weights @@ -99,27 +86,14 @@ weight_draws.draws_array <- function(x, weights, log = FALSE, pareto_smooth = FA #' @rdname weight_draws #' @export weight_draws.draws_df <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { - - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) - } - x$.log_weight <- log_weights + x$.log_weight <- validate_weights(weights, ndraws(x), log, pareto_smooth) x } #' @rdname weight_draws #' @export weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { - - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) - } + log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) niterations <- niterations(x) for (i in seq_len(nchains(x))) { sel <- (1 + (i - 1) * niterations):(i * niterations) @@ -131,13 +105,7 @@ weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FAL #' @rdname weight_draws #' @export weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { - - pareto_smooth <- as_one_logical(pareto_smooth) - log <- as_one_logical(log) - log_weights <- validate_weights(weights, x, log = log) - if (pareto_smooth) { - log_weights <- pareto_smooth_log_weights(log_weights) - } + log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) x$.log_weight <- rvar(log_weights) x } @@ -178,10 +146,12 @@ weights.draws <- function(object, log = FALSE, normalize = TRUE, ...) { } # validate weights and return log weights -validate_weights <- function(weights, draws, log = FALSE) { +validate_weights <- function(weights, ndraws, log = FALSE, pareto_smooth = FALSE) { checkmate::assert_numeric(weights) checkmate::assert_flag(log) - if (length(weights) != ndraws(draws)) { + checkmate::assert_flag(pareto_smooth) + + if (length(weights) != ndraws) { stop_no_call("Number of weights must match the number of draws.") } if (!log) { @@ -190,6 +160,10 @@ validate_weights <- function(weights, draws, log = FALSE) { } weights <- log(weights) } + if (pareto_smooth) { + weights <- pareto_smooth_log_weights(weights) + } + weights } From c1c95957f92deca4db0c86308722b2285fadd2f2 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Fri, 5 Jan 2024 13:31:08 -0600 Subject: [PATCH 02/43] basic setup of weighting for rvars --- R/rvar-.R | 54 ++++++++++++++++++++++++++++++++++++++++------------- man/rvar.Rd | 4 +++- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/R/rvar-.R b/R/rvar-.R index 129f357..aa5da1e 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -90,7 +90,11 @@ #' x #' #' @export -rvar <- function(x = double(), dim = NULL, dimnames = NULL, nchains = NULL, with_chains = FALSE) { +rvar <- function( + x = double(), dim = NULL, dimnames = NULL, + nchains = NULL, with_chains = FALSE, + weights = NULL, log = FALSE +) { if (is_rvar(x)) { nchains <- nchains %||% nchains(x) with_chains = FALSE @@ -105,7 +109,7 @@ rvar <- function(x = double(), dim = NULL, dimnames = NULL, nchains = NULL, with nchains <- nchains %||% 1L } - out <- new_rvar(x, .nchains = nchains) + out <- new_rvar(x, .nchains = nchains, weights = weights, log = log) if (!is.null(dim)) { dim(out) <- dim @@ -118,7 +122,7 @@ rvar <- function(x = double(), dim = NULL, dimnames = NULL, nchains = NULL, with } #' @importFrom vctrs new_vctr -new_rvar <- function(x = double(), .nchains = 1L) { +new_rvar <- function(x = double(), .nchains = 1L, weights = NULL, log = FALSE) { if (is.null(x)) { x <- double() } @@ -128,11 +132,17 @@ new_rvar <- function(x = double(), .nchains = 1L) { .ndraws <- dim(x)[[1]] .nchains <- as_one_integer(.nchains) check_nchains_compat_with_ndraws(.nchains, .ndraws) + if (is.null(weights)) { + log_weight <- NULL + } else { + log_weight <- validate_weights(weights, .ndraws, log = log, pareto_smooth = FALSE) + } structure( list(), draws = x, nchains = .nchains, + log_weight = log_weight, class = get_rvar_class(x), cache = new.env(parent = emptyenv()) ) @@ -252,14 +262,14 @@ rep.rvar <- function(x, times = 1, length.out = NA, each = 1, ...) { dim = dim(draws) dim[[2]] = dim[[2]] * times dim(rep_draws) = dim - out <- new_rvar(rep_draws, .nchains = nchains(x)) + draws_of(x) <- rep_draws } else { # use `length.out` rep_draws = rep_len(draws, length.out * ndraws(x)) dim(rep_draws) = c(ndraws(x), length(rep_draws) / ndraws(x)) - out <- new_rvar(rep_draws, .nchains = nchains(x)) + draws_of(x) <- rep_draws } - out + x } #' @rawNamespace S3method(rep.int,rvar,rep_int_rvar) @@ -537,6 +547,23 @@ nchains2_common <- function(nchains_x, nchains_y) { } } +# find common weights for two rvars +weights2_common <- function(weights_x, weights_y) { + if (is.null(weights_x)) { + weights_y + } else if (is.null(weights_y)) { + weights_x + } else if (identical(weights_x, weights_y)) { + weights_x + } else { + stop_no_call( + "Random variables have different weights and cannot be used together:\n", + "<", vctrs::vec_ptype_abbr(weights_x), "> ", paste(head(weights_x, 5), collapse = ", "), " ...\n", + "<", vctrs::vec_ptype_abbr(weights_y), "> ", paste(head(weights_y, 5), collapse = ", "), " ..." + ) + } +} + # check that the given number of chains is compatible with the given number of draws check_nchains_compat_with_ndraws <- function(nchains, ndraws) { # except with constants, nchains must divide the number of draws @@ -548,7 +575,7 @@ check_nchains_compat_with_ndraws <- function(nchains, ndraws) { } } -# given two rvars, conform their number of chains +# given a list of rvars, conform their number of chains # so they can be used together (or throw an error if they can't be) conform_rvar_nchains <- function(rvars) { # find the number of chains to use, treating constants as having any number of chains @@ -562,15 +589,16 @@ conform_rvar_nchains <- function(rvars) { rvars } -# given two rvars, conform their number of draws +# given a list of rvars, conform their number of draws # so they can be used together (or throw an error if they can't be) # @param keep_constants keep constants as 1-draw rvars conform_rvar_ndraws <- function(rvars, keep_constants = FALSE) { - # broadcast to a common number of chains. If keep_constants = TRUE, - # constants will not be broadcast. + # broadcast to a common number of draws and the same set of weights. + # If keep_constants = TRUE, constants will not be broadcast or re-weighted. .ndraws = Reduce(ndraws2_common, lapply(rvars, ndraws)) + log_weight = Reduce(weights2_common, lapply(rvars, attr, "log_weight")) for (i in seq_along(rvars)) { - rvars[[i]] <- broadcast_draws(rvars[[i]], .ndraws, keep_constants) + rvars[[i]] <- broadcast_draws(rvars[[i]], .ndraws, keep_constants, log_weight = log_weight) } rvars @@ -726,7 +754,7 @@ broadcast_array <- function(x, dim, broadcast_scalars = TRUE) { } # broadcast the draws dimension of an rvar to the requested size -broadcast_draws <- function(x, .ndraws, keep_constants = FALSE) { +broadcast_draws <- function(x, .ndraws, keep_constants = FALSE, log_weight = NULL) { ndraws_x = ndraws(x) if ( (ndraws_x == 1 && keep_constants) || @@ -738,7 +766,7 @@ broadcast_draws <- function(x, .ndraws, keep_constants = FALSE) { new_dim <- dim(draws) new_dim[1] <- .ndraws - new_rvar(broadcast_array(draws, new_dim), .nchains = nchains(x)) + new_rvar(broadcast_array(draws, new_dim), .nchains = nchains(x), weights = log_weight, log = TRUE) } } diff --git a/man/rvar.Rd b/man/rvar.Rd index 3c144c7..9782c3e 100755 --- a/man/rvar.Rd +++ b/man/rvar.Rd @@ -9,7 +9,9 @@ rvar( dim = NULL, dimnames = NULL, nchains = NULL, - with_chains = FALSE + with_chains = FALSE, + weights = NULL, + log = FALSE ) } \arguments{ From 243492a55aade1ce941e30bf694221e192d28df1 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Fri, 5 Jan 2024 14:56:42 -0600 Subject: [PATCH 03/43] add log_weights() to get internal weights for easier programming --- NAMESPACE | 5 +++++ R/rvar-.R | 30 +++++++++++++++--------------- R/rvar-bind.R | 6 ++++-- R/rvar-cast.R | 5 ++++- R/rvar-factor.R | 32 ++++++++++++++++++++++++++++---- R/rvar-math.R | 29 +++++++++++++++++------------ R/rvar-rfun.R | 7 +++++-- R/rvar-slice.R | 9 +++++---- R/weight_draws.R | 44 ++++++++++++++++++++++++++++++++++++++++---- man/rvar.Rd | 3 +-- man/weight_draws.Rd | 6 ++++++ man/weights.draws.Rd | 14 ++++++++++++++ 12 files changed, 144 insertions(+), 46 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 2d81d65..fa735ef 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -165,6 +165,8 @@ S3method(iteration_ids,draws_rvars) S3method(iteration_ids,rvar) S3method(length,rvar) S3method(levels,rvar) +S3method(log_weights,draws) +S3method(log_weights,rvar) S3method(mad,default) S3method(mad,rvar) S3method(mad,rvar_ordered) @@ -382,7 +384,9 @@ S3method(weight_draws,draws_df) S3method(weight_draws,draws_list) S3method(weight_draws,draws_matrix) S3method(weight_draws,draws_rvars) +S3method(weight_draws,rvar) S3method(weights,draws) +S3method(weights,rvar) export("%**%") export("%in%") export("draws_of<-") @@ -441,6 +445,7 @@ export(is_rvar) export(is_rvar_factor) export(is_rvar_ordered) export(iteration_ids) +export(log_weights) export(mad) export(match) export(mcse_mean) diff --git a/R/rvar-.R b/R/rvar-.R index aa5da1e..9dcf23a 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -93,7 +93,7 @@ rvar <- function( x = double(), dim = NULL, dimnames = NULL, nchains = NULL, with_chains = FALSE, - weights = NULL, log = FALSE + log_weights = NULL ) { if (is_rvar(x)) { nchains <- nchains %||% nchains(x) @@ -109,7 +109,7 @@ rvar <- function( nchains <- nchains %||% 1L } - out <- new_rvar(x, .nchains = nchains, weights = weights, log = log) + out <- new_rvar(x, .nchains = nchains, .log_weights = log_weights) if (!is.null(dim)) { dim(out) <- dim @@ -122,7 +122,7 @@ rvar <- function( } #' @importFrom vctrs new_vctr -new_rvar <- function(x = double(), .nchains = 1L, weights = NULL, log = FALSE) { +new_rvar <- function(x = double(), .nchains = 1L, .log_weights = NULL) { if (is.null(x)) { x <- double() } @@ -132,17 +132,15 @@ new_rvar <- function(x = double(), .nchains = 1L, weights = NULL, log = FALSE) { .ndraws <- dim(x)[[1]] .nchains <- as_one_integer(.nchains) check_nchains_compat_with_ndraws(.nchains, .ndraws) - if (is.null(weights)) { - log_weight <- NULL - } else { - log_weight <- validate_weights(weights, .ndraws, log = log, pareto_smooth = FALSE) + if (!is.null(.log_weights)) { + .log_weights <- validate_weights(.log_weights, .ndraws, log = TRUE, pareto_smooth = FALSE) } structure( list(), draws = x, nchains = .nchains, - log_weight = log_weight, + log_weights = .log_weights, class = get_rvar_class(x), cache = new.env(parent = emptyenv()) ) @@ -557,7 +555,7 @@ weights2_common <- function(weights_x, weights_y) { weights_x } else { stop_no_call( - "Random variables have different weights and cannot be used together:\n", + "Random variables have different log weights and cannot be used together:\n", "<", vctrs::vec_ptype_abbr(weights_x), "> ", paste(head(weights_x, 5), collapse = ", "), " ...\n", "<", vctrs::vec_ptype_abbr(weights_y), "> ", paste(head(weights_y, 5), collapse = ", "), " ..." ) @@ -596,9 +594,9 @@ conform_rvar_ndraws <- function(rvars, keep_constants = FALSE) { # broadcast to a common number of draws and the same set of weights. # If keep_constants = TRUE, constants will not be broadcast or re-weighted. .ndraws = Reduce(ndraws2_common, lapply(rvars, ndraws)) - log_weight = Reduce(weights2_common, lapply(rvars, attr, "log_weight")) + .log_weights = Reduce(weights2_common, lapply(rvars, log_weights)) for (i in seq_along(rvars)) { - rvars[[i]] <- broadcast_draws(rvars[[i]], .ndraws, keep_constants, log_weight = log_weight) + rvars[[i]] <- broadcast_draws(rvars[[i]], .ndraws, keep_constants, .log_weights = .log_weights) } rvars @@ -754,7 +752,7 @@ broadcast_array <- function(x, dim, broadcast_scalars = TRUE) { } # broadcast the draws dimension of an rvar to the requested size -broadcast_draws <- function(x, .ndraws, keep_constants = FALSE, log_weight = NULL) { +broadcast_draws <- function(x, .ndraws, keep_constants = FALSE, .log_weights = NULL) { ndraws_x = ndraws(x) if ( (ndraws_x == 1 && keep_constants) || @@ -766,7 +764,7 @@ broadcast_draws <- function(x, .ndraws, keep_constants = FALSE, log_weight = NUL new_dim <- dim(draws) new_dim[1] <- .ndraws - new_rvar(broadcast_array(draws, new_dim), .nchains = nchains(x), weights = log_weight, log = TRUE) + new_rvar(broadcast_array(draws, new_dim), .nchains = nchains(x), .log_weights = .log_weights) } } @@ -968,7 +966,8 @@ summarise_rvar_within_draws <- function(x, .f, ..., .transpose = FALSE, .when_em } else { draws <- apply(draws, 1, .f, ...) if (.transpose) draws <- t(draws) - new_rvar(draws, .nchains = nchains(x)) + draws_of(x) <- draws + x } } @@ -999,7 +998,8 @@ summarise_rvar_within_draws_via_matrix <- function(x, .name, .f, ..., .ordered_o .draws <- .f(draws_of(x), ...) } - new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws + x } # apply vectorized function to an rvar's draws diff --git a/R/rvar-bind.R b/R/rvar-bind.R index 55c1a58..947dc8b 100755 --- a/R/rvar-bind.R +++ b/R/rvar-bind.R @@ -89,9 +89,10 @@ broadcast_and_bind_rvars.rvar <- function(x, y, axis = 1) { draws_axis <- axis + 1 # because first dim is draws - # conform nchains + # conform nchains and weights # (don't need to do draws here since that's part of the broadcast below) c(x, y) %<-% conform_rvar_nchains(list(x, y)) + log_weights <- weights2_common(log_weights(x), log_weights(y)) # broadcast each array to the desired dimensions # (except along the axis we are binding along) @@ -112,7 +113,8 @@ broadcast_and_bind_rvars.rvar <- function(x, y, axis = 1) { # bind along desired axis result <- new_rvar( abind(draws_x, draws_y, along = draws_axis, use.dnns = TRUE), - .nchains = nchains(x) + .nchains = nchains(x), + .log_weights = log_weights ) } diff --git a/R/rvar-cast.R b/R/rvar-cast.R index 1132894..43d6197 100755 --- a/R/rvar-cast.R +++ b/R/rvar-cast.R @@ -210,6 +210,7 @@ vec_proxy.rvar = function(x, ...) { #' @noRd make_rvar_proxy = function(x) { nchains <- nchains(x) + log_weights <- log_weights(x) draws <- draws_of(x) is <- seq_len(NROW(x)) names(is) <- rownames(x) @@ -217,6 +218,7 @@ make_rvar_proxy = function(x) { list( index = i, nchains = nchains, + log_weights = log_weights, draws = draws ) }) @@ -251,7 +253,7 @@ vec_restore.rvar <- function(x, ...) { groups <- split(x, draws_groups) rvars <- lapply(groups, function(x) { i <- vapply(x, `[[`, "index", FUN.VALUE = numeric(1)) - rvar <- new_rvar(x[[1]]$draws, .nchains = x[[1]]$nchains) + rvar <- new_rvar(x[[1]]$draws, .nchains = x[[1]]$nchains, .log_weights = x[[1]]$log_weights) if (length(dim(rvar)) > 1) { rvar[i, ] } else { @@ -318,6 +320,7 @@ vec_proxy_equal.rvar = function(x, ...) { make_rvar_proxy_equal = function(x) { lapply(as.list(x), function(x) list( nchains = nchains(x), + log_weights = log_weights(x), draws = draws_of(x) )) } diff --git a/R/rvar-factor.R b/R/rvar-factor.R index 1b985a0..324c2ba 100644 --- a/R/rvar-factor.R +++ b/R/rvar-factor.R @@ -61,7 +61,13 @@ #' #' @export rvar_factor <- function( - x = factor(), dim = NULL, dimnames = NULL, nchains = NULL, with_chains = FALSE, ... + x = factor(), + dim = NULL, + dimnames = NULL, + nchains = NULL, + with_chains = FALSE, + log_weights = NULL, + ... ) { # to ensure we pick up levels already attached to x (if there are any), we @@ -71,7 +77,12 @@ rvar_factor <- function( } out <- rvar( - x, dim = dim, dimnames = dimnames, nchains = nchains, with_chains = with_chains + x, + dim = dim, + dimnames = dimnames, + nchains = nchains, + with_chains = with_chains, + log_weights = log_weights ) .rvar_to_rvar_factor(out, ...) } @@ -79,11 +90,24 @@ rvar_factor <- function( #' @rdname rvar_factor #' @export rvar_ordered <- function( - x = ordered(NULL), dim = NULL, dimnames = NULL, nchains = NULL, with_chains = FALSE, ... + x = ordered(NULL), + dim = NULL, + dimnames = NULL, + nchains = NULL, + with_chains = FALSE, + log_weights = NULL, + ... ) { rvar_factor( - x, dim = dim, dimnames = dimnames, nchains = nchains, with_chains = with_chains, ordered = TRUE, ... + x, + dim = dim, + dimnames = dimnames, + nchains = nchains, + with_chains = with_chains, + log_weights = log_weights, + ordered = TRUE, + ... ) } diff --git a/R/rvar-math.R b/R/rvar-math.R index 11a8ca3..550752c 100755 --- a/R/rvar-math.R +++ b/R/rvar-math.R @@ -15,6 +15,7 @@ Ops.rvar <- function(e1, e2) { .Ops.rvar <- function(f, e1, e2, preserve_dims = FALSE) { c(e1, e2) %<-% conform_rvar_nchains(list(e1, e2)) + .log_weights <- weights2_common(log_weights(e1), log_weights(e2)) draws_x <- draws_of(e1) draws_y <- draws_of(e2) @@ -47,7 +48,7 @@ Ops.rvar <- function(e1, e2) { draws <- while_preserving_dims(function(...) draws, dim_source) } - new_rvar(draws, .nchains = nchains(e1)) + new_rvar(draws, .nchains = nchains(e1), .log_weights = .log_weights) } #' @export @@ -95,10 +96,12 @@ Math.rvar <- function(x, ...) { if (.Generic %in% c("cumsum", "cumprod", "cummax", "cummin")) { # cumulative functions need to be handled differently # from other functions in this generic - new_rvar(t(apply(draws_of(x), 1, f)), .nchains = nchains(x)) + draws_of(x) <- t(apply(draws_of(x), 1, f)) } else { - new_rvar(f(draws_of(x), ...), .nchains = nchains(x)) + draws_of(x) <- f(draws_of(x), ...) } + + x } #' @export @@ -206,7 +209,7 @@ Math.rvar_factor <- function(x, ...) { result <- copy_dimnames(draws_of(x), 1:2, result, 1:2) result <- copy_dimnames(draws_of(y), 3, result, 3) - new_rvar(result, .nchains = nchains(x)) + new_rvar(result, .nchains = nchains(x), .log_weights = log_weights(x)) } # This generic is not exported here as matrixOps is only in R >= 4.3, so we must @@ -246,16 +249,17 @@ chol.rvar <- function(x, ...) { x_tensor <- as.tensor(aperm(draws_of(x), c(2,3,1))) # do the cholesky decomp - result <- unclass(chol.tensor(x_tensor, 1, 2, ...)) + out_draws <- unclass(chol.tensor(x_tensor, 1, 2, ...)) # move draws dimension back to the front - result <- aperm(result, c(3,1,2)) + out_draws <- aperm(out_draws, c(3,1,2)) # drop dimension names (chol.tensor screws them around) - names(dim(result)) <- NULL - dimnames(result) <- NULL + names(dim(out_draws)) <- NULL + dimnames(out_draws) <- NULL - new_rvar(result, .nchains = nchains(x)) + draws_of(x) <- out_draws + x } #' @importFrom methods setGeneric @@ -334,14 +338,15 @@ t.rvar = function(x) { .dimnames = dimnames(.draws) dim(.draws) = c(dim(.draws)[1], 1, dim(.draws)[2]) dimnames(.draws) = c(.dimnames[1], list(NULL), .dimnames[2]) - result <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws } else if (ndim == 3) { .draws <- while_preserving_levels(aperm, .draws, c(1, 3, 2)) - result <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws } else { stop_no_call("argument is not a random vector or matrix") } - result + + x } #' @export diff --git a/R/rvar-rfun.R b/R/rvar-rfun.R index 2793198..ac6512b 100755 --- a/R/rvar-rfun.R +++ b/R/rvar-rfun.R @@ -85,6 +85,7 @@ rfun <- function (.f, rvar_args = NULL, rvar_dots = TRUE, ndraws = NULL) { vapply(args, is_rvar, logical(1)) rvar_args_draws <- as_draws_rvars(args[is_rvar_arg]) .nchains <- max(1, nchains(rvar_args_draws)) + .log_weights <- log_weights(rvar_args_draws) if (length(rvar_args_draws) == 0) { # no rvar arguments, so just create a random variable by applying this function @@ -103,7 +104,7 @@ rfun <- function (.f, rvar_args = NULL, rvar_dots = TRUE, ndraws = NULL) { dim(x) <- c(1, dim(x)) x }) - new_rvar(vctrs::list_unchop(list_of_draws), .nchains = .nchains) + new_rvar(vctrs::list_unchop(list_of_draws), .nchains = .nchains, .log_weights = .log_weights) } formals(rvar_f) <- f_formals rvar_f @@ -236,11 +237,13 @@ rvar_rng <- function(.f, n, ..., ndraws = NULL) { if (length(rvar_args) < 1) { nchains <- 1 ndraws <- ndraws %||% getOption("posterior.rvar_ndraws", 4000) + log_weights <- NULL } else { # we have some arguments that are rvars. We require them to be single-dimensional # (vectors) so that R's vector recycling will work correctly. nchains <- nchains(rvar_args[[1]]) ndraws <- ndraws(rvar_args[[1]]) + log_weights <- log_weights(rvar_args[[1]]) rvar_args_ndims <- lengths(lapply(rvar_args, dim)) if (!all(rvar_args_ndims == 1)) { @@ -266,5 +269,5 @@ rvar_rng <- function(.f, n, ..., ndraws = NULL) { args <- c(n = nd, args) result <- do.call(.f, args) dim(result) <- c(ndraws, n) - new_rvar(result, .nchains = nchains) + new_rvar(result, .nchains = nchains, .log_weights = log_weights) } diff --git a/R/rvar-slice.R b/R/rvar-slice.R index 0129e37..02cc54a 100755 --- a/R/rvar-slice.R +++ b/R/rvar-slice.R @@ -141,17 +141,18 @@ NULL .draws <- draws_of(x)[, i, drop = FALSE] } dimnames(.draws) <- NULL - out <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws } else if (length(index) == length(dim(x))) { # multiple element selection => must have exactly the right number of dims .draws <- inject(draws_of(x)[, !!!index, drop = FALSE]) # must do drop manually in case the draws dimension has only 1 draw dim(.draws) <- c(ndraws(x), 1) - out <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws } else { stop_no_call("subscript out of bounds") } - out + + x } #' @rdname rvar-slice @@ -285,7 +286,7 @@ NULL rownames(.draws) <- seq_len(NROW(.draws)) } - x <- new_rvar(.draws, .nchains = nchains(x)) + draws_of(x) <- .draws if (drop) { x <- drop(x) diff --git a/R/weight_draws.R b/R/weight_draws.R index 6ae12f9..d543379 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -45,6 +45,9 @@ #' head(weights(x)) #' head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts #' +#' # log_weights(x) is equivalent to weights(x, log = TRUE, normalize = FALSE) +#' all.equal(log_weights(x), weights(x, log = TRUE, normalize = FALSE)) +#' #' # add weights on log scale and Pareto smooth them #' x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) #' @@ -110,10 +113,19 @@ weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FA x } +#' @rdname weight_draws +#' @export +weight_draws.rvar <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + attr(x, "log_weights") <- validate_weights(weights, ndraws(x), log, pareto_smooth) + x +} + #' Extract Weights from Draws Objects #' #' Extract weights from [`draws`] objects, with one weight per draw. #' See [`weight_draws`] for details how to add weights to [`draws`] objects. +#' `log_weights(x)` is a low-level shortcut for `weights(x, log = TRUE, normalize = FALSE)`, +#' returning the internal log weights without transforming them. #' #' @param object (draws) A [`draws`] object. #' @param log (logical) Should the weights be returned on the log scale? @@ -132,10 +144,10 @@ weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FA weights.draws <- function(object, log = FALSE, normalize = TRUE, ...) { log <- as_one_logical(log) normalize <- as_one_logical(normalize) - if (!".log_weight" %in% variables(object, reserved = TRUE)) { - return(NULL) - } - out <- extract_variable(object, ".log_weight") + + out <- log_weights(object) + if (is.null(out)) return(NULL) + if (normalize) { out <- out - log_sum_exp(out) } @@ -145,6 +157,30 @@ weights.draws <- function(object, log = FALSE, normalize = TRUE, ...) { out } +#' @export +weights.rvar <- weights.draws + +#' @rdname weights.draws +#' @export +log_weights <- function(object, ...) { + UseMethod("log_weights") +} + +#' @rdname weights.draws +#' @export +log_weights.draws <- function(object, ...) { + if (!".log_weight" %in% variables(object, reserved = TRUE)) { + return(NULL) + } + extract_variable(object, ".log_weight") +} + +#' @rdname weights.draws +#' @export +log_weights.rvar <- function(object, ...) { + attr(object, "log_weights") +} + # validate weights and return log weights validate_weights <- function(weights, ndraws, log = FALSE, pareto_smooth = FALSE) { checkmate::assert_numeric(weights) diff --git a/man/rvar.Rd b/man/rvar.Rd index 9782c3e..0c3fe56 100755 --- a/man/rvar.Rd +++ b/man/rvar.Rd @@ -10,8 +10,7 @@ rvar( dimnames = NULL, nchains = NULL, with_chains = FALSE, - weights = NULL, - log = FALSE + log_weights = NULL ) } \arguments{ diff --git a/man/weight_draws.Rd b/man/weight_draws.Rd index d866d46..0dbf869 100644 --- a/man/weight_draws.Rd +++ b/man/weight_draws.Rd @@ -7,6 +7,7 @@ \alias{weight_draws.draws_df} \alias{weight_draws.draws_list} \alias{weight_draws.draws_rvars} +\alias{weight_draws.rvar} \title{Weight \code{draws} objects} \usage{ weight_draws(x, weights, ...) @@ -20,6 +21,8 @@ weight_draws(x, weights, ...) \method{weight_draws}{draws_list}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) \method{weight_draws}{draws_rvars}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) + +\method{weight_draws}{rvar}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) } \arguments{ \item{x}{(draws) A \code{draws} object or another \R object for which the method @@ -73,6 +76,9 @@ x <- weight_draws(x, weights = log_wts, log = TRUE) head(weights(x)) head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts +# log_weights(x) is equivalent to weights(x, log = TRUE, normalize = FALSE) +all.equal(log_weights(x), weights(x, log = TRUE, normalize = FALSE)) + # add weights on log scale and Pareto smooth them x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) diff --git a/man/weights.draws.Rd b/man/weights.draws.Rd index 1a47788..7751251 100644 --- a/man/weights.draws.Rd +++ b/man/weights.draws.Rd @@ -2,9 +2,18 @@ % Please edit documentation in R/weight_draws.R \name{weights.draws} \alias{weights.draws} +\alias{log_weights} +\alias{log_weights.draws} +\alias{log_weights.rvar} \title{Extract Weights from Draws Objects} \usage{ \method{weights}{draws}(object, log = FALSE, normalize = TRUE, ...) + +log_weights(object, ...) + +\method{log_weights}{draws}(object, ...) + +\method{log_weights}{rvar}(object, ...) } \arguments{ \item{object}{(draws) A \code{\link{draws}} object.} @@ -23,6 +32,8 @@ A vector of weights, with one weight per draw. \description{ Extract weights from \code{\link{draws}} objects, with one weight per draw. See \code{\link{weight_draws}} for details how to add weights to \code{\link{draws}} objects. +\code{log_weights(x)} is a low-level shortcut for \code{weights(x, log = TRUE, normalize = FALSE)}, +returning the internal log weights without transforming them. } \examples{ x <- example_draws() @@ -48,6 +59,9 @@ x <- weight_draws(x, weights = log_wts, log = TRUE) head(weights(x)) head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts +# log_weights(x) is equivalent to weights(x, log = TRUE, normalize = FALSE) +all.equal(log_weights(x), weights(x, log = TRUE, normalize = FALSE)) + # add weights on log scale and Pareto smooth them x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) From 4e768c31c9188d0fc5b15a18dbfdf35f5ea08e3e Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Fri, 5 Jan 2024 19:27:48 -0600 Subject: [PATCH 04/43] use rvar weights instead of .log_weight variable in draws_rvars --- NAMESPACE | 1 + NEWS.md | 2 ++ R/as_draws_array.R | 1 + R/as_draws_df.R | 1 + R/as_draws_matrix.R | 1 + R/as_draws_rvars.R | 24 ++++++++++++++++++++++++ R/rvar-print.R | 13 +++++++++++-- R/weight_draws.R | 13 +++++++++++-- man/rvar_factor.Rd | 2 ++ man/weights.draws.Rd | 3 +++ tests/testthat/test-subset_draws.R | 2 +- tests/testthat/test-weight_draws.R | 3 ++- 12 files changed, 60 insertions(+), 6 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index fa735ef..745ee28 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -166,6 +166,7 @@ S3method(iteration_ids,rvar) S3method(length,rvar) S3method(levels,rvar) S3method(log_weights,draws) +S3method(log_weights,draws_rvars) S3method(log_weights,rvar) S3method(mad,default) S3method(mad,rvar) diff --git a/NEWS.md b/NEWS.md index a690914..135435f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ * Add `pareto_smooth` option to `weight_draws`, to Pareto smooth weights before adding to a draws object. +* Add support for applying weights to individual `rvar` objects. +* Add `log_weights()` function for easy access to raw internal weights. * Matrix multiplication of `rvar`s can now be done with the base matrix multiplication operator (`%*%`) instead of `%**%` in R >= 4.3. diff --git a/R/as_draws_array.R b/R/as_draws_array.R index 8ed9a52..84f4ec3 100644 --- a/R/as_draws_array.R +++ b/R/as_draws_array.R @@ -96,6 +96,7 @@ as_draws_array.draws_rvars <- function(x, ...) { x <- check_variables_are_numeric( x, to = "draws_array", is_non_numeric = is_rvar_factor, convert = FALSE ) + x <- promote_rvar_weights_to_variable(x) # cbind discards class information when applied to vectors, which converts # the underlying factors to numeric diff --git a/R/as_draws_df.R b/R/as_draws_df.R index 9eefcdb..28d17a6 100644 --- a/R/as_draws_df.R +++ b/R/as_draws_df.R @@ -110,6 +110,7 @@ as_draws_df.draws_rvars <- function(x, ...) { if (ndraws(x) == 0L) { return(empty_draws_df(variables(x))) } + x <- promote_rvar_weights_to_variable(x) out <- do.call(cbind, lapply(seq_along(x), function(i) { # flatten each rvar so it only has two dimensions: draws and variables # this also collapses indices into variable names in the format "var[i,j,k,...]" diff --git a/R/as_draws_matrix.R b/R/as_draws_matrix.R index 5d7a37c..03fe4ab 100644 --- a/R/as_draws_matrix.R +++ b/R/as_draws_matrix.R @@ -85,6 +85,7 @@ as_draws_matrix.draws_rvars <- function(x, ...) { x <- check_variables_are_numeric( x, to = "draws_matrix", is_non_numeric = is_rvar_factor, convert = FALSE ) + x <- promote_rvar_weights_to_variable(x) # cbind discards class information when applied to vectors, which converts # the underlying factors to numeric diff --git a/R/as_draws_rvars.R b/R/as_draws_rvars.R index 5ba8882..132dc23 100755 --- a/R/as_draws_rvars.R +++ b/R/as_draws_rvars.R @@ -226,6 +226,20 @@ as_draws_rvars.mcmc.list <- function(x, ...) { x <- conform_rvar_ndraws_nchains(x) class(x) <- class_draws_rvars() + + # move the .log_weight column into the log_weights attribute of each rvar, + # but only if there is no conflict between any existing weights on the rvars + if (".log_weight" %in% names(x)) { + existing_weights <- log_weights(x[[1]]) + .log_weight <- as.vector(draws_of(x$.log_weight)) + if (is.null(existing_weights)) { + x$.log_weight <- NULL + x <- weight_draws(x, .log_weight, log = TRUE) + } else { + weights2_common(existing_weights, .log_weight) + } + } + x } @@ -274,3 +288,13 @@ empty_draws_rvars <- function(variables = character(0), nchains = 0) { class(out) <- class_draws_rvars() out } + +# when converting draws_rvars to other formats, we must promote log weights +# to be a variable before doing the conversion +promote_rvar_weights_to_variable <- function(x) { + .log_weights <- log_weights(x) + if (!is.null(.log_weights)) { + x$.log_weight <- rvar(log_weights(x)) + } + x +} diff --git a/R/rvar-print.R b/R/rvar-print.R index 1efdbea..5d07880 100755 --- a/R/rvar-print.R +++ b/R/rvar-print.R @@ -158,7 +158,11 @@ str.rvar <- function( } } str_attr(attributes(draws_of(object)), "draws_of(*)", c("names", "dim", "dimnames", "class", "levels")) - str_attr(attributes(object), "*", c("draws", "names", "dim", "dimnames", "class", "nchains", "cache")) + str_attr(attributes(object), "*", c("draws", "names", "dim", "dimnames", "class", "nchains", "cache", "log_weights")) + if ("log_weights" %in% names(attributes(object))) { + cat0(indent.str, paste0('- log_weights(*)=')) + str_next(attr(object, "log_weights"), ...) + } } invisible(NULL) @@ -218,7 +222,12 @@ rvar_type_full <- function(x, dim1 = TRUE) { paste0(",", nchains(x)) } - paste0(rvar_class(x), "<", niterations(x), chain_str, ">", dim_str) + paste0( + if (!is.null(log_weights(x))) "weighted ", + rvar_class(x), + "<", niterations(x), chain_str, ">", + dim_str + ) } rvar_class <- function(x) { diff --git a/R/weight_draws.R b/R/weight_draws.R index d543379..b53898c 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -108,8 +108,10 @@ weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FAL #' @rdname weight_draws #' @export weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { - log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) - x$.log_weight <- rvar(log_weights) + .log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) + for (i in seq_along(x)) { + attr(x[[i]], "log_weights") <- .log_weights + } x } @@ -175,6 +177,13 @@ log_weights.draws <- function(object, ...) { extract_variable(object, ".log_weight") } +#' @rdname weights.draws +#' @export +log_weights.draws_rvars <- function(object, ...) { + if (length(object) < 1) return(NULL) + attr(object[[1]], "log_weights") +} + #' @rdname weights.draws #' @export log_weights.rvar <- function(object, ...) { diff --git a/man/rvar_factor.Rd b/man/rvar_factor.Rd index 92ea056..ba72387 100644 --- a/man/rvar_factor.Rd +++ b/man/rvar_factor.Rd @@ -11,6 +11,7 @@ rvar_factor( dimnames = NULL, nchains = NULL, with_chains = FALSE, + log_weights = NULL, ... ) @@ -20,6 +21,7 @@ rvar_ordered( dimnames = NULL, nchains = NULL, with_chains = FALSE, + log_weights = NULL, ... ) } diff --git a/man/weights.draws.Rd b/man/weights.draws.Rd index 7751251..8df762d 100644 --- a/man/weights.draws.Rd +++ b/man/weights.draws.Rd @@ -4,6 +4,7 @@ \alias{weights.draws} \alias{log_weights} \alias{log_weights.draws} +\alias{log_weights.draws_rvars} \alias{log_weights.rvar} \title{Extract Weights from Draws Objects} \usage{ @@ -13,6 +14,8 @@ log_weights(object, ...) \method{log_weights}{draws}(object, ...) +\method{log_weights}{draws_rvars}(object, ...) + \method{log_weights}{rvar}(object, ...) } \arguments{ diff --git a/tests/testthat/test-subset_draws.R b/tests/testthat/test-subset_draws.R index a5170ac..42560ba 100644 --- a/tests/testthat/test-subset_draws.R +++ b/tests/testthat/test-subset_draws.R @@ -94,7 +94,7 @@ test_that("subset_draws works correctly for draws_rvars objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") - expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + expect_equal(variables(x_sub, reserved = TRUE), c("mu")) }) test_that("subset_draws works correctly for rvar objects", { diff --git a/tests/testthat/test-weight_draws.R b/tests/testthat/test-weight_draws.R index fb6e6cc..0bce32a 100644 --- a/tests/testthat/test-weight_draws.R +++ b/tests/testthat/test-weight_draws.R @@ -71,7 +71,8 @@ test_that("conversion between formats preserves weights", { array = weight_draws(draws_array(x = 1:10), 1:10), df = weight_draws(draws_df(x = 1:10), 1:10), list = weight_draws(draws_list(x = 1:10), 1:10), - rvars = weight_draws(draws_rvars(x = 1:10), 1:10) + rvars = weight_draws(draws_rvars(x = 1:10), 1:10), + rvar = weight_draws(rvar(x = 1:10), 1:10) ) # chain/iteration/draw columns are placed at the end by conversion functions, From 48c17ee89e4c70839da6610780379785156510d6 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Fri, 5 Jan 2024 21:33:03 -0600 Subject: [PATCH 05/43] weighting for rvar functions except rvar-dist --- R/as_draws_rvars.R | 2 +- R/rvar-print.R | 2 +- R/rvar-slice.R | 1 + R/rvar-summaries-over-draws.R | 32 +++++++++++++++++++++++--------- R/subset_draws.R | 2 ++ R/weight_draws.R | 17 +++++++++++++---- man/weights.draws.Rd | 5 ++++- 7 files changed, 45 insertions(+), 16 deletions(-) diff --git a/R/as_draws_rvars.R b/R/as_draws_rvars.R index 132dc23..038dcc7 100755 --- a/R/as_draws_rvars.R +++ b/R/as_draws_rvars.R @@ -294,7 +294,7 @@ empty_draws_rvars <- function(variables = character(0), nchains = 0) { promote_rvar_weights_to_variable <- function(x) { .log_weights <- log_weights(x) if (!is.null(.log_weights)) { - x$.log_weight <- rvar(log_weights(x)) + x$.log_weight <- rvar(log_weights(x), nchains = nchains(x)) } x } diff --git a/R/rvar-print.R b/R/rvar-print.R index 5d07880..6fdac41 100755 --- a/R/rvar-print.R +++ b/R/rvar-print.R @@ -161,7 +161,7 @@ str.rvar <- function( str_attr(attributes(object), "*", c("draws", "names", "dim", "dimnames", "class", "nchains", "cache", "log_weights")) if ("log_weights" %in% names(attributes(object))) { cat0(indent.str, paste0('- log_weights(*)=')) - str_next(attr(object, "log_weights"), ...) + str_next(log_weights(object), ...) } } diff --git a/R/rvar-slice.R b/R/rvar-slice.R index 02cc54a..e33a982 100755 --- a/R/rvar-slice.R +++ b/R/rvar-slice.R @@ -284,6 +284,7 @@ NULL if (!is_missing(draws_index[[1]])) { # if we subsetted draws, replace draw ids with sequential ids rownames(.draws) <- seq_len(NROW(.draws)) + log_weights_rvar(x) <- inject(log_weights(x)[!!!draws_index]) } draws_of(x) <- .draws diff --git a/R/rvar-summaries-over-draws.R b/R/rvar-summaries-over-draws.R index b0adb9f..b8cc2b3 100755 --- a/R/rvar-summaries-over-draws.R +++ b/R/rvar-summaries-over-draws.R @@ -68,7 +68,7 @@ E <- function(x, ...) { #' @export mean.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "mean", matrixStats::colMeans2, useNames = FALSE, .ordered_okay = FALSE, ... + x, "mean", matrixStats::colWeightedMeans, useNames = FALSE, .ordered_okay = FALSE, w = weights(x), ... ) } @@ -101,7 +101,7 @@ Pr.rvar <- function(x, ...) { #' @export median.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "median", matrixStats::colMedians, useNames = FALSE, ... + x, "median", matrixStats::colWeightedMedians, useNames = FALSE, w = weights(x), ... ) } @@ -124,6 +124,8 @@ max.rvar <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export sum.rvar <- function(x, ...) { + .weights <- weights(x, normalize = FALSE) + if (!is.null(.weights)) x <- x * new_rvar(.weights, .nchains = nchains(x)) summarise_rvar_by_element_via_matrix( x, "sum", matrixStats::colSums2, useNames = FALSE, .ordered_okay = FALSE, ... ) @@ -132,6 +134,8 @@ sum.rvar <- function(x, ...) { #' @rdname rvar-summaries-over-draws #' @export prod.rvar <- function(x, ...) { + .weights <- weights(x, normalize = FALSE) + if (!is.null(.weights)) x <- x ^ new_rvar(.weights, .nchains = nchains(x)) summarise_rvar_by_element_via_matrix( x, "prod", matrixStats::colProds, useNames = FALSE, .ordered_okay = FALSE, ... ) @@ -172,9 +176,14 @@ distributional::variance #' @rdname rvar-summaries-over-draws #' @export variance.rvar <- function(x, ...) { - summarise_rvar_by_element_via_matrix( - x, "variance", matrixStats::colVars, useNames = FALSE, .ordered_okay = FALSE, ... - ) + .weights <- weights(x) + if (is.null(.weights)) { + summarise_rvar_by_element_via_matrix( + x, "variance", matrixStats::colVars, useNames = FALSE, .ordered_okay = FALSE, ... + ) + } else { + mean((x - mean(x))^2) + } } #' @rdname rvar-summaries-over-draws @@ -196,9 +205,14 @@ sd.default <- function(x, ...) stats::sd(x, ...) #' @rdname rvar-summaries-over-draws #' @export sd.rvar <- function(x, ...) { - summarise_rvar_by_element_via_matrix( - x, "sd", matrixStats::colSds, useNames = FALSE, .ordered_okay = FALSE, ... - ) + .weights <- weights(x) + if (is.null(.weights)) { + summarise_rvar_by_element_via_matrix( + x, "sd", matrixStats::colWeightedSds, useNames = FALSE, .ordered_okay = FALSE, w = weights(x), ... + ) + } else { + sqrt(variance(x)) + } } #' @rdname rvar-summaries-over-draws @@ -211,7 +225,7 @@ mad.default <- function(x, ...) stats::mad(x, ...) #' @export mad.rvar <- function(x, ...) { summarise_rvar_by_element_via_matrix( - x, "mad", matrixStats::colMads, useNames = FALSE, .ordered_okay = FALSE, ... + x, "mad", matrixStats::colWeightedMads, useNames = FALSE, .ordered_okay = FALSE, w = weights(x), ... ) } #' @rdname rvar-summaries-over-draws diff --git a/R/subset_draws.R b/R/subset_draws.R index 8465f14..d0b0cd6 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -359,6 +359,7 @@ subset_dims <- function(x, ...) { for (i in seq_along(x)) { draws_of(x[[i]]) <- vec_slice(draws_of(x[[i]]), slice_index) nchains_rvar(x[[i]]) <- nchains + log_weights_rvar(x[[i]]) <- log_weights(x[[i]])[slice_index] } } if (!is.null(iteration)) { @@ -367,6 +368,7 @@ subset_dims <- function(x, ...) { (rep(chain_ids(x), each = niterations) - 1) * niterations(x) for (i in seq_along(x)) { draws_of(x[[i]]) <- vec_slice(draws_of(x[[i]]), slice_index) + log_weights_rvar(x[[i]]) <- log_weights(x[[i]])[slice_index] } } x diff --git a/R/weight_draws.R b/R/weight_draws.R index b53898c..490853a 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -110,7 +110,7 @@ weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FAL weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { .log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) for (i in seq_along(x)) { - attr(x[[i]], "log_weights") <- .log_weights + log_weights_rvar(x[[i]]) <- .log_weights } x } @@ -118,7 +118,7 @@ weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FA #' @rdname weight_draws #' @export weight_draws.rvar <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { - attr(x, "log_weights") <- validate_weights(weights, ndraws(x), log, pareto_smooth) + log_weights_rvar(x) <- validate_weights(weights, ndraws(x), log, pareto_smooth) x } @@ -129,7 +129,7 @@ weight_draws.rvar <- function(x, weights, log = FALSE, pareto_smooth = FALSE, .. #' `log_weights(x)` is a low-level shortcut for `weights(x, log = TRUE, normalize = FALSE)`, #' returning the internal log weights without transforming them. #' -#' @param object (draws) A [`draws`] object. +#' @param object (draws) A [`draws`] object or an [`rvar`]. #' @param log (logical) Should the weights be returned on the log scale? #' Defaults to `FALSE`. #' @param normalize (logical) Should the weights be normalized to sum to 1 on @@ -159,6 +159,7 @@ weights.draws <- function(object, log = FALSE, normalize = TRUE, ...) { out } +#' @rdname weights.draws #' @export weights.rvar <- weights.draws @@ -181,7 +182,7 @@ log_weights.draws <- function(object, ...) { #' @export log_weights.draws_rvars <- function(object, ...) { if (length(object) < 1) return(NULL) - attr(object[[1]], "log_weights") + log_weights(object[[1]]) } #' @rdname weights.draws @@ -189,6 +190,14 @@ log_weights.draws_rvars <- function(object, ...) { log_weights.rvar <- function(object, ...) { attr(object, "log_weights") } +# for internal use only currently: if you are setting the log_weights +# attribute on an rvar, ALWAYS use this function so that the proxy +# cache is invalidated +`log_weights_rvar<-` <- function(x, value) { + attr(x, "log_weights") <- value + invalidate_rvar_cache(x) +} + # validate weights and return log weights validate_weights <- function(weights, ndraws, log = FALSE, pareto_smooth = FALSE) { diff --git a/man/weights.draws.Rd b/man/weights.draws.Rd index 8df762d..2f88b40 100644 --- a/man/weights.draws.Rd +++ b/man/weights.draws.Rd @@ -2,6 +2,7 @@ % Please edit documentation in R/weight_draws.R \name{weights.draws} \alias{weights.draws} +\alias{weights.rvar} \alias{log_weights} \alias{log_weights.draws} \alias{log_weights.draws_rvars} @@ -10,6 +11,8 @@ \usage{ \method{weights}{draws}(object, log = FALSE, normalize = TRUE, ...) +\method{weights}{rvar}(object, log = FALSE, normalize = TRUE, ...) + log_weights(object, ...) \method{log_weights}{draws}(object, ...) @@ -19,7 +22,7 @@ log_weights(object, ...) \method{log_weights}{rvar}(object, ...) } \arguments{ -\item{object}{(draws) A \code{\link{draws}} object.} +\item{object}{(draws) A \code{\link{draws}} object or an \code{\link{rvar}}.} \item{log}{(logical) Should the weights be returned on the log scale? Defaults to \code{FALSE}.} From 4ba15822c1461e3d20b4375aed2eda816fe2d6ff Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Fri, 5 Jan 2024 23:49:14 -0600 Subject: [PATCH 06/43] weighted discrete summaries --- R/discrete-summaries.R | 73 ++++++++++++++-------- R/rvar-.R | 8 ++- man-roxygen/args-summaries-weights.R | 2 + man-roxygen/args-summaries-x-categorical.R | 5 ++ man/dissent.Rd | 11 +++- man/entropy.Rd | 11 +++- man/modal_category.Rd | 11 +++- man/rvar.Rd | 5 ++ man/rvar_factor.Rd | 5 ++ 9 files changed, 94 insertions(+), 37 deletions(-) create mode 100644 man-roxygen/args-summaries-weights.R create mode 100644 man-roxygen/args-summaries-x-categorical.R diff --git a/R/discrete-summaries.R b/R/discrete-summaries.R index 3c141c9..f4f4e14 100644 --- a/R/discrete-summaries.R +++ b/R/discrete-summaries.R @@ -2,11 +2,9 @@ #' #' Normalized entropy, for measuring dispersion in draws from categorical distributions. #' -#' @param x (multiple options) A vector to be interpreted as draws from -#' a categorical distribution, such as: -#' - A [factor] -#' - A [numeric] (should be [integer] or integer-like) -#' - An [rvar], [rvar_factor], or [rvar_ordered] +#' @template args-summaries-x-categorical +#' @template args-summaries-weights +#' @template args-methods-dots #' #' @details #' Calculates the normalized Shannon entropy of the draws in `x`. This value is @@ -51,14 +49,14 @@ #' xy #' entropy(xy) #' @export -entropy <- function(x) { +entropy <- function(x, ...) { UseMethod("entropy") } #' @rdname entropy #' @export -entropy.default <- function(x) { +entropy.default <- function(x, weights = NULL, ...) { if (anyNA(x)) return(NA_real_) - p <- prop.table(simple_table(x)$count) + p <- prop.table(weighted_simple_table(x, weights)$count) n <- length(p) if (n == 1) { @@ -71,8 +69,8 @@ entropy.default <- function(x) { } #' @rdname entropy #' @export -entropy.rvar <- function(x) { - summarise_rvar_by_element(x, entropy) +entropy.rvar <- function(x, ...) { + summarise_rvar_by_element(x, entropy, weights = weights(x)) } @@ -85,6 +83,8 @@ entropy.rvar <- function(x) { #' - A [factor] #' - A [numeric] (should be [integer] or integer-like) #' - An [rvar], [rvar_factor], or [rvar_ordered] +#' @template args-summaries-weights +#' @template args-methods-dots #' #' @details #' Calculates Tastle and Wierman's (2007) *dissention* measure: @@ -125,12 +125,12 @@ entropy.rvar <- function(x) { #' xy #' dissent(xy) #' @export -dissent <- function(x) { +dissent <- function(x, ...) { UseMethod("dissent") } #' @rdname dissent #' @export -dissent.default <- function(x) { +dissent.default <- function(x, weights = NULL, ...) { if (anyNA(x)) return(NA_real_) if (length(x) == 0) return(0) @@ -141,21 +141,22 @@ dissent.default <- function(x) { d <- diff(range(x)) } - tab <- simple_table(x) + tab <- weighted_simple_table(x, weights) p <- prop.table(tab$count) if (length(p) == 1) { out <- 0 } else { x_i <- tab$x - out <- -sum(p * log2(1 - abs(x_i - mean(x)) / d)) + mean_x <- if (is.null(weights)) mean(x) else weighted.mean(x, weights) + out <- -sum(p * log2(1 - abs(x_i - mean_x) / d)) } out } #' @rdname dissent #' @export -dissent.rvar <- function(x) { - summarise_rvar_by_element(x, dissent) +dissent.rvar <- function(x, ...) { + summarise_rvar_by_element(x, dissent, weights = weights(x)) } @@ -163,11 +164,9 @@ dissent.rvar <- function(x) { #' #' Modal category of a vector. #' -#' @param x (multiple options) A vector to be interpreted as draws from -#' a categorical distribution, such as: -#' - A [factor] -#' - A [numeric] (should be [integer] or integer-like) -#' - An [rvar], [rvar_factor], or [rvar_ordered] +#' @template args-summaries-x-categorical +#' @template args-summaries-weights +#' @template args-methods-dots #' #' @details #' Finds the modal category (i.e., most frequent value) in `x`. In the case of @@ -192,20 +191,20 @@ dissent.rvar <- function(x) { #' xy #' modal_category(xy) #' @export -modal_category <- function(x) { +modal_category <- function(x, ...) { UseMethod("modal_category") } #' @rdname modal_category #' @export -modal_category.default <- function(x) { +modal_category.default <- function(x, weights = NULL, ...) { if (anyNA(x)) return(NA) - tab <- simple_table(x) + tab <- weighted_simple_table(x, weights) tab$x[which.max(tab$count)] } #' @rdname modal_category #' @export -modal_category.rvar <- function(x) { - summarise_rvar_by_element(x, modal_category) +modal_category.rvar <- function(x, ...) { + summarise_rvar_by_element(x, modal_category, weights = weights(x)) } @@ -231,3 +230,25 @@ simple_table <- function(x) { count = tabulate(x_int, nbins = length(values)) ) } + +#' A weighted version of simple_table +#' @param x a vector (numeric, factor, character, etc) +#' @param weights weights +#' @returns a list with two components of the same length +#' - `x`: unique values from the input `x` +#' - `count`: sum of weights for each unique value of `x` +#' @noRd +weighted_simple_table <- function(x, weights) { + if (is.null(weights)) return(simple_table(x)) + stopifnot(identical(length(x), length(weights))) + + if (is.factor(x)) { + values <- levels(x) + } else { + values <- unique(x) + } + list( + x = values, + count = vapply(split(weights, factor(x, values)), sum, numeric(1), USE.NAMES = FALSE) + ) +} diff --git a/R/rvar-.R b/R/rvar-.R index 9dcf23a..ef2d864 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -26,6 +26,10 @@ #' is ignored and the second dimension of `x` is used to index chains. #' Internally, the array will be converted to a format without the chain index. #' Ignored when `x` is already an [`rvar`]. +#' @param log_weights (numeric vector) A vector of log weights of length `ndraws(x)`. +#' Weights will be internally stored on the log scale and will not be normalized, +#' but normalized (non-log) weights can be returned via the [weights.rvar()] +#' method later. #' #' @details #' @@ -556,8 +560,8 @@ weights2_common <- function(weights_x, weights_y) { } else { stop_no_call( "Random variables have different log weights and cannot be used together:\n", - "<", vctrs::vec_ptype_abbr(weights_x), "> ", paste(head(weights_x, 5), collapse = ", "), " ...\n", - "<", vctrs::vec_ptype_abbr(weights_y), "> ", paste(head(weights_y, 5), collapse = ", "), " ..." + "<", vctrs::vec_ptype_abbr(weights_x), "> ", paste(utils::head(weights_x, 5), collapse = ", "), " ...\n", + "<", vctrs::vec_ptype_abbr(weights_y), "> ", paste(utils::head(weights_y, 5), collapse = ", "), " ..." ) } } diff --git a/man-roxygen/args-summaries-weights.R b/man-roxygen/args-summaries-weights.R new file mode 100644 index 0000000..d7089d7 --- /dev/null +++ b/man-roxygen/args-summaries-weights.R @@ -0,0 +1,2 @@ +#' @param weights (numeric vector) A vector of weights of the same length as `x`, +#' or `NULL` for unweighted estimation. diff --git a/man-roxygen/args-summaries-x-categorical.R b/man-roxygen/args-summaries-x-categorical.R new file mode 100644 index 0000000..6bc363c --- /dev/null +++ b/man-roxygen/args-summaries-x-categorical.R @@ -0,0 +1,5 @@ +#' @param x (multiple options) A vector to be interpreted as draws from +#' a categorical distribution, such as: +#' - A [factor] +#' - A [numeric] (should be [integer] or integer-like) +#' - An [rvar], [rvar_factor], or [rvar_ordered] diff --git a/man/dissent.Rd b/man/dissent.Rd index 4166ee5..d16431c 100644 --- a/man/dissent.Rd +++ b/man/dissent.Rd @@ -6,11 +6,11 @@ \alias{dissent.rvar} \title{Dissention} \usage{ -dissent(x) +dissent(x, ...) -\method{dissent}{default}(x) +\method{dissent}{default}(x, weights = NULL, ...) -\method{dissent}{rvar}(x) +\method{dissent}{rvar}(x, ...) } \arguments{ \item{x}{(multiple options) A vector to be interpreted as draws from @@ -20,6 +20,11 @@ an ordinal distribution, such as: \item A \link{numeric} (should be \link{integer} or integer-like) \item An \link{rvar}, \link{rvar_factor}, or \link{rvar_ordered} }} + +\item{...}{Arguments passed to individual methods (if applicable).} + +\item{weights}{(numeric vector) A vector of weights of the same length as \code{x}, +or \code{NULL} for unweighted estimation.} } \value{ If \code{x} is a \link{factor} or \link{numeric}, returns a length-1 numeric vector with a value diff --git a/man/entropy.Rd b/man/entropy.Rd index 429068f..8d8657b 100644 --- a/man/entropy.Rd +++ b/man/entropy.Rd @@ -6,11 +6,11 @@ \alias{entropy.rvar} \title{Normalized entropy} \usage{ -entropy(x) +entropy(x, ...) -\method{entropy}{default}(x) +\method{entropy}{default}(x, weights = NULL, ...) -\method{entropy}{rvar}(x) +\method{entropy}{rvar}(x, ...) } \arguments{ \item{x}{(multiple options) A vector to be interpreted as draws from @@ -20,6 +20,11 @@ a categorical distribution, such as: \item A \link{numeric} (should be \link{integer} or integer-like) \item An \link{rvar}, \link{rvar_factor}, or \link{rvar_ordered} }} + +\item{...}{Arguments passed to individual methods (if applicable).} + +\item{weights}{(numeric vector) A vector of weights of the same length as \code{x}, +or \code{NULL} for unweighted estimation.} } \value{ If \code{x} is a \link{factor} or \link{numeric}, returns a length-1 numeric vector with a value diff --git a/man/modal_category.Rd b/man/modal_category.Rd index 8fd8300..2f8351d 100644 --- a/man/modal_category.Rd +++ b/man/modal_category.Rd @@ -6,11 +6,11 @@ \alias{modal_category.rvar} \title{Modal category} \usage{ -modal_category(x) +modal_category(x, ...) -\method{modal_category}{default}(x) +\method{modal_category}{default}(x, weights = NULL, ...) -\method{modal_category}{rvar}(x) +\method{modal_category}{rvar}(x, ...) } \arguments{ \item{x}{(multiple options) A vector to be interpreted as draws from @@ -20,6 +20,11 @@ a categorical distribution, such as: \item A \link{numeric} (should be \link{integer} or integer-like) \item An \link{rvar}, \link{rvar_factor}, or \link{rvar_ordered} }} + +\item{...}{Arguments passed to individual methods (if applicable).} + +\item{weights}{(numeric vector) A vector of weights of the same length as \code{x}, +or \code{NULL} for unweighted estimation.} } \value{ If \code{x} is a \link{factor} or \link{numeric}, returns a length-1 vector containing diff --git a/man/rvar.Rd b/man/rvar.Rd index 0c3fe56..2dc1861 100755 --- a/man/rvar.Rd +++ b/man/rvar.Rd @@ -51,6 +51,11 @@ used to determine the number of chains. If \code{TRUE}, the \code{nchains} argum is ignored and the second dimension of \code{x} is used to index chains. Internally, the array will be converted to a format without the chain index. Ignored when \code{x} is already an \code{\link{rvar}}.} + +\item{log_weights}{(numeric vector) A vector of log weights of length \code{ndraws(x)}. +Weights will be internally stored on the log scale and will not be normalized, +but normalized (non-log) weights can be returned via the \code{\link[=weights.rvar]{weights.rvar()}} +method later.} } \value{ An object of class \code{"rvar"} representing a random variable. diff --git a/man/rvar_factor.Rd b/man/rvar_factor.Rd index ba72387..315a397 100644 --- a/man/rvar_factor.Rd +++ b/man/rvar_factor.Rd @@ -64,6 +64,11 @@ is ignored and the second dimension of \code{x} is used to index chains. Internally, the array will be converted to a format without the chain index. Ignored when \code{x} is already an \code{\link{rvar}}.} +\item{log_weights}{(numeric vector) A vector of log weights of length \code{ndraws(x)}. +Weights will be internally stored on the log scale and will not be normalized, +but normalized (non-log) weights can be returned via the \code{\link[=weights.rvar]{weights.rvar()}} +method later.} + \item{...}{ Arguments passed on to \code{\link[base:factor]{base::factor}} \describe{ From 2db69fe0a7dffd24d974b98fb9e5527bc20370e6 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sat, 6 Jan 2024 12:35:13 -0600 Subject: [PATCH 07/43] tests for weighted rvars --- R/draws-index.R | 7 ++- R/rvar-.R | 8 +-- R/rvar-print.R | 32 +++++++---- R/weight_draws.R | 7 ++- tests/testthat/test-discrete-summaries.R | 30 +++++++++- tests/testthat/test-rvar-dist.R | 4 ++ tests/testthat/test-rvar-print.R | 55 +++++++++++++++++++ .../testthat/test-rvar-summaries-over-draws.R | 20 +++++++ tests/testthat/test-weight_draws.R | 20 +++++++ 9 files changed, 163 insertions(+), 20 deletions(-) diff --git a/R/draws-index.R b/R/draws-index.R index 8cdfa99..2333a86 100644 --- a/R/draws-index.R +++ b/R/draws-index.R @@ -421,8 +421,11 @@ nchains.rvar <- function(x) { # attribute on an rvar, ALWAYS use this function so that the proxy # cache is invalidated `nchains_rvar<-` <- function(x, value) { - attr(x, "nchains") <- value - invalidate_rvar_cache(x) + if (attr(x, "nchains") != value) { + attr(x, "nchains") <- value + x <- invalidate_rvar_cache(x) + } + x } diff --git a/R/rvar-.R b/R/rvar-.R index ef2d864..85d1599 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -758,10 +758,10 @@ broadcast_array <- function(x, dim, broadcast_scalars = TRUE) { # broadcast the draws dimension of an rvar to the requested size broadcast_draws <- function(x, .ndraws, keep_constants = FALSE, .log_weights = NULL) { ndraws_x = ndraws(x) - if ( - (ndraws_x == 1 && keep_constants) || - (ndraws_x == .ndraws) - ) { + if (ndraws_x == 1 && keep_constants) { + x + } else if (ndraws_x == .ndraws) { + log_weights_rvar(x) <- .log_weights x } else { draws <- draws_of(x) diff --git a/R/rvar-print.R b/R/rvar-print.R index 6fdac41..0e7b7a0 100755 --- a/R/rvar-print.R +++ b/R/rvar-print.R @@ -60,7 +60,7 @@ print.rvar <- function(x, ..., summary = NULL, digits = NULL, color = TRUE, widt digits <- digits %||% getOption("posterior.digits", 2) # \u00b1 = plus/minus sign summary_functions <- get_summary_functions(draws_of(x), summary) - plus_minus <- summary_functions[[1]] != "modal_category" + plus_minus <- !identical(summary_functions[[1]], modal_category) summary_string <- if (plus_minus) { paste0(paste(names(summary_functions), collapse = " \u00b1 "), ":") } else { @@ -89,7 +89,7 @@ print.rvar <- function(x, ..., summary = NULL, digits = NULL, color = TRUE, widt #' @export format.rvar <- function(x, ..., summary = NULL, digits = NULL, color = FALSE) { digits <- digits %||% getOption("posterior.digits", 2) - format_rvar_draws(draws_of(x), ..., summary = summary, digits = digits, color = color) + format_rvar_draws(draws_of(x), weights(x), ..., summary = summary, digits = digits, color = color) } #' @rdname print.rvar @@ -126,7 +126,7 @@ str.rvar <- function( } cat0(" ", rvar_type_full(object), " ", - paste(format_rvar_draws(.draws, summary = summary, trim = TRUE), collapse = " "), + paste(format_rvar_draws(.draws, weights(object), summary = summary, trim = TRUE), collapse = " "), ellipsis, "\n" ) @@ -244,19 +244,19 @@ rvar_class <- function(x) { # formats a draws array for display as individual "variables" (i.e. maintaining # its dimensions except for the dimension representing draws) format_rvar_draws <- function( - draws, ..., pad_left = "", pad_right = "", summary = NULL, digits = 2, color = FALSE, trim = FALSE + draws, weights, ..., pad_left = "", pad_right = "", summary = NULL, digits = 2, color = FALSE, trim = FALSE ) { if (length(draws) == 0) { return(character()) } summary_functions <- get_summary_functions(draws, summary) - plus_minus <- summary_functions[[1]] != "modal_category" + plus_minus <- !identical(summary_functions[[1]], modal_category) summary_dimensions <- seq_len(length(dim(draws)) - 1) + 1 # these will be mean/sd, median/mad, mode/entropy, mode/dissent depending on `summary` - .mean <- .apply_factor(draws, summary_dimensions, summary_functions[[1]]) - .sd <- .apply_factor(draws, summary_dimensions, summary_functions[[2]]) + .mean <- .apply_factor(draws, summary_dimensions, function(x) summary_functions[[1]](x, weights)) + .sd <- .apply_factor(draws, summary_dimensions, function(x) summary_functions[[2]](x, weights)) out <- paste0( pad_left, @@ -322,6 +322,16 @@ format_levels <- function(levels, ordered = FALSE, max_level = NULL, width = get ) } +# matrixStats::weighted_sd assumes we know the sample size, so use +# this instead +weighted_sd <- function(x, w = NULL) { + if (is.null(w)) { + sd(x) + } else { + sqrt(weighted.mean((x - weighted.mean(x, w))^2, w) ) + } +} + # check that summary is a valid name of the type of summary to do and # return a vector of two elements, where the first is the point summary function # (mean, median, mode) and the second is the uncertainty function () @@ -334,10 +344,10 @@ get_summary_functions <- function(draws, summary = NULL) { if (is.null(summary)) summary <- getOption("posterior.rvar_summary", "mean_sd") switch(summary, - mean_sd = list(mean = "mean", sd = "sd"), - median_mad = list(median = "median", mad = "mad"), - mode_entropy = list(mode = "modal_category", entropy = "entropy"), - mode_dissent = list(mode = "modal_category", dissent = "dissent"), + mean_sd = list(mean = matrixStats::weightedMean, sd = weighted_sd), + median_mad = list(median = matrixStats::weightedMedian, mad = matrixStats::weightedMad), + mode_entropy = list(mode = modal_category, entropy = entropy), + mode_dissent = list(mode = modal_category, dissent = dissent), stop_no_call('`summary` must be one of "mean_sd" or "median_mad"') ) } diff --git a/R/weight_draws.R b/R/weight_draws.R index 490853a..19d5443 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -194,8 +194,11 @@ log_weights.rvar <- function(object, ...) { # attribute on an rvar, ALWAYS use this function so that the proxy # cache is invalidated `log_weights_rvar<-` <- function(x, value) { - attr(x, "log_weights") <- value - invalidate_rvar_cache(x) + if (!identical(attr(x, "log_weights"), value)) { + attr(x, "log_weights") <- value + x <- invalidate_rvar_cache(x) + } + x } diff --git a/tests/testthat/test-discrete-summaries.R b/tests/testthat/test-discrete-summaries.R index 78ce902..5352aa3 100644 --- a/tests/testthat/test-discrete-summaries.R +++ b/tests/testthat/test-discrete-summaries.R @@ -24,7 +24,6 @@ test_that("modal_category works on rvars", { expect_equal(modal_category(c(rvar(c("a","b","b","c","c")), rvar("c"))), c("b","c")) }) - # entropy ----------------------------------------------------------------- test_that("entropy works on vectors", { @@ -83,3 +82,32 @@ test_that("dissent works on rvars", { # know about the missing level at the end expect_equal(dissent(as_rvar_numeric(x)), c(-sum(p * log2(1 - abs(1:3 - 1.75) / 2)), 0, 1)) }) + + +# weighted summaries ------------------------------------------------------ + +test_that("weighted discrete summaries work", { + x <- c(0, 0, 0, 3, 3, 1) + levs <- c("h","e","f","g") + x_factor <- factor(c("h","h","h","g","g","e"), levels = levs) + x_ordered <- ordered(c("h","h","h","g","g","e"), levels = levs) + xw <- c(1, 2, 0, 3, 0) + xw_factor <- factor(c("e","f","h","g","h"), levels = levs) + xw_ordered <- ordered(c("e","f","h","g","h"), levels = levs) + w <- c(1, 0, 1.25, 2, 1.75) + + expect_equal(modal_category(xw, w), modal_category(x)) + expect_equal(modal_category(xw_factor, w), modal_category(x_factor)) + expect_equal(modal_category(xw_ordered, w), modal_category(x_ordered)) + + # entropy(xw, w) is equal to entropy(x_factor) because entropy(x_factor) + # accounts for the missing level just as entropy(xw, w) accounts for the + # element with 0 weight. entropy(x) cannot do this. + expect_equal(entropy(xw, w), entropy(x_factor)) + expect_equal(entropy(xw_factor, w), entropy(x_factor)) + expect_equal(entropy(xw_ordered, w), entropy(x_ordered)) + + expect_equal(dissent(xw, w), dissent(x)) + expect_equal(dissent(xw_factor, w), dissent(x_factor)) + expect_equal(dissent(xw_ordered, w), dissent(x_ordered)) +}) diff --git a/tests/testthat/test-rvar-dist.R b/tests/testthat/test-rvar-dist.R index 09c2460..b4d7a65 100755 --- a/tests/testthat/test-rvar-dist.R +++ b/tests/testthat/test-rvar-dist.R @@ -41,12 +41,16 @@ test_that("distributional functions work on an rvar_factor", { x_values <- c(2,2,2,4,4,4,4,3,5,3) x_letters <- letters[x_values] x <- rvar_factor(x_letters, levels = letters[1:5]) + x2 <- c(rvar_factor(letters), rvar_factor(letters)) expect_equal(density(x, letters[1:6]), c(0, .3, .2, .4, .1, NA)) + expect_equal(density(x2, letters[1:3]), array(rep(1/26, 6), dim = c(3,2))) expect_equal(cdf(x, letters[1:5]), c(NA, NA, NA, NA, NA)) + expect_equal(cdf(x2, letters[1:3]), array(rep(NA, 6), dim = c(3,2))) expect_equal(quantile(x, 1:4/4), c(NA, NA, NA, NA)) + expect_equal(quantile(x2, 1:3/3), array(rep(NA, 6), dim = c(3,2))) }) test_that("distributional functions work on an rvar_ordered", { diff --git a/tests/testthat/test-rvar-print.R b/tests/testthat/test-rvar-print.R index fe28a6a..d2a8c49 100755 --- a/tests/testthat/test-rvar-print.R +++ b/tests/testthat/test-rvar-print.R @@ -108,6 +108,61 @@ test_that("print() works", { ) }) +test_that("printing weighted rvars works", { + w <- c(1, 0, 1.25, 2, 1.75) + levs <- c("h","e","f","g") + xw <- weight_draws(rvar(c(1, 2, 0, 3, 0)), w) + xw_factor <- weight_draws(rvar_factor(c("e","f","h","g","h"), levels = levs), w) + xw_ordered <- weight_draws(rvar_ordered(c("e","f","h","g","h"), levels = levs), w) + + out <- capture.output(print(xw, color = FALSE)) + expect_match( + out, + regexp = "weighted rvar<5>\\[1\\] mean . sd:", + all = FALSE + ) + expect_match( + out, + regexp = "1.2 . 1.3", + all = FALSE + ) + + out <- capture.output(print(xw, summary = "median_mad", color = FALSE)) + expect_match( + out, + regexp = "weighted rvar<5>\\[1\\] median . mad:", + all = FALSE + ) + expect_match( + out, + regexp = "0.64 . 0.94", + all = FALSE + ) + + out <- capture.output(print(xw_factor, color = FALSE)) + expect_match( + out, + regexp = "weighted rvar_factor<5>\\[1\\] mode :", + all = FALSE + ) + expect_match( + out, + regexp = "h <0.73>", + all = FALSE + ) + + out <- capture.output(print(xw_ordered, color = FALSE)) + expect_match( + out, + regexp = "weighted rvar_ordered<5>\\[1\\] mode :", + all = FALSE + ) + expect_match( + out, + regexp = "h <0.82>", + all = FALSE + ) +}) # str --------------------------------------------------------------------- diff --git a/tests/testthat/test-rvar-summaries-over-draws.R b/tests/testthat/test-rvar-summaries-over-draws.R index 5d1cd62..86713ab 100755 --- a/tests/testthat/test-rvar-summaries-over-draws.R +++ b/tests/testthat/test-rvar-summaries-over-draws.R @@ -193,3 +193,23 @@ test_that("anyNA works", { x_ord[2,1] <- NA expect_equal(anyNA(x_ord), TRUE) }) + + +# weighted summaries ------------------------------------------------------ + +test_that("weighted summaries work", { + x <- rvar(c(1,1,2,2,2,3,3,3,3)) + n <- ndraws(x) + w <- c(2,3,4,0) + xw <- weight_draws(rvar(c(1,2,3,4)), w) + + expect_equal(sum(xw), sum(x)) + expect_equal(prod(xw), prod(x)) + expect_equal(mean(xw), mean(x)) + expect_equal(median(xw), matrixStats::weightedMedian(draws_of(xw), w)) + expect_equal(mad(xw), matrixStats::weightedMad(draws_of(xw), w)) + # weighted var and sd don't use sample correction because it depends on + # knowing the sample size + expect_equal(var(xw), var(x)*(n-1)/n) + expect_equal(sd(xw), sqrt(var(x)*(n-1)/n)) +}) diff --git a/tests/testthat/test-weight_draws.R b/tests/testthat/test-weight_draws.R index 0bce32a..0230a83 100644 --- a/tests/testthat/test-weight_draws.R +++ b/tests/testthat/test-weight_draws.R @@ -63,6 +63,26 @@ test_that("weight_draws works on draws_rvars", { expect_equal(weights2, weights) }) +test_that("weights are propagated to variables in draws_rvars", { + d <- draws_rvars(x = rvar(1:10, log_weights = 2:11), y = 3:12) + expect_equal(log_weights(d$x), 2:11) + expect_equal(log_weights(d$y), 2:11) + + d <- draws_rvars(x = 1:10, y = 3:12, .log_weight = 2:11) + expect_equal(log_weights(d$x), 2:11) + expect_equal(log_weights(d$y), 2:11) + + expect_error( + draws_rvars(x = rvar(1:10, log_weights = 1:10), y = rvar(3:12, log_weights = 2:11)), + "different log weights" + ) + + expect_error( + draws_rvars(x = rvar(1:10, log_weights = 1:10), .log_weight = 2:11), + "different log weights" + ) +}) + # conversion preserves weights -------------------------------------------- test_that("conversion between formats preserves weights", { From 6c87a3bb77423c48a4c90b6f91fdfb1c641d05e1 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sat, 6 Jan 2024 15:21:32 -0600 Subject: [PATCH 08/43] allow weight_draws(x, NULL) to remove weights --- R/resample_draws.R | 2 +- R/rvar-.R | 4 +--- R/weight_draws.R | 19 ++++++++++++++----- man/weight_draws.Rd | 9 +++++---- man/weights.draws.Rd | 3 ++- tests/testthat/test-resample_draws.R | 6 ++++++ 6 files changed, 29 insertions(+), 14 deletions(-) diff --git a/R/resample_draws.R b/R/resample_draws.R index f14bbc8..5784b2d 100644 --- a/R/resample_draws.R +++ b/R/resample_draws.R @@ -72,7 +72,7 @@ resample_draws.draws <- function(x, weights = NULL, method = "stratified", weights <- rep.int(1/ndraws_total, ndraws_total) } # resampling invalidates stored weights - x <- remove_variables(x, ".log_weight") + x <- weight_draws(x, NULL) } else { weights <- weights / sum(weights) } diff --git a/R/rvar-.R b/R/rvar-.R index 85d1599..da4398c 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -136,9 +136,7 @@ new_rvar <- function(x = double(), .nchains = 1L, .log_weights = NULL) { .ndraws <- dim(x)[[1]] .nchains <- as_one_integer(.nchains) check_nchains_compat_with_ndraws(.nchains, .ndraws) - if (!is.null(.log_weights)) { - .log_weights <- validate_weights(.log_weights, .ndraws, log = TRUE, pareto_smooth = FALSE) - } + .log_weights <- validate_weights(.log_weights, .ndraws, log = TRUE, pareto_smooth = FALSE) structure( list(), diff --git a/R/weight_draws.R b/R/weight_draws.R index 19d5443..ecaf2ea 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -7,10 +7,11 @@ #' `draws` objects. #' #' @template args-methods-x -#' @param weights (numeric vector) A vector of weights of length `ndraws(x)`. -#' Weights will be internally stored on the log scale (in a variable called -#' `.log_weight`) and will not be normalized, but normalized (non-log) weights -#' can be returned via the [weights.draws()] method later. +#' @param weights (numeric vector) A vector of weights of length `ndraws(x)`, +#' or `NULL` to remove weights. Weights will be internally stored on the log +#' scale and will not be normalized. Normalized (non-log) weights can be +#' returned via the [weights.draws()] method, and the unnormalized +#' log weights can be accessed via [log_weights()]. #' @param log (logical) Are the weights passed already on the log scale? The #' default is `FALSE`, that is, expecting `weights` to be on the standard #' (non-log) scale. @@ -60,6 +61,8 @@ weight_draws <- function(x, weights, ...) { #' @export weight_draws.draws_matrix <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) + if (is.null(weights)) return(remove_variables(x, ".log_weight")) + if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, ".log_weight"] <- log_weights @@ -75,6 +78,8 @@ weight_draws.draws_matrix <- function(x, weights, log = FALSE, pareto_smooth = F #' @export weight_draws.draws_array <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) + if (is.null(weights)) return(remove_variables(x, ".log_weight")) + if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, , ".log_weight"] <- log_weights @@ -97,6 +102,8 @@ weight_draws.draws_df <- function(x, weights, log = FALSE, pareto_smooth = FALSE #' @export weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { log_weights <- validate_weights(weights, ndraws(x), log, pareto_smooth) + if (is.null(log_weights)) return(remove_variables(x, ".log_weight")) + niterations <- niterations(x) for (i in seq_len(nchains(x))) { sel <- (1 + (i - 1) * niterations):(i * niterations) @@ -136,7 +143,8 @@ weight_draws.rvar <- function(x, weights, log = FALSE, pareto_smooth = FALSE, .. #' the standard scale? Defaults to `TRUE`. #' @template args-methods-dots #' -#' @return A vector of weights, with one weight per draw. +#' @return A vector of weights, with one weight per draw, or `NULL` if this +#' object does not contain weights. #' #' @seealso [`weight_draws`], [`resample_draws`] #' @@ -204,6 +212,7 @@ log_weights.rvar <- function(object, ...) { # validate weights and return log weights validate_weights <- function(weights, ndraws, log = FALSE, pareto_smooth = FALSE) { + if (is.null(weights)) return(NULL) checkmate::assert_numeric(weights) checkmate::assert_flag(log) checkmate::assert_flag(pareto_smooth) diff --git a/man/weight_draws.Rd b/man/weight_draws.Rd index 0dbf869..673223f 100644 --- a/man/weight_draws.Rd +++ b/man/weight_draws.Rd @@ -28,10 +28,11 @@ weight_draws(x, weights, ...) \item{x}{(draws) A \code{draws} object or another \R object for which the method is defined.} -\item{weights}{(numeric vector) A vector of weights of length \code{ndraws(x)}. -Weights will be internally stored on the log scale (in a variable called -\code{.log_weight}) and will not be normalized, but normalized (non-log) weights -can be returned via the \code{\link[=weights.draws]{weights.draws()}} method later.} +\item{weights}{(numeric vector) A vector of weights of length \code{ndraws(x)}, +or \code{NULL} to remove weights. Weights will be internally stored on the log +scale and will not be normalized. Normalized (non-log) weights can be +returned via the \code{\link[=weights.draws]{weights.draws()}} method, and the unnormalized +log weights can be accessed via \code{\link[=log_weights]{log_weights()}}.} \item{...}{Arguments passed to individual methods (if applicable).} diff --git a/man/weights.draws.Rd b/man/weights.draws.Rd index 2f88b40..7b4d865 100644 --- a/man/weights.draws.Rd +++ b/man/weights.draws.Rd @@ -33,7 +33,8 @@ the standard scale? Defaults to \code{TRUE}.} \item{...}{Arguments passed to individual methods (if applicable).} } \value{ -A vector of weights, with one weight per draw. +A vector of weights, with one weight per draw, or \code{NULL} if this +object does not contain weights. } \description{ Extract weights from \code{\link{draws}} objects, with one weight per draw. diff --git a/tests/testthat/test-resample_draws.R b/tests/testthat/test-resample_draws.R index bea4980..42abba8 100644 --- a/tests/testthat/test-resample_draws.R +++ b/tests/testthat/test-resample_draws.R @@ -79,6 +79,12 @@ test_that("resample_draws works on rvars", { expect_true(mean_rs > 6660 && mean_rs < 6670) expect_true(is_rvar(x_rs)) + x_rs <- resample_draws(weight_draws(x, w), method = "stratified") + mean_rs <- mean(x_rs) + expect_true(mean_rs > 6660 && mean_rs < 6670) + expect_true(is_rvar(x_rs)) + expect_null(log_weights(x_rs)) + # without weights x_rs <- resample_draws(x, method = "stratified") mean_rs <- mean(x_rs) From 6ae97f8de21d2919da0f8a41ca6c9b140c213182 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sat, 6 Jan 2024 16:40:50 -0600 Subject: [PATCH 09/43] test fixes --- tests/testthat/test-print.R | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/testthat/test-print.R b/tests/testthat/test-print.R index 2b79fd3..155be0f 100644 --- a/tests/testthat/test-print.R +++ b/tests/testthat/test-print.R @@ -53,7 +53,6 @@ test_that("print.draws_list runs without errors", { test_that("print.draws_rvars runs without errors", { skip_on_cran() - skip_on_os("windows") x <- as_draws_rvars(example_draws()) out <- capture.output(print(x)) expect_match( @@ -65,7 +64,7 @@ test_that("print.draws_rvars runs without errors", { x <- weight_draws(x, rep(1, ndraws(x))) expect_output( print(x), - "hidden reserved variables ..\\.log_weight.." + "weighted rvar" ) }) @@ -112,7 +111,6 @@ test_that("print.draws_list handles reserved variables correctly", { test_that("print.draws_rvars handles reserved variables correctly", { skip_on_cran() - skip_on_os("windows") x <- as_draws_rvars(example_draws()) variables(x)[1] <- ".log_weight" # reserved name expect_output(print(x, max_variables = 1), "tau") From ecfbb8438ba8b256bb6b8e7c3a0cbcc95b6781cd Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sat, 6 Jan 2024 17:02:41 -0600 Subject: [PATCH 10/43] make test reliable --- tests/testthat/test-rstar.R | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/testthat/test-rstar.R b/tests/testthat/test-rstar.R index 6625b7a..65a4751 100644 --- a/tests/testthat/test-rstar.R +++ b/tests/testthat/test-rstar.R @@ -89,6 +89,7 @@ test_that("rstar accepts different hyperparameters", { test_that("rstar accepts different training proportion", { skip_if_not_installed("caret") x <- example_draws() + set.seed(12345) val1 <- rstar(x, method = "knn") val2 <- rstar(x, method = "knn", training_proportion = 0.1) expect_true(val1 > val2) From ca67b6ec0e12992d83279cf5db1680b3ef42bd70 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sun, 7 Jan 2024 16:56:54 -0600 Subject: [PATCH 11/43] add documentation of rvar internals --- R/rvar-.R | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++ man/rvar.Rd | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/R/rvar-.R b/R/rvar-.R index da4398c..10f09d4 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -58,6 +58,62 @@ #' on the underlying array using the [draws_of()] function. To re-use existing #' random number generator functions to efficiently create `rvar`s, use [rvar_rng()]. #' +#' @section `rvar` Internals: +#' +#' The `rvar` datatype is not intended to be modified directly; rather, you should +#' only use exported functions from \pkg{posterior}, such as [rvar()], [draws_of()], +#' [log_weights()], and [weight_draws()] to create and manipulate `rvar`s. +#' For completeness, and to aid internal development, this section documents the +#' internal structure of the `rvar` datatype. While the public-facing API is +#' intended to be stable, **this internal structure is subject to change without +#' notice**. +#' +#' An `rvar` `x` consists of: +#' +#' - A zero-length `list()` with class `c("rvar", "vctrs_vctr")`. If `draws_of(x)` +#' is a [`factor`], the class will be `c("rvar_factor", "rvar", "vctrs_vctr")`, +#' and if `draws_of(x)` is an [`ordered`], the class will be +#' `c("rvar_ordered", "rvar_factor", "rvar", "vctrs_vctr")`. These classes are +#' set automatically if the underlying draws are modified. +#' +#' The list has these attributes: +#' +#' - `draws`: An [`array`] containing the draws, where the first dimension +#' indexes draws. **Always** get this attribute using [draws_of()] and set it +#' using `draws_of(x) <- value`. To simplify programming, `dim(draws_of(x))` +#' is guaranteed to always be greater than or equal to 2. Zero-length `rvar`s +#' have `dim(draws_of(x)) = c(1,0)`. The draws may be a [`numeric`], +#' [`integer`], [`logical`], [`factor`], or [`ordered`] array. +#' +#' The dimensions after the first are reported as the dimensions of `x`; i.e. +#' `dim(x) == dim(draws_of(x))[-1]` and `dimnames(x) = dimnames(draws_of(x))[-1]`. +#' Because `rvar`s *always* have dimensions (unlike base R datatypes, where +#' there is a distinction between a length-*n* vector with no dimensions and +#' a length-*n* array with only 1 dimension), `names(x) = dimnames(x)[[1]]`; +#' i.e., `names()` refers to the names along the first dimension only. +#' +#' - `nchains`: A scalar [`numeric`] giving the number of chains in this `rvar`. +#' **Always** get this attribute using [nchains()]. It cannot be set using the +#' public (exported) API, but can be modified through other functions (e.g. +#' [merge_chains()] or creating a new [rvar()]). In internal code, **always** +#' set it using `nchains_rvar(x) <- value`. +#' +#' - `log_weights`: A vector [`numeric`] with length `ndraws(x)` giving the +#' log weight on each draw of this `rvar`, or `NULL` if the `rvar` is not +#' weighted. **Always** get this attribute using [weights()] or [log_weights()], +#' and set this attributes using [weight_draws()]. In internal code, it may +#' also be modified directly using `log_weights_rvar(x) <- value`. +#' +#' - `cache`: An [`environment`] that may contain cached output of the \pkg{vctrs} +#' proxy functions on `x` to improve performance of code that makes multiple +#' calls to these functions. The cache is updated automatically and invalidated +#' when necessary so long as the `rvar` is only modified using the functions +#' described in this section (or other functions in the publicly-exported +#' `rvar` API). The environment may contain these variables: +#' +#' - `vec_proxy`: cached output of [vctrs::vec_proxy()]. +#' - `vec_proxy_equal`: cached output of [vctrs::vec_proxy_equal()]. +#' #' @seealso [as_rvar()] to convert objects to `rvar`s. See [rdo()], [rfun()], and #' [rvar_rng()] for higher-level interfaces for creating `rvar`s. #' diff --git a/man/rvar.Rd b/man/rvar.Rd index 2dc1861..30a0c4d 100755 --- a/man/rvar.Rd +++ b/man/rvar.Rd @@ -89,6 +89,64 @@ As \code{\link[=rfun]{rfun()}} and \code{\link[=rdo]{rdo()}} incur some performa on the underlying array using the \code{\link[=draws_of]{draws_of()}} function. To re-use existing random number generator functions to efficiently create \code{rvar}s, use \code{\link[=rvar_rng]{rvar_rng()}}. } +\section{\code{rvar} Internals}{ + + +The \code{rvar} datatype is not intended to be modified directly; rather, you should +only use exported functions from \pkg{posterior}, such as \code{\link[=rvar]{rvar()}}, \code{\link[=draws_of]{draws_of()}}, +\code{\link[=log_weights]{log_weights()}}, and \code{\link[=weight_draws]{weight_draws()}} to create and manipulate \code{rvar}s. +For completeness, and to aid internal development, this section documents the +internal structure of the \code{rvar} datatype. While the public-facing API is +intended to be stable, \strong{this internal structure is subject to change without +notice}. + +An \code{rvar} \code{x} consists of: +\itemize{ +\item A zero-length \code{list()} with class \code{c("rvar", "vctrs_vctr")}. If \code{draws_of(x)} +is a \code{\link{factor}}, the class will be \code{c("rvar_factor", "rvar", "vctrs_vctr")}, +and if \code{draws_of(x)} is an \code{\link{ordered}}, the class will be +\code{c("rvar_ordered", "rvar_factor", "rvar", "vctrs_vctr")}. These classes are +set automatically if the underlying draws are modified. + +The list has these attributes: +\itemize{ +\item \code{draws}: An \code{\link{array}} containing the draws, where the first dimension +indexes draws. \strong{Always} get this attribute using \code{\link[=draws_of]{draws_of()}} and set it +using \code{draws_of(x) <- value}. To simplify programming, \code{dim(draws_of(x))} +is guaranteed to always be greater than or equal to 2. Zero-length \code{rvar}s +have \code{dim(draws_of(x)) = c(1,0)}. The draws may be a \code{\link{numeric}}, +\code{\link{integer}}, \code{\link{logical}}, \code{\link{factor}}, or \code{\link{ordered}} array. + +The dimensions after the first are reported as the dimensions of \code{x}; i.e. +\code{dim(x) == dim(draws_of(x))[-1]} and \code{dimnames(x) = dimnames(draws_of(x))[-1]}. +Because \code{rvar}s \emph{always} have dimensions (unlike base R datatypes, where +there is a distinction between a length-\emph{n} vector with no dimensions and +a length-\emph{n} array with only 1 dimension), \code{names(x) = dimnames(x)[[1]]}; +i.e., \code{names()} refers to the names along the first dimension only. +\item \code{nchains}: A scalar \code{\link{numeric}} giving the number of chains in this \code{rvar}. +\strong{Always} get this attribute using \code{\link[=nchains]{nchains()}}. It cannot be set using the +public (exported) API, but can be modified through other functions (e.g. +\code{\link[=merge_chains]{merge_chains()}} or creating a new \code{\link[=rvar]{rvar()}}). In internal code, \strong{always} +set it using \code{nchains_rvar(x) <- value}. +\item \code{log_weights}: A vector \code{\link{numeric}} with length \code{ndraws(x)} giving the +log weight on each draw of this \code{rvar}, or \code{NULL} if the \code{rvar} is not +weighted. \strong{Always} get this attribute using \code{\link[=weights]{weights()}} or \code{\link[=log_weights]{log_weights()}}, +and set this attributes using \code{\link[=weight_draws]{weight_draws()}}. In internal code, it may +also be modified directly using \code{log_weights_rvar(x) <- value}. +\item \code{cache}: An \code{\link{environment}} that may contain cached output of the \pkg{vctrs} +proxy functions on \code{x} to improve performance of code that makes multiple +calls to these functions. The cache is updated automatically and invalidated +when necessary so long as the \code{rvar} is only modified using the functions +described in this section (or other functions in the publicly-exported +\code{rvar} API). The environment may contain these variables: +\itemize{ +\item \code{vec_proxy}: cached output of \code{\link[vctrs:vec_proxy]{vctrs::vec_proxy()}}. +\item \code{vec_proxy_equal}: cached output of \code{\link[vctrs:vec_proxy_equal]{vctrs::vec_proxy_equal()}}. +} +} +} +} + \examples{ set.seed(1234) From 739a5b4e151c8978d9569f8f38efe2a2cfeb1251 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Mon, 8 Jan 2024 00:21:00 -0600 Subject: [PATCH 12/43] minor edits to docs --- R/rvar-.R | 8 ++++---- man/rvar.Rd | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/R/rvar-.R b/R/rvar-.R index 10f09d4..f2289b0 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -80,13 +80,13 @@ #' #' - `draws`: An [`array`] containing the draws, where the first dimension #' indexes draws. **Always** get this attribute using [draws_of()] and set it -#' using `draws_of(x) <- value`. To simplify programming, `dim(draws_of(x))` +#' using `draws_of(x) <- value`. To simplify programming, `length(dim(draws_of(x)))` #' is guaranteed to always be greater than or equal to 2. Zero-length `rvar`s #' have `dim(draws_of(x)) = c(1,0)`. The draws may be a [`numeric`], #' [`integer`], [`logical`], [`factor`], or [`ordered`] array. #' #' The dimensions after the first are reported as the dimensions of `x`; i.e. -#' `dim(x) == dim(draws_of(x))[-1]` and `dimnames(x) = dimnames(draws_of(x))[-1]`. +#' `dim(x) = dim(draws_of(x))[-1]` and `dimnames(x) = dimnames(draws_of(x))[-1]`. #' Because `rvar`s *always* have dimensions (unlike base R datatypes, where #' there is a distinction between a length-*n* vector with no dimensions and #' a length-*n* array with only 1 dimension), `names(x) = dimnames(x)[[1]]`; @@ -95,7 +95,7 @@ #' - `nchains`: A scalar [`numeric`] giving the number of chains in this `rvar`. #' **Always** get this attribute using [nchains()]. It cannot be set using the #' public (exported) API, but can be modified through other functions (e.g. -#' [merge_chains()] or creating a new [rvar()]). In internal code, **always** +#' [merge_chains()] or by creating a new [rvar()]). In internal code, **always** #' set it using `nchains_rvar(x) <- value`. #' #' - `log_weights`: A vector [`numeric`] with length `ndraws(x)` giving the @@ -106,7 +106,7 @@ #' #' - `cache`: An [`environment`] that may contain cached output of the \pkg{vctrs} #' proxy functions on `x` to improve performance of code that makes multiple -#' calls to these functions. The cache is updated automatically and invalidated +#' calls to those functions. The cache is updated automatically and invalidated #' when necessary so long as the `rvar` is only modified using the functions #' described in this section (or other functions in the publicly-exported #' `rvar` API). The environment may contain these variables: diff --git a/man/rvar.Rd b/man/rvar.Rd index 30a0c4d..24c745e 100755 --- a/man/rvar.Rd +++ b/man/rvar.Rd @@ -112,13 +112,13 @@ The list has these attributes: \itemize{ \item \code{draws}: An \code{\link{array}} containing the draws, where the first dimension indexes draws. \strong{Always} get this attribute using \code{\link[=draws_of]{draws_of()}} and set it -using \code{draws_of(x) <- value}. To simplify programming, \code{dim(draws_of(x))} +using \code{draws_of(x) <- value}. To simplify programming, \code{length(dim(draws_of(x)))} is guaranteed to always be greater than or equal to 2. Zero-length \code{rvar}s have \code{dim(draws_of(x)) = c(1,0)}. The draws may be a \code{\link{numeric}}, \code{\link{integer}}, \code{\link{logical}}, \code{\link{factor}}, or \code{\link{ordered}} array. The dimensions after the first are reported as the dimensions of \code{x}; i.e. -\code{dim(x) == dim(draws_of(x))[-1]} and \code{dimnames(x) = dimnames(draws_of(x))[-1]}. +\code{dim(x) = dim(draws_of(x))[-1]} and \code{dimnames(x) = dimnames(draws_of(x))[-1]}. Because \code{rvar}s \emph{always} have dimensions (unlike base R datatypes, where there is a distinction between a length-\emph{n} vector with no dimensions and a length-\emph{n} array with only 1 dimension), \code{names(x) = dimnames(x)[[1]]}; @@ -126,7 +126,7 @@ i.e., \code{names()} refers to the names along the first dimension only. \item \code{nchains}: A scalar \code{\link{numeric}} giving the number of chains in this \code{rvar}. \strong{Always} get this attribute using \code{\link[=nchains]{nchains()}}. It cannot be set using the public (exported) API, but can be modified through other functions (e.g. -\code{\link[=merge_chains]{merge_chains()}} or creating a new \code{\link[=rvar]{rvar()}}). In internal code, \strong{always} +\code{\link[=merge_chains]{merge_chains()}} or by creating a new \code{\link[=rvar]{rvar()}}). In internal code, \strong{always} set it using \code{nchains_rvar(x) <- value}. \item \code{log_weights}: A vector \code{\link{numeric}} with length \code{ndraws(x)} giving the log weight on each draw of this \code{rvar}, or \code{NULL} if the \code{rvar} is not @@ -135,7 +135,7 @@ and set this attributes using \code{\link[=weight_draws]{weight_draws()}}. In in also be modified directly using \code{log_weights_rvar(x) <- value}. \item \code{cache}: An \code{\link{environment}} that may contain cached output of the \pkg{vctrs} proxy functions on \code{x} to improve performance of code that makes multiple -calls to these functions. The cache is updated automatically and invalidated +calls to those functions. The cache is updated automatically and invalidated when necessary so long as the \code{rvar} is only modified using the functions described in this section (or other functions in the publicly-exported \code{rvar} API). The environment may contain these variables: From 19d2ff7c4ede60fe07cd08d754cd4f6965b91c57 Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 17 Jan 2024 17:28:32 +0200 Subject: [PATCH 13/43] updating pareto functions for weighted rvars --- R/pareto_smooth.R | 109 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 88 insertions(+), 21 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 7da95a7..295b2ee 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -46,21 +46,57 @@ pareto_khat.default <- function(x, #' @rdname pareto_khat #' @export -pareto_khat.rvar <- function(x, ...) { - draws_diags <- summarise_rvar_by_element_with_chains( - x, - pareto_smooth.default, - return_k = TRUE, - smooth_draws = FALSE, - ... - ) - dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) - margins <- seq_along(dim(draws_diags)) +pareto_khat.rvar <- function(x, verbose = FALSE, ...) { + if (is.null(weights(x))) { + draws_diags <- summarise_rvar_by_element_with_chains( + x, + pareto_smooth.default, + return_k = TRUE, + smooth_draws = FALSE, + verbose = verbose, + ... + ) - diags <- list( - khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) - ) + dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) + margins <- seq_along(dim(draws_diags)) + + diags <- list( + khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) + ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_khat( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_khat, + ... + ) + + print(weights_diags) + print(product_diags) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, + function(x) { + max(x[[1]]$khat, + weights_diags$khat) + }) + ) + } diags } @@ -138,7 +174,7 @@ pareto_diags.default <- function(x, extra_diags = TRUE, verbose = verbose, smooth_draws = FALSE, - are_log_weights = FALSE, + are_log_weights = are_log_weights, ...) return(smoothed$diagnostics) @@ -149,6 +185,8 @@ pareto_diags.default <- function(x, #' @rdname pareto_diags #' @export pareto_diags.rvar <- function(x, ...) { + + if (is.null(weights(x))) { draws_diags <- summarise_rvar_by_element_with_chains( x, pareto_smooth.default, @@ -167,6 +205,35 @@ pareto_diags.rvar <- function(x, ...) { khat_threshold = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat_threshold), convergence_rate = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$convergence_rate) ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_diags( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_diags, + ... + ) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, function(x) max(x[[1]]$khat, weights_diags$khat)), + min_ss = apply(product_diags, margins, function(x) max(x[[1]]$min_ss, weights_diags$min_ss)), + khat_threshold = apply(product_diags, margins, function(x) max(x[[1]]$khat_threshold, weights_diags$khat_threshold)), + convergence_rate = apply(product_diags, margins, function(x) min(x[[1]]$convergence_rate, weights_diags$convergence_rate)) + ) + } diags } @@ -279,7 +346,7 @@ pareto_smooth.default <- function(x, if (are_log_weights) { tail <- "right" } - + tail <- match.arg(tail) S <- length(x) @@ -330,7 +397,7 @@ pareto_smooth.default <- function(x, k <- max(left_k, right_k) x <- smoothed$x - + } else { smoothed <- .pareto_smooth_tail( @@ -443,7 +510,7 @@ pareto_convergence_rate.rvar <- function(x, ...) { # shift log values for safe exponentiation x <- x - max(x) } - + tail <- match.arg(tail) S <- length(x) @@ -457,10 +524,10 @@ pareto_convergence_rate.rvar <- function(x, ...) { draws_tail <- ord$x[tail_ids] cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values - + max_tail <- max(draws_tail) min_tail <- min(draws_tail) - + if (ndraws_tail >= 5) { ord <- sort.int(x, index.return = TRUE) if (abs(max_tail - min_tail) < .Machine$double.eps / 100) { @@ -616,7 +683,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { msg <- NULL if (!are_weights) { - + if (khat > 1) { msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.") } else { @@ -629,7 +696,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { } } else { if (khat > khat_threshold || khat > 0.7) { - msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") + msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } message("Pareto k-hat = ", round(khat, 2), ".", msg) From 53a0b785df7f9ee89505fd3db5d25fc34dfca605 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Wed, 17 Jan 2024 23:14:02 -0600 Subject: [PATCH 14/43] cleanup rvar conform functions and drop unused keep_constants arg --- R/as_draws_rvars.R | 6 +++- R/mutate_variables.R | 2 +- R/rvar-.R | 70 ++++++++++++++++++++++++++++---------------- R/rvar-bind.R | 4 +-- R/rvar-math.R | 6 ++-- R/rvar-rfun.R | 2 +- R/rvar-slice.R | 10 +++---- 7 files changed, 62 insertions(+), 38 deletions(-) diff --git a/R/as_draws_rvars.R b/R/as_draws_rvars.R index 038dcc7..307c0a5 100755 --- a/R/as_draws_rvars.R +++ b/R/as_draws_rvars.R @@ -223,7 +223,7 @@ as_draws_rvars.mcmc.list <- function(x, ...) { check_new_variables(names(x)) - x <- conform_rvar_ndraws_nchains(x) + x <- conform_rvar_nchains_ndraws_weights(x) class(x) <- class_draws_rvars() @@ -236,6 +236,10 @@ as_draws_rvars.mcmc.list <- function(x, ...) { x$.log_weight <- NULL x <- weight_draws(x, .log_weight, log = TRUE) } else { + # if we reach this point either existing_weights and .log_weight + # are identical (so we don't have to do anything) or they aren't + # and weights2_common will throw the appropriate error --- thus + # we don't need to do anything with its output weights2_common(existing_weights, .log_weight) } } diff --git a/R/mutate_variables.R b/R/mutate_variables.R index 4ead827..fd42dd1 100644 --- a/R/mutate_variables.R +++ b/R/mutate_variables.R @@ -86,7 +86,7 @@ mutate_variables.draws_rvars <- function(.x, ...) { for (var in names(dots)) { .x[[var]] <- as_rvar(eval_tidy(dots[[var]], .x, env)) } - conform_rvar_ndraws_nchains(.x) + conform_rvar_nchains_ndraws_weights(.x) } # evaluate an expression passed to 'mutate_variables' and check its validity diff --git a/R/rvar-.R b/R/rvar-.R index f2289b0..fc7b50f 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -487,7 +487,7 @@ rvar_ifelse = function(test, yes, no) { stop_no_call("`rvar_ifelse(test, yes, no)` requires `test` to be a logical rvar, or castable to one.") } c(yes, no) %<-% vec_cast_common(yes, no) - c(test, yes, no) %<-% conform_array_dims(conform_rvar_ndraws(list(test, yes, no))) + c(test, yes, no) %<-% conform_array_dims(conform_rvar_ndraws_weights(list(test, yes, no))) test_draws <- draws_of(test) false_draws <- test_draws %in% FALSE @@ -631,8 +631,12 @@ check_nchains_compat_with_ndraws <- function(nchains, ndraws) { } } -# given a list of rvars, conform their number of chains -# so they can be used together (or throw an error if they can't be) +#' given a list of rvars, conform their number of chains +#' so they can be used together (or throw an error if they can't be). Constants +#' are treated as having any number of draws +#' @param rvars a list of rvars +#' @returns modified list of rvars all having the same number of chains +#' @noRd conform_rvar_nchains <- function(rvars) { # find the number of chains to use, treating constants as having any number of chains nchains_or_null <- lapply(rvars, function(x) if (ndraws(x) == 1) NULL else nchains(x)) @@ -645,27 +649,46 @@ conform_rvar_nchains <- function(rvars) { rvars } -# given a list of rvars, conform their number of draws -# so they can be used together (or throw an error if they can't be) -# @param keep_constants keep constants as 1-draw rvars -conform_rvar_ndraws <- function(rvars, keep_constants = FALSE) { - # broadcast to a common number of draws and the same set of weights. - # If keep_constants = TRUE, constants will not be broadcast or re-weighted. - .ndraws = Reduce(ndraws2_common, lapply(rvars, ndraws)) +#' given a list of rvars, conform their their weights +#' so they can be used together (or throw an error if they can't be) +#' @param rvars a list of rvars +#' @returns modified list of rvars all having the same weights. +#' @noRd +conform_rvar_weights <- function(rvars) { .log_weights = Reduce(weights2_common, lapply(rvars, log_weights)) + for (i in seq_along(rvars)) { - rvars[[i]] <- broadcast_draws(rvars[[i]], .ndraws, keep_constants, .log_weights = .log_weights) + log_weights_rvar(rvars[[i]]) <- .log_weights } rvars } -# given multiple rvars, conform their number of draws and chains -# so they can be used together (or throw an error if they can't be) -# @param keep_constants keep constants as 1-draw rvars -conform_rvar_ndraws_nchains <- function(rvars, keep_constants = FALSE) { +#' given a list of rvars, conform their number of draws and their weights +#' so they can be used together (or throw an error if they can't be) +#' @param rvars a list of rvars +#' @returns modified list of rvars all having the same number of draws and the +#' same weights. +#' @noRd +conform_rvar_ndraws_weights <- function(rvars) { + .ndraws = Reduce(ndraws2_common, lapply(rvars, ndraws)) + + for (i in seq_along(rvars)) { + rvars[[i]] <- broadcast_draws(rvars[[i]], .ndraws) + } + + conform_rvar_weights(rvars) +} + +#' given a list of rvars, conform their number of draws, number of chains, and +#' their weights so they can be used together (or throw an error if they can't be) +#' @param rvars a list of rvars +#' @returns modified list of rvars all having the same number of chains, same +#' number of draws, and the same weights. +#' @noRd +conform_rvar_nchains_ndraws_weights <- function(rvars) { rvars <- conform_rvar_nchains(rvars) - rvars <- conform_rvar_ndraws(rvars) + rvars <- conform_rvar_ndraws_weights(rvars) rvars } @@ -810,20 +833,17 @@ broadcast_array <- function(x, dim, broadcast_scalars = TRUE) { } # broadcast the draws dimension of an rvar to the requested size -broadcast_draws <- function(x, .ndraws, keep_constants = FALSE, .log_weights = NULL) { +broadcast_draws <- function(x, .ndraws) { ndraws_x = ndraws(x) - if (ndraws_x == 1 && keep_constants) { - x - } else if (ndraws_x == .ndraws) { - log_weights_rvar(x) <- .log_weights - x - } else { + + if (ndraws_x != .ndraws) { draws <- draws_of(x) new_dim <- dim(draws) new_dim[1] <- .ndraws - - new_rvar(broadcast_array(draws, new_dim), .nchains = nchains(x), .log_weights = .log_weights) + draws_of(x) <- broadcast_array(draws, new_dim) } + + x } # flatten dimensions and names of an array diff --git a/R/rvar-bind.R b/R/rvar-bind.R index 947dc8b..906a1b4 100755 --- a/R/rvar-bind.R +++ b/R/rvar-bind.R @@ -92,7 +92,7 @@ broadcast_and_bind_rvars.rvar <- function(x, y, axis = 1) { # conform nchains and weights # (don't need to do draws here since that's part of the broadcast below) c(x, y) %<-% conform_rvar_nchains(list(x, y)) - log_weights <- weights2_common(log_weights(x), log_weights(y)) + c(x, y) %<-% conform_rvar_weights(list(x, y)) # broadcast each array to the desired dimensions # (except along the axis we are binding along) @@ -114,7 +114,7 @@ broadcast_and_bind_rvars.rvar <- function(x, y, axis = 1) { result <- new_rvar( abind(draws_x, draws_y, along = draws_axis, use.dnns = TRUE), .nchains = nchains(x), - .log_weights = log_weights + .log_weights = log_weights(x) ) } diff --git a/R/rvar-math.R b/R/rvar-math.R index 550752c..969bfe5 100755 --- a/R/rvar-math.R +++ b/R/rvar-math.R @@ -15,7 +15,7 @@ Ops.rvar <- function(e1, e2) { .Ops.rvar <- function(f, e1, e2, preserve_dims = FALSE) { c(e1, e2) %<-% conform_rvar_nchains(list(e1, e2)) - .log_weights <- weights2_common(log_weights(e1), log_weights(e2)) + c(e1, e2) %<-% conform_rvar_weights(list(e1, e2)) draws_x <- draws_of(e1) draws_y <- draws_of(e2) @@ -48,7 +48,7 @@ Ops.rvar <- function(e1, e2) { draws <- while_preserving_dims(function(...) draws, dim_source) } - new_rvar(draws, .nchains = nchains(e1), .log_weights = .log_weights) + new_rvar(draws, .nchains = nchains(e1), .log_weights = log_weights(e1)) } #' @export @@ -189,7 +189,7 @@ Math.rvar_factor <- function(x, ...) { } # conform the draws dimension in both variables - c(x, y) %<-% conform_rvar_ndraws_nchains(list(x, y)) + c(x, y) %<-% conform_rvar_nchains_ndraws_weights(list(x, y)) # drop the names of the dimensions (mul.tensor gets uppity if dimension names # are duplicated, but we don't care about that) diff --git a/R/rvar-rfun.R b/R/rvar-rfun.R index ac6512b..79210e6 100755 --- a/R/rvar-rfun.R +++ b/R/rvar-rfun.R @@ -232,7 +232,7 @@ rvar_rng <- function(.f, n, ..., ndraws = NULL) { args <- list(...) is_rvar_arg <- vapply(args, is_rvar, logical(1)) - rvar_args <- conform_rvar_ndraws_nchains(args[is_rvar_arg]) + rvar_args <- conform_rvar_nchains_ndraws_weights(args[is_rvar_arg]) if (length(rvar_args) < 1) { nchains <- 1 diff --git a/R/rvar-slice.R b/R/rvar-slice.R index e33a982..fcb8bd9 100755 --- a/R/rvar-slice.R +++ b/R/rvar-slice.R @@ -159,7 +159,7 @@ NULL #' @export `[[<-.rvar` <- function(x, i, ..., value) { value <- vec_cast(value, x) - c(x, value) %<-% conform_rvar_ndraws_nchains(list(x, value)) + c(x, value) %<-% conform_rvar_nchains_ndraws_weights(list(x, value)) value <- check_rvar_dims_first(value, new_rvar(0)) index <- check_rvar_yank_index(x, i, ...) @@ -220,7 +220,7 @@ NULL # this kind of indexing must ignore chains nchains_rvar(x) <- 1L nchains_rvar(i) <- 1L - c(x, i) %<-% conform_rvar_ndraws(list(x, i)) + c(x, i) %<-% conform_rvar_ndraws_weights(list(x, i)) index <- list() draws_index <- list(draws_of(i)) } else { @@ -316,7 +316,7 @@ NULL # for the purposes of this kind of assignment, we check draws only, not chains, # as chain information is irrelevant when subsetting by draw - c(x, i) %<-% conform_rvar_ndraws(list(x, i)) + c(x, i) %<-% conform_rvar_ndraws_weights(list(x, i)) draws_index <- draws_of(i) # necessary number of draws in `value` is determined by whether or not @@ -325,7 +325,7 @@ NULL draws_of(value) <- broadcast_array(draws_of(value), c(value_ndraws, dim(x)), broadcast_scalars = FALSE) i <- missing_arg() } else { - c(x, value) %<-% conform_rvar_ndraws_nchains(list(x, value)) + c(x, value) %<-% conform_rvar_nchains_ndraws_weights(list(x, value)) draws_index <- missing_arg() } @@ -380,7 +380,7 @@ scalar_numeric_rvar_to_index <- function(i_rvar, x, ...) { if (!is.numeric(draws_of(i_rvar)) || length(i_rvar) != 1) { stop_no_call("`x[[i]]` for rvars `x` and `i` is only supported when `i` is a scalar numeric rvar.") } - out <- conform_rvar_ndraws_nchains(list(i_rvar, x, ...)) + out <- conform_rvar_nchains_ndraws_weights(list(i_rvar, x, ...)) c(i_rvar, x) %<-% out[1:2] out[[1]] <- matrix_to_index(cbind(seq_len(ndraws(x)), draws_of(i_rvar)), c(ndraws(x), length(x))) out From 476cdc24db5cb538fda79e6ba960093466adecd9 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Thu, 18 Jan 2024 00:17:48 -0600 Subject: [PATCH 15/43] prevent binding weighted and unweighted non-constant rvars --- R/rvar-.R | 44 +++++++++++++++++++++++++-------- R/rvar-bind.R | 2 +- tests/testthat/test-rvar-bind.R | 9 +++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/R/rvar-.R b/R/rvar-.R index fc7b50f..0a16d33 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -604,10 +604,14 @@ nchains2_common <- function(nchains_x, nchains_y) { } # find common weights for two rvars -weights2_common <- function(weights_x, weights_y) { - if (is.null(weights_x)) { +#' @param promote_unweighted should unweighted rvars be promoted to have the +#' weights of weighted rvars they are combined with? typically `FALSE` for +#' binding operations and `TRUE` for math operations. +#' @noRd +weights2_common <- function(weights_x, weights_y, promote_unweighted = TRUE) { + if (promote_unweighted && is.null(weights_x)) { weights_y - } else if (is.null(weights_y)) { + } else if (promote_unweighted && is.null(weights_y)) { weights_x } else if (identical(weights_x, weights_y)) { weights_x @@ -652,10 +656,20 @@ conform_rvar_nchains <- function(rvars) { #' given a list of rvars, conform their their weights #' so they can be used together (or throw an error if they can't be) #' @param rvars a list of rvars +#' @param promote_unweighted should unweighted rvars be promoted to have the +#' weights of weighted rvars they are combined with? typically `FALSE` for +#' binding operations and `TRUE` for math operations. #' @returns modified list of rvars all having the same weights. #' @noRd -conform_rvar_weights <- function(rvars) { - .log_weights = Reduce(weights2_common, lapply(rvars, log_weights)) +conform_rvar_weights <- function(rvars, promote_unweighted = TRUE) { + # only check rvars that are not constants --- constant rvars can + # always take on the weights of others + not_constant <- vapply(rvars, ndraws, numeric(1)) > 1 + weights_list <- lapply(rvars[not_constant], log_weights) + .log_weights <- Reduce( + function(...) weights2_common(..., promote_unweighted = promote_unweighted), + weights_list + ) for (i in seq_along(rvars)) { log_weights_rvar(rvars[[i]]) <- .log_weights @@ -667,28 +681,38 @@ conform_rvar_weights <- function(rvars) { #' given a list of rvars, conform their number of draws and their weights #' so they can be used together (or throw an error if they can't be) #' @param rvars a list of rvars +#' @param promote_unweighted should unweighted rvars be promoted to have the +#' weights of weighted rvars they are combined with? typically `FALSE` for +#' binding operations and `TRUE` for math operations. #' @returns modified list of rvars all having the same number of draws and the #' same weights. #' @noRd -conform_rvar_ndraws_weights <- function(rvars) { - .ndraws = Reduce(ndraws2_common, lapply(rvars, ndraws)) +conform_rvar_ndraws_weights <- function(rvars, promote_unweighted = TRUE) { + # must conform weights before ndraws so that constants are handled properly + rvars <- conform_rvar_weights(rvars, promote_unweighted = promote_unweighted) + + .ndraws <- Reduce(ndraws2_common, lapply(rvars, ndraws)) for (i in seq_along(rvars)) { rvars[[i]] <- broadcast_draws(rvars[[i]], .ndraws) } - conform_rvar_weights(rvars) + rvars } #' given a list of rvars, conform their number of draws, number of chains, and #' their weights so they can be used together (or throw an error if they can't be) #' @param rvars a list of rvars +#' @param promote_unweighted should unweighted rvars be promoted to have the +#' weights of weighted rvars they are combined with? typically `FALSE` for +#' binding operations and `TRUE` for math operations. #' @returns modified list of rvars all having the same number of chains, same #' number of draws, and the same weights. #' @noRd -conform_rvar_nchains_ndraws_weights <- function(rvars) { +conform_rvar_nchains_ndraws_weights <- function(rvars, promote_unweighted = TRUE) { + # must conform nchains before ndraws so that constants are handled properly rvars <- conform_rvar_nchains(rvars) - rvars <- conform_rvar_ndraws_weights(rvars) + rvars <- conform_rvar_ndraws_weights(rvars, promote_unweighted = promote_unweighted) rvars } diff --git a/R/rvar-bind.R b/R/rvar-bind.R index 906a1b4..eece5d9 100755 --- a/R/rvar-bind.R +++ b/R/rvar-bind.R @@ -92,7 +92,7 @@ broadcast_and_bind_rvars.rvar <- function(x, y, axis = 1) { # conform nchains and weights # (don't need to do draws here since that's part of the broadcast below) c(x, y) %<-% conform_rvar_nchains(list(x, y)) - c(x, y) %<-% conform_rvar_weights(list(x, y)) + c(x, y) %<-% conform_rvar_weights(list(x, y), promote_unweighted = FALSE) # broadcast each array to the desired dimensions # (except along the axis we are binding along) diff --git a/tests/testthat/test-rvar-bind.R b/tests/testthat/test-rvar-bind.R index 6f580ed..b10cc81 100755 --- a/tests/testthat/test-rvar-bind.R +++ b/tests/testthat/test-rvar-bind.R @@ -143,6 +143,15 @@ test_that("c works on rvar_ordered", { expect_equal(c(x_col, y), x_y) }) +test_that("binding weighted and unweighted rvars works", { + x = rvar(1:10) + xw = rvar(1:10, log_weights = 1:10) + + # binding weighted to unweighted constant is okay + expect_equal(c(xw, 1), rvar(cbind(1:10, 1), log_weights = 1:10)) + # but binding weights to unweighted non-constant is not okay + expect_error(c(xw, x), "different log weights") +}) # cbind.rvar -------------------------------------------------------------- From 534d7fb4d8cc10dff2e17ca9a582043aa7c6b022 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Thu, 18 Jan 2024 13:45:56 -0600 Subject: [PATCH 16/43] test coverage improvements --- R/rvar-.R | 24 ++++++---------- R/rvar-cast.R | 4 ++- R/rvar-slice.R | 2 +- tests/testthat/test-rvar-cast.R | 12 ++++++++ tests/testthat/test-rvar-print.R | 39 ++++++++++++++++++++++++++ tests/testthat/test-weight_draws.R | 45 ++++++++++++++++++++++++++++++ 6 files changed, 108 insertions(+), 18 deletions(-) diff --git a/R/rvar-.R b/R/rvar-.R index 0a16d33..444d6b7 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -716,24 +716,16 @@ conform_rvar_nchains_ndraws_weights <- function(rvars, promote_unweighted = TRUE rvars } -# Check that the first rvar can be conformed to the dimensions of the second, -# ignoring 1s -check_rvar_dims_first <- function(x, y) { - x_dim <- dim(x) - x_dim_dropped <- as.integer(x_dim[x_dim != 1]) - y_dim <- dim(y) - y_dim_dropped <- as.integer(y_dim[y_dim != 1]) - - if (length(x_dim_dropped) == 0) { - # x can be treated as scalar, do so - dim(x) <- rep(1, length(dim(y))) - } else if (identical(x_dim_dropped, y_dim_dropped)) { - dim(x) <- dim(y) - } else { - stop_no_call("Cannot assign an rvar with dimension ", paste0(x_dim, collapse = ","), - " to an rvar with dimension ", paste0(y_dim, collapse = ",")) +#' Check that an rvar is a scalar (length 1) +#' @param x rvar to check +#' @returns x with `dim(x) == 1`, or throws an error if `x` is not scalar. +#' @noRd +check_rvar_is_scalar <- function(x) { + if (length(x) != 1) { + stop_no_call("Cannot insert an rvar with length != 1 into another rvar using `[[`") } + dim(x) <- 1 x } diff --git a/R/rvar-cast.R b/R/rvar-cast.R index 43d6197..8a123d2 100755 --- a/R/rvar-cast.R +++ b/R/rvar-cast.R @@ -245,7 +245,9 @@ vec_restore.rvar <- function(x, ...) { # find runs where the same underlying draws are in the proxy different_draws_from_previous <- vapply(seq_along(x)[-1], FUN.VALUE = logical(1), function(i) { - !identical(x[[i]]$draws, x[[i - 1]]$draws) || !identical(x[[i]]$nchains, x[[i - 1]]$nchains) + !identical(x[[i]]$draws, x[[i - 1]]$draws) || + !identical(x[[i]]$nchains, x[[i - 1]]$nchains) || + !identical(x[[i]]$log_weights, x[[i - 1]]$log_weights) }) draws_groups <- cumsum(c(TRUE, different_draws_from_previous)) diff --git a/R/rvar-slice.R b/R/rvar-slice.R index fcb8bd9..4ea47bf 100755 --- a/R/rvar-slice.R +++ b/R/rvar-slice.R @@ -160,7 +160,7 @@ NULL `[[<-.rvar` <- function(x, i, ..., value) { value <- vec_cast(value, x) c(x, value) %<-% conform_rvar_nchains_ndraws_weights(list(x, value)) - value <- check_rvar_dims_first(value, new_rvar(0)) + value <- check_rvar_is_scalar(value) index <- check_rvar_yank_index(x, i, ...) if (length(index) == 1) { diff --git a/tests/testthat/test-rvar-cast.R b/tests/testthat/test-rvar-cast.R index f411c56..701f295 100755 --- a/tests/testthat/test-rvar-cast.R +++ b/tests/testthat/test-rvar-cast.R @@ -202,6 +202,18 @@ test_that("casting to/from rvar/distribution objects works", { expect_error(vctrs::vec_cast(x_mv, null_dist)) }) +test_that("vec_c works with rvar and distributions", { + x_dist <- distributional::dist_sample(list(a = 1:2, b = 3:4)) + y_dist <- distributional::dist_sample(list(c = 5:6, d = 7:8)) + xy_dist <- distributional::dist_sample(list(a = 1:2, b = 3:4, c = 5:6, d = 7:8)) + x_rvar <- rvar(matrix(c(1:4), ncol = 2, dimnames = list(NULL, c("a","b")))) + y_rvar <- rvar(matrix(c(5:8), ncol = 2, dimnames = list(NULL, c("c","d")))) + xy_rvar <- rvar(matrix(c(1:8), ncol = 4, dimnames = list(NULL, c("a","b","c","d")))) + + expect_equal(vctrs::vec_c(x_dist, y_rvar), xy_dist) + expect_equal(vctrs::vec_c(x_rvar, y_dist), xy_rvar) +}) + # type predicates --------------------------------------------------------- diff --git a/tests/testthat/test-rvar-print.R b/tests/testthat/test-rvar-print.R index d2a8c49..63bad5b 100755 --- a/tests/testthat/test-rvar-print.R +++ b/tests/testthat/test-rvar-print.R @@ -80,6 +80,14 @@ test_that("print() works", { regexp = "12 levels: a b c d e f g h i j k l", all = FALSE ) + + x_long <- rvar_factor(combn(letters, 2, paste, collapse = "")) + out <- capture.output(print(x_long, color = FALSE, width = 50)) + expect_match( + out, + regexp = "325 levels: ab ac ad ae af ag ah ai aj \\.\\.\\. yz", + all = FALSE + ) }) test_that("print() works", { @@ -255,9 +263,40 @@ test_that("str() works", { ) }) +test_that("str() works", { + x <- rvar(1:100, log_weights = 2:101) + + expect_output(str(weight_draws(rvar(), 1)), + " weighted rvar<1>\\[0\\] " + ) + out <- capture.output(str(x)) + expect_match( + out, + regexp = " weighted rvar<100>\\[1\\] 99 . 0.96", + all = FALSE + ) + expect_match( + out, + regexp = " - log_weights\\(\\*\\)= int \\[1:100\\] 2 3 4 5", + all = FALSE + ) +}) + # other ------------------------------------------------------------------- +test_that("tibble printing works", { + skip_on_cran() + + x <- rvar(1:10) + out <- capture.output(print(tibble::tibble(x))) + expect_match( + out, + regexp = " 5.5 . 3", + all = FALSE + ) +}) + test_that("glimpse on rvar works", { skip_on_cran() x_vec <- rvar(array(1:24, dim = c(6,4))) diff --git a/tests/testthat/test-weight_draws.R b/tests/testthat/test-weight_draws.R index 0230a83..1602396 100644 --- a/tests/testthat/test-weight_draws.R +++ b/tests/testthat/test-weight_draws.R @@ -9,6 +9,9 @@ test_that("weight_draws works on draws_matrix", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2) expect_equal(weights2, weights / sum(weights)) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_array", { @@ -22,6 +25,9 @@ test_that("weight_draws works on draws_array", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2, normalize = FALSE) expect_equal(weights2, weights) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_df", { @@ -35,6 +41,9 @@ test_that("weight_draws works on draws_df", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2) expect_equal(weights2, weights / sum(weights)) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_list", { @@ -48,6 +57,9 @@ test_that("weight_draws works on draws_list", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2, normalize = FALSE) expect_equal(weights2, weights) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_rvars", { @@ -61,6 +73,9 @@ test_that("weight_draws works on draws_rvars", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2, normalize = FALSE) expect_equal(weights2, weights) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weights are propagated to variables in draws_rvars", { @@ -83,6 +98,26 @@ test_that("weights are propagated to variables in draws_rvars", { ) }) +# removing weights works -------------------------------------------------- + +test_that("weights can be removed", { + x <- list( + matrix = as_draws_matrix(example_draws()), + array = as_draws_array(example_draws()), + df = as_draws_df(example_draws()), + list = as_draws_list(example_draws()), + rvars = as_draws_rvars(example_draws()), + rvar = as_draws_rvars(example_draws())$mu + ) + + weights <- rexp(ndraws(example_draws())) + x_weighted <- lapply(x, weight_draws, weights) + + for (type in names(x)) { + expect_equal(weight_draws(x_weighted[[!!type]], NULL), x[[!!type]]) + } +}) + # conversion preserves weights -------------------------------------------- test_that("conversion between formats preserves weights", { @@ -118,3 +153,13 @@ test_that("pareto smoothing smooths weights in weight_draws", { smoothed <- weight_draws(x, lw, pareto_smooth = TRUE, log = TRUE) expect_false(all(weights(weighted) == weights(smoothed))) }) + +# weights must match draws ------------------------------------------------ + +test_that("weights must match draws", { + x <- example_draws() + types <- list(as_draws_matrix, as_draws_array, as_draws_df, as_draws_list, as_draws_rvars) + for (type in types) { + expect_error(weight_draws((!!type)(x), 1), "weights must match .* draws") + } +}) From b0a778a71d6693122692b88a543870a6d4c54919 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Thu, 18 Jan 2024 20:27:00 -0600 Subject: [PATCH 17/43] density, cdf, quantiles for weighted rvar --- R/rvar-dist.R | 25 ++++--- R/weighted.R | 118 ++++++++++++++++++++++++++++++++ tests/testthat/test-rvar-dist.R | 52 ++++++++++++-- 3 files changed, 176 insertions(+), 19 deletions(-) create mode 100644 R/weighted.R diff --git a/R/rvar-dist.R b/R/rvar-dist.R index ee57c16..39bf2f7 100755 --- a/R/rvar-dist.R +++ b/R/rvar-dist.R @@ -40,8 +40,9 @@ #' @name rvar-dist #' @export density.rvar <- function(x, at, ...) { + weights <- weights(x) summarise_rvar_by_element(x, function(draws) { - d <- density(draws, cut = 0, ...) + d <- density(draws, weights = weights, cut = 0, ...) f <- approxfun(d$x, d$y, yleft = 0, yright = 0) f(at) }) @@ -50,11 +51,12 @@ density.rvar <- function(x, at, ...) { #' @rdname rvar-dist #' @export density.rvar_factor <- function(x, at, ...) { + weights <- weights(x) at <- as.numeric(factor(at, levels = levels(x))) - nbins <- nlevels(x) summarise_rvar_by_element(x, function(draws) { - props <- prop.table(tabulate(draws, nbins = nbins))[at] + tab <- weighted_simple_table(draws, weights) + props <- prop.table(tab$count)[at] props }) } @@ -66,8 +68,9 @@ distributional::cdf #' @rdname rvar-dist #' @export cdf.rvar <- function(x, q, ...) { + weights <- weights(x) summarise_rvar_by_element(x, function(draws) { - ecdf(draws)(q) + weighted_ecdf(draws, weights)(q) }) } @@ -76,7 +79,7 @@ cdf.rvar <- function(x, q, ...) { cdf.rvar_factor <- function(x, q, ...) { # CDF is not defined for unordered distributions # generate an all-NA array of the appropriate shape - out <- rep_len(NA, length(x) * length(q)) + out <- rep_len(NA_real_, length(x) * length(q)) if (length(x) > 1) dim(out) <- c(length(q), dim(x)) out } @@ -91,14 +94,10 @@ cdf.rvar_ordered <- function(x, q, ...) { #' @rdname rvar-dist #' @export quantile.rvar <- function(x, probs, ...) { - summarise_rvar_by_element_via_matrix(x, - "quantile", - function(draws) { - t(matrixStats::colQuantiles(draws, probs = probs, useNames = TRUE, ...)) - }, - .extra_dim = length(probs), - .extra_dimnames = list(NULL) - ) + weights <- weights(x) + summarise_rvar_by_element(x, function(draws) { + weighted_quantile(draws, probs = probs, weights = weights, ...) + }) } #' @rdname rvar-dist diff --git a/R/weighted.R b/R/weighted.R new file mode 100644 index 0000000..ec16ade --- /dev/null +++ b/R/weighted.R @@ -0,0 +1,118 @@ +# weighted distribution functions -------------------------------------------- + +#' Weighted version of [stats::ecdf()]. +#' Based on ggdist::weighted_ecdf(). +#' @noRd +weighted_ecdf = function(x, weights = NULL) { + n = length(x) + if (n < 1) stop("Need at least 1 or more values to calculate an ECDF") + + weights = if (is.null(weights)) rep(1, n) else weights + + #sort only if necessary + if (is.unsorted(x)) { + sort_order = order(x) + x = x[sort_order] + weights = weights[sort_order] + } + + # calculate weighted cumulative probabilities + p = cumsum(weights) + p = p/p[n] + + approxfun(x, p, yleft = 0, yright = 1, ties = "ordered", method = "constant") +} + +#' Weighted version of [stats::quantile()]. +#' Based on ggdist::weighted_quantile(). +#' @noRd +weighted_quantile = function(x, + probs = seq(0, 1, 0.25), + weights = NULL, + na.rm = FALSE, + type = 7 +) { + weighted_quantile_fun( + x, + weights = weights, + na.rm = na.rm, + type = type + )(probs) +} + +#' @rdname weighted_quantile +#' @export +weighted_quantile_fun = function(x, weights = NULL, na.rm = FALSE, type = 7) { + na.rm <- as_one_logical(na.rm) + if (!isTRUE(type %in% 1:9)) { + stop0("Quantile type `", deparse0(type), "` is invalid. It must be in 1:9.") + } + + if (na.rm) { + keep = !is.na(x) & !is.na(weights) + x = x[keep] + weights = weights[keep] + } + + # determine weights + weights = weights %||% rep(1, length(x)) + non_zero = weights != 0 + x = x[non_zero] + weights = weights[non_zero] + weights = weights / sum(weights) + + # if there is only 0 or 1 x values, we don't need the weighted version (and + # we couldn't calculate it anyway as we need > 2 points for the interpolation) + if (length(x) <= 1) { + return(function(p) quantile(x, p, names = FALSE)) + } + + # sort values if necessary + if (is.unsorted(x)) { + x_order = order(x) + x = x[x_order] + weights = weights[x_order] + } + + # calculate the weighted CDF + F_k = cumsum(weights) + + # generate the function for the approximate inverse CDF + if (1 <= type && type <= 3) { + # discontinuous quantiles + switch(type, + # type 1 + stepfun(F_k, c(x, x[length(x)]), right = TRUE), + # type 2 + { + x_over_2 = c(x, x[length(x)])/2 + inverse_cdf_type2_left = stepfun(F_k, x_over_2, right = FALSE) + inverse_cdf_type2_right = stepfun(F_k, x_over_2, right = TRUE) + function(x) inverse_cdf_type2_left(x) + inverse_cdf_type2_right(x) + }, + # type 3 + stepfun(F_k - weights/2, c(x[[1]], x), right = TRUE) + ) + } else { + # Continuous quantiles. These are based on the definition of p_k as described + # in the documentation of `quantile()`. The trick to re-writing those formulas + # (which use `n` and `k`) for the weighted case is that `k` = `F_k * n` and + # `1/n` = `weight_k`. Using these two facts, we can express the formulas for + # `p_k` without using `n` or `k`, which don't really apply in the weighted case. + p_k = switch(type - 3, + # type 4 + F_k, + # type 5 + F_k - weights/2, + # type 6 + F_k / (1 + weights), + # type 7 + (F_k - weights) / (1 - weights), + # type 8 + (F_k - weights/3) / (1 + weights/3), + # type 9 + (F_k - weights*3/8) / (1 + weights/4) + ) + approxfun(p_k, x, rule = 2, ties = "ordered") + } +} diff --git a/tests/testthat/test-rvar-dist.R b/tests/testthat/test-rvar-dist.R index b4d7a65..e3d44b0 100755 --- a/tests/testthat/test-rvar-dist.R +++ b/tests/testthat/test-rvar-dist.R @@ -33,7 +33,7 @@ test_that("distributional functions work on an rvar array", { q21 <- quantile(4:6, p) q12 <- quantile(7:9, p) q22 <- quantile(10:12, p) - x_quantiles <- array(c(q11, q21, q12, q22), dim = c(9, 2, 2), dimnames = list(NULL)) + x_quantiles <- array(c(q11, q21, q12, q22), dim = c(9, 2, 2)) expect_equal(quantile(x, p), x_quantiles) }) @@ -43,14 +43,14 @@ test_that("distributional functions work on an rvar_factor", { x <- rvar_factor(x_letters, levels = letters[1:5]) x2 <- c(rvar_factor(letters), rvar_factor(letters)) - expect_equal(density(x, letters[1:6]), c(0, .3, .2, .4, .1, NA)) + expect_equal(density(x, letters[1:6]), c(0, .3, .2, .4, .1, NA_real_)) expect_equal(density(x2, letters[1:3]), array(rep(1/26, 6), dim = c(3,2))) - expect_equal(cdf(x, letters[1:5]), c(NA, NA, NA, NA, NA)) - expect_equal(cdf(x2, letters[1:3]), array(rep(NA, 6), dim = c(3,2))) + expect_equal(cdf(x, letters[1:5]), c(NA_real_, NA_real_, NA_real_, NA_real_, NA_real_)) + expect_equal(cdf(x2, letters[1:3]), array(rep(NA_real_, 6), dim = c(3,2))) - expect_equal(quantile(x, 1:4/4), c(NA, NA, NA, NA)) - expect_equal(quantile(x2, 1:3/3), array(rep(NA, 6), dim = c(3,2))) + expect_equal(quantile(x, 1:4/4), c(NA_real_, NA_real_, NA_real_, NA_real_)) + expect_equal(quantile(x2, 1:3/3), array(rep(NA_real_, 6), dim = c(3,2))) }) test_that("distributional functions work on an rvar_ordered", { @@ -64,3 +64,43 @@ test_that("distributional functions work on an rvar_ordered", { expect_equal(quantile(x, c(.3, .5, .9, 1)), letters[2:5]) }) + +# weighted rvar ----------------------------------------------------------- + +test_that("weighted rvar works", { + x1_draws = qnorm(ppoints(10)) + x2_draws = qnorm(ppoints(10), 5) + w1 = rep(1, 10) + w2 = rep(2, 10) + w3 = rep(0, 10) + x = rvar(c(x1_draws, x2_draws, rep(10, 10)), log_weights = log(c(w1, w2, w3))) + + expect_equal( + density(x, 0:9, bw = 2.25), + density(draws_of(x), weights = weights(x), bw = 2.25, from = 0, to = 9, n = 10)$y, + tolerance = 1e-4 + ) + expect_equal(cdf(x, 0:9), ecdf(x1_draws)(0:9)/3 + ecdf(x2_draws)(0:9)*2/3) + expect_equal(quantile(x, cdf(x, c(x1_draws, x2_draws)), type = 1), c(x1_draws, x2_draws)) + expect_equal(quantile(x, cdf(x, c(x1_draws, x2_draws)), type = 4), c(x1_draws, x2_draws)) +}) + +test_that("weighted rvar_factor works", { + x = rvar_factor(c("b", "g", "f", "g"), levels = letters, log_weights = log(c(1/2, 1/6, 1/6, 1/6))) + + expect_equal(density(x, letters), c(0, 1/2, 0, 0, 0, 1/6, 1/3, rep(0, 19))) + expect_equal(cdf(x, letters), rep(NA_real_, 26)) + expect_equal(quantile(x, c(0.2, 0.8)), rep(NA_real_, 2)) +}) + +test_that("weighted rvar_ordered works", { + x = rvar_ordered(c("b", "g", "f", "g"), levels = letters, log_weights = log(c(1/2, 1/6, 1/6, 1/6))) + + expect_equal(density(x, letters), c(0, 1/2, 0, 0, 0, 1/6, 1/3, rep(0, 19))) + expect_equal(cdf(x, letters), cumsum(c(0, 1/2, 0, 0, 0, 1/6, 1/3, rep(0, 19)))) + expect_equal(quantile(x, c(0.2, 0.6, 0.8)), c("b", "f", "g")) + + xl = weight_draws(rvar_ordered(letters), 1:26) + expect_equal(quantile(xl, cdf(xl, letters) - .Machine$double.eps), letters) +}) + From f282f42e7ebd238fa8c368d2a66d83d143fd3811 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Sun, 21 Jan 2024 22:42:32 -0600 Subject: [PATCH 18/43] use toString instead of paste(collapse = ", ") --- R/rvar-.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/rvar-.R b/R/rvar-.R index 444d6b7..5590a2c 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -618,8 +618,8 @@ weights2_common <- function(weights_x, weights_y, promote_unweighted = TRUE) { } else { stop_no_call( "Random variables have different log weights and cannot be used together:\n", - "<", vctrs::vec_ptype_abbr(weights_x), "> ", paste(utils::head(weights_x, 5), collapse = ", "), " ...\n", - "<", vctrs::vec_ptype_abbr(weights_y), "> ", paste(utils::head(weights_y, 5), collapse = ", "), " ..." + "<", vctrs::vec_ptype_abbr(weights_x), "> ", toString(weights_x, width = 60), "\n", + "<", vctrs::vec_ptype_abbr(weights_y), "> ", toString(weights_y, width = 60) ) } } From 83439f4f84a2b6a20146a215c88b55cf8d1bde93 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Fri, 2 Feb 2024 18:22:24 -0600 Subject: [PATCH 19/43] add weights to rvar vignette --- vignettes/rvar.Rmd | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/vignettes/rvar.Rmd b/vignettes/rvar.Rmd index fe61835..fbf3abb 100755 --- a/vignettes/rvar.Rmd +++ b/vignettes/rvar.Rmd @@ -566,6 +566,35 @@ x This approach is also nice because it generalizes easily to more than two components. +## Weights + +Weighted `rvar`s can be created by passing log weights to the `log_weights` +parameter of `rvar()`, by using the `weight_draws()` function (as with `draws` +objects), or by converting a weighted `draws` object to a `draws_rvars` object. +Functions of `rvar`s, such as `mean()`, `sd()`, etc, support weights as +appropriate. + +For example, we can create an `rvar` that is a mixture of draws from +Normal(0,1) and Normal(5,1) distributions: + +```{r rvar_weighted} +x <- rvar(c(rnorm(10000, mean = c(0,5)))) +x +``` + +By default the mean is about 2.5, as the components with mean 0 and mean 5 +are weighted equally. However, if we weight the component with mean 5 twice +as much, then the summary display will show the appropriate weighted mean: + +```{r weighted_mean} +x <- weight_draws(x, rep(c(1,2), 5000)) +x +``` + +The latest version of [ggdist](https://mjskay.github.io/ggdist/) also supports +weighted `rvar`s, and will calculate histograms, densities, point summaries, and +intervals of `rvar`s correctly, accounting for weights. + ## Applying functions over `rvar`s The `rvar` data type supplies an implementation of `as.list()`, which should give From 22b9691d1cb93f3b42e8623b67bb8077b116f847 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Fri, 2 Feb 2024 19:37:16 -0600 Subject: [PATCH 20/43] test coverage improvements --- R/convergence.R | 16 +++++++++------- R/weighted.R | 15 +++++++++------ tests/testthat/test-convergence.R | 4 ++++ tests/testthat/test-rvar-dist.R | 16 ++++++++++++++++ tests/testthat/test-summarise_draws.R | 8 ++++---- 5 files changed, 42 insertions(+), 17 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index cf895f0..bbb8b12 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -541,12 +541,9 @@ quantile2.default <- function( ) { names <- as_one_logical(names) na.rm <- as_one_logical(na.rm) - if (!na.rm && anyNA(x)) { - # quantile itself doesn't handle this case (#110) - out <- rep(NA_real_, length(probs)) - } else { - out <- quantile(x, probs = probs, na.rm = na.rm, ...) - } + + out <- weighted_quantile(x, probs = probs, na.rm = na.rm, ...) + if (names) { names(out) <- paste0("q", probs * 100) } else { @@ -560,7 +557,12 @@ quantile2.default <- function( quantile2.rvar <- function( x, probs = c(0.05, 0.95), na.rm = FALSE, names = TRUE, ... ) { - summarise_rvar_by_element_with_chains(x, quantile2, probs, na.rm, names, ...) + weights <- weights(x) + summarise_rvar_by_element(x, function(draws) { + quantile2( + draws, probs = probs, weights = weights, na.rm = na.rm, names = names, ... + ) + }) } # internal ---------------------------------------------------------------- diff --git a/R/weighted.R b/R/weighted.R index ec16ade..f444ef7 100644 --- a/R/weighted.R +++ b/R/weighted.R @@ -30,28 +30,31 @@ weighted_quantile = function(x, probs = seq(0, 1, 0.25), weights = NULL, na.rm = FALSE, - type = 7 + type = 7, + ... ) { weighted_quantile_fun( x, weights = weights, na.rm = na.rm, - type = type + type = type, + ... )(probs) } #' @rdname weighted_quantile #' @export -weighted_quantile_fun = function(x, weights = NULL, na.rm = FALSE, type = 7) { +weighted_quantile_fun = function(x, weights = NULL, na.rm = FALSE, type = 7, ...) { na.rm <- as_one_logical(na.rm) - if (!isTRUE(type %in% 1:9)) { - stop0("Quantile type `", deparse0(type), "` is invalid. It must be in 1:9.") - } + assert_number(type, lower = 1, upper = 9) if (na.rm) { keep = !is.na(x) & !is.na(weights) x = x[keep] weights = weights[keep] + } else if (anyNA(x)) { + # quantile itself doesn't handle this case (#110) + return(function(p) rep(NA_real_, length(p))) } # determine weights diff --git a/tests/testthat/test-convergence.R b/tests/testthat/test-convergence.R index 96194d8..4f9621a 100644 --- a/tests/testthat/test-convergence.R +++ b/tests/testthat/test-convergence.R @@ -145,3 +145,7 @@ test_that("autocovariance returns correct results", { ac2 <- acf(x, type = "covariance", lag.max = length(x), plot = FALSE)$acf[, 1, 1] expect_equal(ac1, ac2) }) + +test_that("NA quantile2 works", { + expect_equal(quantile2(NA_real_, c(0.25, 0.75)), c(q25 = NA_real_, q75 = NA_real_)) +}) diff --git a/tests/testthat/test-rvar-dist.R b/tests/testthat/test-rvar-dist.R index e3d44b0..95f1a57 100755 --- a/tests/testthat/test-rvar-dist.R +++ b/tests/testthat/test-rvar-dist.R @@ -9,6 +9,10 @@ test_that("distributional functions work on a scalar rvar", { expect_equal(cdf(x, x_values), x_cdf) expect_equal(quantile(x, 1:4/4), quantile(x_values, 1:4/4, names = FALSE)) + + expect_equal(quantile(rvar(1:4), 0:4/4 + .Machine$double.eps, type = 1), c(1:4, 4)) + expect_equal(quantile(rvar(1:4), 0:4/4, type = 2), c(1, 1.5, 2.5, 3.5, 4)) + expect_equal(quantile(rvar(1:4), 0:4/4 + .Machine$double.eps, type = 3), c(1, 1:4)) }) test_that("distributional functions work on an rvar array", { @@ -83,6 +87,18 @@ test_that("weighted rvar works", { expect_equal(cdf(x, 0:9), ecdf(x1_draws)(0:9)/3 + ecdf(x2_draws)(0:9)*2/3) expect_equal(quantile(x, cdf(x, c(x1_draws, x2_draws)), type = 1), c(x1_draws, x2_draws)) expect_equal(quantile(x, cdf(x, c(x1_draws, x2_draws)), type = 4), c(x1_draws, x2_draws)) + expect_equal(unname(quantile2(x, cdf(x, c(x1_draws, x2_draws)), type = 1)), c(x1_draws, x2_draws)) + expect_equal(unname(quantile2(x, cdf(x, c(x1_draws, x2_draws)), type = 4)), c(x1_draws, x2_draws)) + + x_na <- rvar(c(draws_of(x), NA_real_), log_weights = c(log_weights(x), 1)) + expect_equal(quantile(x_na, c(0.25, 0.5, 0.75), type = 4), c(NA_real_, NA_real_, NA_real_)) + expect_equal( + quantile(x_na, c(0.25, 0.5, 0.75), type = 7, na.rm = TRUE), + quantile(x, c(0.25, 0.5, 0.75), type = 7) + ) + + expect_equal(quantile(rvar(1), 0.5), 1) + expect_equal(quantile(rvar(), 0.5), numeric()) }) test_that("weighted rvar_factor works", { diff --git a/tests/testthat/test-summarise_draws.R b/tests/testthat/test-summarise_draws.R index 80ed971..7c14650 100644 --- a/tests/testthat/test-summarise_draws.R +++ b/tests/testthat/test-summarise_draws.R @@ -100,24 +100,24 @@ test_that(paste( x <- as_draws_array(test_array) sum_x <- summarise_draws(x) parsum_x <- summarise_draws(x, .cores = cores) - expect_identical(sum_x, parsum_x) + expect_equal(sum_x, parsum_x) dimnames(x)$variable[2] <- reserved_variables()[1] sum_x <- summarise_draws(x) parsum_x <- summarise_draws(x, .cores = cores) - expect_identical(sum_x, parsum_x) + expect_equal(sum_x, parsum_x) n <- 1 test_array <- array(data = rnorm(1000*nc*n), dim = c(1000,nc,n)) x <- as_draws_array(test_array) sum_x <- summarise_draws(x) parsum_x <- summarise_draws(x, .cores = cores) - expect_identical(sum_x, parsum_x) + expect_equal(sum_x, parsum_x) dimnames(x)$variable[1] <- reserved_variables()[1] suppressWarnings(sum_x <- summarise_draws(x)) suppressWarnings(parsum_x <- summarise_draws(x, .cores = cores)) - expect_identical(sum_x, parsum_x) + expect_equal(sum_x, parsum_x) }) test_that("summarise_draws supports tibble::set_num_opts correctly", { From cb44b4a487be241233a934bd6aed546ce866065c Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 28 Feb 2024 11:39:33 +0200 Subject: [PATCH 21/43] start on weighted convergence --- DESCRIPTION | 2 +- R/convergence.R | 30 +++++++++++++++++++++++------- R/summarise_draws.R | 2 +- man/posterior-package.Rd | 30 ++++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 9 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 6efe863..ce46243 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -56,5 +56,5 @@ LazyData: false URL: https://mc-stan.org/posterior/, https://discourse.mc-stan.org/ BugReports: https://github.com/stan-dev/posterior/issues Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.0 VignetteBuilder: knitr diff --git a/R/convergence.R b/R/convergence.R index cf895f0..be5b1b1 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -182,7 +182,7 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export -ess_bulk.default <- function(x, ...) { +ess_bulk.default <- function(x, weights = NULL, ...) { .ess(z_scale(.split_chains(x))) } @@ -319,14 +319,19 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @rdname ess_quantile #' @export -ess_mean.default <- function(x, ...) { - .ess(.split_chains(x)) +ess_mean.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + .ess(.split_chains(x)) + } else { + .ess(.split_chains(x)) * (1 / sum(weights^2)) / (NROW(x) * NCOL(x)) + } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, ...) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights(x), ...) } #' Effective sample size for the standard deviation @@ -449,14 +454,18 @@ mcse_mean <- function(x, ...) UseMethod("mcse_mean") #' @rdname mcse_mean #' @export -mcse_mean.default <- function(x, ...) { - sd(x) / sqrt(ess_mean(x)) +mcse_mean.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + sd(x) / sqrt(ess_mean(x)) + } else { + .mcse_mean_weighted(x, weights, ...) + } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, ...) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights(x), ...) } #' Monte Carlo standard error for the standard deviation @@ -782,6 +791,12 @@ fold_draws <- function(x) { ess } +.mcse_mean_weighted <- function(x, weights, r_eff = 1, ...) { + # Vehtari et al. 2022 equation 6 + weighted_mean <- matrixStats::weightedMean(x, w = weights) + weights^2 %*% (x - c(weighted_mean))^2 / r_eff +} + # should NA be returned by a convergence diagnostic? should_return_NA <- function(x, tol = .Machine$double.eps) { if (anyNA(x) || checkmate::anyInfinite(x)) { @@ -801,3 +816,4 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } + diff --git a/R/summarise_draws.R b/R/summarise_draws.R index 6f13755..05acd83 100644 --- a/R/summarise_draws.R +++ b/R/summarise_draws.R @@ -329,9 +329,9 @@ empty_draws_summary <- function(dimensions = "variable") { create_summary_list <- function(x, v, funs, .args) { draws <- drop_dims_or_classes(x[, , v], dims = 3, reset_class = FALSE) - args <- c(list(draws), .args) v_summary <- named_list(names(funs)) for (m in names(funs)) { + args <- c(list(draws), .args[[m]]) v_summary[[m]] <- do.call(funs[[m]], args) } v_summary diff --git a/man/posterior-package.Rd b/man/posterior-package.Rd index 0f933f8..5396bf1 100644 --- a/man/posterior-package.Rd +++ b/man/posterior-package.Rd @@ -73,3 +73,33 @@ causes a warning can be controlled by this option. } } +\seealso{ +Useful links: +\itemize{ + \item \url{https://mc-stan.org/posterior/} + \item \url{https://discourse.mc-stan.org/} + \item Report bugs at \url{https://github.com/stan-dev/posterior/issues} +} + +} +\author{ +\strong{Maintainer}: Paul-Christian Bürkner \email{paul.buerkner@gmail.com} + +Authors: +\itemize{ + \item Jonah Gabry \email{jsg2201@columbia.edu} + \item Matthew Kay \email{mjskay@northwestern.edu} + \item Aki Vehtari \email{Aki.Vehtari@aalto.fi} +} + +Other contributors: +\itemize{ + \item Måns Magnusson [contributor] + \item Rok Češnovar [contributor] + \item Ben Lambert [contributor] + \item Ozan Adıgüzel [contributor] + \item Jacob Socolar [contributor] + \item Noa Kallioinen [contributor] +} + +} From 165fb4fd35fafdbe678a83138e2ec3365a1e52fc Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 7 Mar 2024 11:59:53 +0200 Subject: [PATCH 22/43] improvements to weighted ess, mcse --- R/convergence.R | 108 ++++++++++++++++++++++++++++++++++++---------- R/pareto_smooth.R | 5 --- 2 files changed, 85 insertions(+), 28 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 270020b..c3b6e34 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -183,13 +183,18 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export ess_bulk.default <- function(x, weights = NULL, ...) { - .ess(z_scale(.split_chains(x))) + if (is.null(weights)) { + .ess(z_scale(.split_chains(x))) + } else { + .ess_weighted(x, weights, ...) + } } #' @rdname ess_bulk #' @export ess_bulk.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_bulk, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_bulk, weights = weights, ...) } #' Tail effective sample size (tail-ESS) @@ -220,16 +225,17 @@ ess_tail <- function(x, ...) UseMethod("ess_tail") #' @rdname ess_tail #' @export -ess_tail.default <- function(x, ...) { - q05_ess <- ess_quantile(x, 0.05) - q95_ess <- ess_quantile(x, 0.95) +ess_tail.default <- function(x, weights = NULL, ...) { + q05_ess <- ess_quantile(x, 0.05, weights = weights, ...) + q95_ess <- ess_quantile(x, 0.95, weights = weights, ...) min(q05_ess, q95_ess) } #' @rdname ess_tail #' @export ess_tail.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_tail, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_tail, weights = weights, ...) } #' Effective sample sizes for quantiles @@ -258,13 +264,17 @@ ess_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname ess_quantile #' @export -ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .ess_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .ess_quantile, x = x) + } else { + out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, ...) + } if (names) { names(out) <- paste0("ess_q", probs * 100) } @@ -274,7 +284,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { #' @rdname ess_quantile #' @export ess_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_quantile, probs, weights = weights, names, ...) } #' @rdname ess_quantile @@ -297,6 +308,19 @@ ess_median <- function(x, ...) { .ess(.split_chains(I)) } +.ess_quantile_weighted <- function(x, prob, weights, r_eff = 1) { + if (should_return_NA(x)) { + return(NA_real_) + } + x <- as.matrix(x) + if (prob == 1) { + len <- length(x) + prob <- (len - 0.5) / len + } + I <- x <= weighted_quantile(x, prob, weights) + .ess_weighted(I, weights = weights, r_eff = r_eff) +} + #' Effective sample size for the mean #' #' Compute an effective sample size estimate for a mean (expectation) @@ -324,14 +348,15 @@ ess_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(x)) } else { - .ess(.split_chains(x)) * (1 / sum(weights^2)) / (NROW(x) * NCOL(x)) + .ess_weighted(x, weights, ...) } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights(x), ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights, ...) } #' Effective sample size for the standard deviation @@ -358,14 +383,19 @@ ess_sd <- function(x, ...) UseMethod("ess_sd") #' @rdname ess_sd #' @export -ess_sd.default <- function(x, ...) { - .ess(.split_chains(abs(x-mean(x)))) +ess_sd.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + .ess(.split_chains(abs(x-mean(x)))) + } else { + .ess_weighted(abs(x - mean(x)), weights = weights, ...) + } } #' @rdname ess_sd #' @export ess_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_sd, weights = weights, ...) } #' Monte Carlo standard error for quantiles @@ -394,23 +424,29 @@ mcse_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname mcse_quantile #' @export -mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .mcse_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .mcse_quantile, x = x) + } else { + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + } if (names) { names(out) <- paste0("mcse_q", probs * 100) } + out } #' @rdname mcse_quantile #' @export mcse_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, weights = weights, ...) } #' @rdname mcse_quantile @@ -431,6 +467,18 @@ mcse_median <- function(x, ...) { as.vector((th2 - th1) / 2) } +.mcse_quantile_weighted <- function(x, prob, weights) { + ess <- ess_quantile(x, prob, weights = weights) + p <- c(0.1586553, 0.8413447) + a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) + ssims <- sort(x) + S <- length(ssims) + th1 <- ssims[max(floor(a[1] * S), 1)] + th2 <- ssims[min(ceiling(a[2] * S), S)] + as.vector((th2 - th1) / 2) +} + + #' Monte Carlo standard error for the mean #' #' Compute the Monte Carlo standard error for the mean (expectation) of a @@ -458,14 +506,15 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { - .mcse_mean_weighted(x, weights, ...) + .mcse_weighted(x, weights, ...) } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights(x), ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights, ...) } #' Monte Carlo standard error for the standard deviation @@ -514,7 +563,8 @@ mcse_sd.default <- function(x, ...) { #' @rdname mcse_sd #' @export mcse_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_sd, weights = weights, ...) } #' Compute Quantiles @@ -793,10 +843,22 @@ fold_draws <- function(x) { ess } -.mcse_mean_weighted <- function(x, weights, r_eff = 1, ...) { +.mcse_weighted <- function(x, weights, r_eff = 1, ...) { # Vehtari et al. 2022 equation 6 - weighted_mean <- matrixStats::weightedMean(x, w = weights) - weights^2 %*% (x - c(weighted_mean))^2 / r_eff + + x <- as.numeric(x) + + weighted_mean <- matrixStats::weightedMean(x, w = weights) + + weights^2 %*% (x - c(weighted_mean))^2 / r_eff +} + +.ess_weighted <- function(x, weights, r_eff = 1, ...) { + # Vehtari et al. 2022 equation 7 + weighted_mean <- matrixStats::weightedMean(x, w = weights) + mcse <- .mcse_weighted(x, weights, r_eff, ...) + + mean((x - weighted_mean)^2) / mcse } # should NA be returned by a convergence diagnostic? diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 6aa9ce0..60c45cd 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -66,7 +66,6 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { } else { # take the max of khat for x * weights and khat for weights - weights_diags <- pareto_khat( weights(x, log = TRUE), are_log_weights = TRUE, @@ -82,10 +81,6 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { ... ) - print(weights_diags) - - print(product_diags) - dim(product_diags) <- dim(product_diags) %||% length(product_diags) margins <- seq_along(dim(product_diags)) From 8e1e0dfdab3aa180283d21b21f30543ef373eb89 Mon Sep 17 00:00:00 2001 From: n-kall Date: Mon, 11 Mar 2024 12:38:05 +0200 Subject: [PATCH 23/43] tweak weighted diagnostics --- R/convergence.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index c3b6e34..65e5b0b 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -850,7 +850,9 @@ fold_draws <- function(x) { weighted_mean <- matrixStats::weightedMean(x, w = weights) - weights^2 %*% (x - c(weighted_mean))^2 / r_eff + out <- weights^2 %*% (x - c(weighted_mean))^2 / r_eff + + out } .ess_weighted <- function(x, weights, r_eff = 1, ...) { From a9ba2b6b5f5ec60b364560355538969d19167caa Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 13 Mar 2024 11:54:03 +0200 Subject: [PATCH 24/43] add r_eff into calculation of weighted ess and mcse --- R/convergence.R | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 65e5b0b..7a0faa4 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -186,7 +186,8 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - .ess_weighted(x, weights, ...) + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -345,10 +346,11 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @export ess_mean.default <- function(x, weights = NULL, ...) { - if (is.null(weights)) { + if (is.null(weights)) { .ess(.split_chains(x)) } else { - .ess_weighted(x, weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -506,7 +508,8 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { - .mcse_weighted(x, weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .mcse_weighted(x, weights, ...) / r_eff } } @@ -843,23 +846,21 @@ fold_draws <- function(x) { ess } -.mcse_weighted <- function(x, weights, r_eff = 1, ...) { +.mcse_weighted <- function(x, weights, ...) { # Vehtari et al. 2022 equation 6 x <- as.numeric(x) - - weighted_mean <- matrixStats::weightedMean(x, w = weights) - out <- weights^2 %*% (x - c(weighted_mean))^2 / r_eff + weighted_mean <- matrixStats::weightedMean(x, w = weights) - out + (weights^2 %*% (x - c(weighted_mean))^2) } -.ess_weighted <- function(x, weights, r_eff = 1, ...) { +.ess_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 7 weighted_mean <- matrixStats::weightedMean(x, w = weights) - mcse <- .mcse_weighted(x, weights, r_eff, ...) - + mcse <- .mcse_weighted(x, weights, ...) / r_eff + mean((x - weighted_mean)^2) / mcse } @@ -882,4 +883,4 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } - + From ccdfb2fff751bc2c0ffc77bd5a5990f12c2431b0 Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 15 Mar 2024 17:46:57 +0200 Subject: [PATCH 25/43] fixes to weighted ess and mcse --- R/convergence.R | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 7a0faa4..3e49c68 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -186,7 +186,7 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -274,7 +274,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights if (is.null(weights)) { out <- ulapply(probs, .ess_quantile, x = x) } else { - out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, r_eff = r_eff, ...) } if (names) { names(out) <- paste0("ess_q", probs * 100) @@ -309,7 +310,7 @@ ess_median <- function(x, ...) { .ess(.split_chains(I)) } -.ess_quantile_weighted <- function(x, prob, weights, r_eff = 1) { +.ess_quantile_weighted <- function(x, prob, weights, r_eff) { if (should_return_NA(x)) { return(NA_real_) } @@ -318,7 +319,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= weighted_quantile(x, prob, weights) + I <- x <= quantile(x, prob) .ess_weighted(I, weights = weights, r_eff = r_eff) } @@ -389,7 +390,8 @@ ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(abs(x-mean(x)))) } else { - .ess_weighted(abs(x - mean(x)), weights = weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } } @@ -435,7 +437,8 @@ mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weight if (is.null(weights)) { out <- ulapply(probs, .mcse_quantile, x = x) } else { - out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) / r_eff } if (names) { names(out) <- paste0("mcse_q", probs * 100) @@ -846,22 +849,21 @@ fold_draws <- function(x) { ess } -.mcse_weighted <- function(x, weights, ...) { +.mcse_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 6 x <- as.numeric(x) - weighted_mean <- matrixStats::weightedMean(x, w = weights) - (weights^2 %*% (x - c(weighted_mean))^2) + sqrt(weights^2 %*% (x - c(weighted_mean))^2 / r_eff) } .ess_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 7 - weighted_mean <- matrixStats::weightedMean(x, w = weights) - mcse <- .mcse_weighted(x, weights, ...) / r_eff + mcse <- .mcse_weighted(x, weights, r_eff, ...) - mean((x - weighted_mean)^2) / mcse + var <- mean((x - mean(x))^2) + var / mcse^2 } # should NA be returned by a convergence diagnostic? From eed7221cbd1604ad06a30df60f411b858495f9cb Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 19 Mar 2024 17:02:09 +0200 Subject: [PATCH 26/43] use weighted quantile in weighted mcse for quantile --- R/convergence.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index 3e49c68..ccf265c 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -319,7 +319,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= quantile(x, prob) + I <- x <= weighted_quantile(x, prob, weights = weights) .ess_weighted(I, weights = weights, r_eff = r_eff) } From 2c70d46e3671cc5cd0ef150c6ab623a32029a104 Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 19 Mar 2024 17:51:41 +0200 Subject: [PATCH 27/43] add tests for weighted convergence measures --- R/convergence.R | 2 +- tests/testthat/test-convergence.R | 58 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index ccf265c..2c78912 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -512,7 +512,7 @@ mcse_mean.default <- function(x, weights = NULL, ...) { sd(x) / sqrt(ess_mean(x)) } else { r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) - .mcse_weighted(x, weights, ...) / r_eff + .mcse_weighted(x, weights, r_eff, ...) } } diff --git a/tests/testthat/test-convergence.R b/tests/testthat/test-convergence.R index 4f9621a..a471e75 100644 --- a/tests/testthat/test-convergence.R +++ b/tests/testthat/test-convergence.R @@ -149,3 +149,61 @@ test_that("autocovariance returns correct results", { test_that("NA quantile2 works", { expect_equal(quantile2(NA_real_, c(0.25, 0.75)), c(q25 = NA_real_, q75 = NA_real_)) }) + + +test_that("weighted convergence measures work", { + + # draws from standard normal + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 0.5) + # here, ess should be higher for mean + # mcse should be lower for mean + w1 <- as.numeric(dnorm(x, sd = 0.5) / dnorm(x)) + w1 <- w1 / sum(w1) + xw1 <- weight_draws(xr, weights = w1) + + expect_true(ess_mean(xw1) > ess_mean(xr)) + expect_true(mcse_mean(xw1) < mcse_mean(xr)) + expect_true(ess_quantile(xw1, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw1, probs = 0.95) > ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw1, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw1, probs = 0.95) < mcse_quantile(xr, probs = 0.95)) + + # target is normal(0, 1.2) + # here ess should be lower, and mcse should be higher + w2 <- as.numeric(dnorm(x, sd = 1.2) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + expect_true(ess_mean(xw2) < ess_mean(xr)) + expect_true(mcse_mean(xw2) > mcse_mean(xr)) + + expect_true(ess_quantile(xw2, probs = 0.05) < ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw2, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw2, probs = 0.05) > mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw2, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + + + # target is normal(1, 1) + # here ess for mean and q95 should be lower, but for q5 it should be higher + w3 <- as.numeric(dnorm(x, mean = 1, sd = 1) / dnorm(x)) + w3 <- w3 / sum(w3) + + xw3 <- weight_draws(xr, weights = w3) + expect_true(ess_mean(xw3) < ess_mean(xr)) + expect_true(mcse_mean(xw3) > mcse_mean(xr)) + + expect_true(ess_quantile(xw3, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw3, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw3, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw3, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + +}) From 1079cefe5c83d6c9578d1f126a5887a1505aa344 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Tue, 19 Mar 2024 23:54:27 -0500 Subject: [PATCH 28/43] check weights must be a vector --- R/weight_draws.R | 7 ++++--- tests/testthat/test-weight_draws.R | 8 +++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/R/weight_draws.R b/R/weight_draws.R index ecaf2ea..a336e17 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -213,9 +213,10 @@ log_weights.rvar <- function(object, ...) { # validate weights and return log weights validate_weights <- function(weights, ndraws, log = FALSE, pareto_smooth = FALSE) { if (is.null(weights)) return(NULL) - checkmate::assert_numeric(weights) - checkmate::assert_flag(log) - checkmate::assert_flag(pareto_smooth) + assert_numeric(weights) + assert_atomic_vector(weights) + assert_flag(log) + assert_flag(pareto_smooth) if (length(weights) != ndraws) { stop_no_call("Number of weights must match the number of draws.") diff --git a/tests/testthat/test-weight_draws.R b/tests/testthat/test-weight_draws.R index 1602396..cd94a7a 100644 --- a/tests/testthat/test-weight_draws.R +++ b/tests/testthat/test-weight_draws.R @@ -154,7 +154,7 @@ test_that("pareto smoothing smooths weights in weight_draws", { expect_false(all(weights(weighted) == weights(smoothed))) }) -# weights must match draws ------------------------------------------------ +# assertions on weights vector ------------------------------------------------ test_that("weights must match draws", { x <- example_draws() @@ -163,3 +163,9 @@ test_that("weights must match draws", { expect_error(weight_draws((!!type)(x), 1), "weights must match .* draws") } }) + +test_that("weights must be a vector, not array/matrix", { + x <- example_draws() + w <- seq_len(ndraws(x)) + expect_error(weight_draws(x, matrix(w)), "Must be.*vector.*not.*matrix") +}) From 0bfe6e658699ab1ec18aaf826181429f88a16194 Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 17 Jan 2024 17:28:32 +0200 Subject: [PATCH 29/43] updating pareto functions for weighted rvars --- R/pareto_smooth.R | 107 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 87 insertions(+), 20 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index a0bb848..6aa9ce0 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -46,21 +46,57 @@ pareto_khat.default <- function(x, #' @rdname pareto_khat #' @export -pareto_khat.rvar <- function(x, ...) { - draws_diags <- summarise_rvar_by_element_with_chains( - x, - pareto_smooth.default, - return_k = TRUE, - smooth_draws = FALSE, - ... - ) - dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) - margins <- seq_along(dim(draws_diags)) +pareto_khat.rvar <- function(x, verbose = FALSE, ...) { + if (is.null(weights(x))) { + draws_diags <- summarise_rvar_by_element_with_chains( + x, + pareto_smooth.default, + return_k = TRUE, + smooth_draws = FALSE, + verbose = verbose, + ... + ) - diags <- list( - khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) - ) + dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) + margins <- seq_along(dim(draws_diags)) + + diags <- list( + khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) + ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_khat( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_khat, + ... + ) + print(weights_diags) + + print(product_diags) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, + function(x) { + max(x[[1]]$khat, + weights_diags$khat) + }) + ) + } diags } @@ -149,6 +185,8 @@ pareto_diags.default <- function(x, #' @rdname pareto_diags #' @export pareto_diags.rvar <- function(x, ...) { + + if (is.null(weights(x))) { draws_diags <- summarise_rvar_by_element_with_chains( x, pareto_smooth.default, @@ -167,6 +205,35 @@ pareto_diags.rvar <- function(x, ...) { khat_threshold = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat_threshold), convergence_rate = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$convergence_rate) ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_diags( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_diags, + ... + ) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, function(x) max(x[[1]]$khat, weights_diags$khat)), + min_ss = apply(product_diags, margins, function(x) max(x[[1]]$min_ss, weights_diags$min_ss)), + khat_threshold = apply(product_diags, margins, function(x) max(x[[1]]$khat_threshold, weights_diags$khat_threshold)), + convergence_rate = apply(product_diags, margins, function(x) min(x[[1]]$convergence_rate, weights_diags$convergence_rate)) + ) + } diags } @@ -279,7 +346,7 @@ pareto_smooth.default <- function(x, if (are_log_weights) { tail <- "right" } - + tail <- match.arg(tail) S <- length(x) @@ -330,7 +397,7 @@ pareto_smooth.default <- function(x, k <- max(left_k, right_k) x <- smoothed$x - + } else { smoothed <- .pareto_smooth_tail( @@ -444,7 +511,7 @@ pareto_convergence_rate.rvar <- function(x, ...) { # shift log values for safe exponentiation x <- x - max(x) } - + tail <- match.arg(tail) S <- length(x) @@ -458,10 +525,10 @@ pareto_convergence_rate.rvar <- function(x, ...) { draws_tail <- ord$x[tail_ids] cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values - + max_tail <- max(draws_tail) min_tail <- min(draws_tail) - + if (ndraws_tail >= 5) { ord <- sort.int(x, index.return = TRUE) if (abs(max_tail - min_tail) < .Machine$double.eps / 100) { @@ -617,7 +684,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { msg <- NULL if (!are_weights) { - + if (khat > 1) { msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.") } else { @@ -630,7 +697,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { } } else { if (khat > khat_threshold || khat > 0.7) { - msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") + msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } message("Pareto k-hat = ", round(khat, 2), ".", msg) From 691d24088d669368fa6ac1b6c0397bcca3469626 Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 28 Feb 2024 11:39:33 +0200 Subject: [PATCH 30/43] start on weighted convergence --- R/convergence.R | 30 +++++++++++++++++++++++------- R/summarise_draws.R | 2 +- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index bbb8b12..270020b 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -182,7 +182,7 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export -ess_bulk.default <- function(x, ...) { +ess_bulk.default <- function(x, weights = NULL, ...) { .ess(z_scale(.split_chains(x))) } @@ -319,14 +319,19 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @rdname ess_quantile #' @export -ess_mean.default <- function(x, ...) { - .ess(.split_chains(x)) +ess_mean.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + .ess(.split_chains(x)) + } else { + .ess(.split_chains(x)) * (1 / sum(weights^2)) / (NROW(x) * NCOL(x)) + } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, ...) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights(x), ...) } #' Effective sample size for the standard deviation @@ -449,14 +454,18 @@ mcse_mean <- function(x, ...) UseMethod("mcse_mean") #' @rdname mcse_mean #' @export -mcse_mean.default <- function(x, ...) { - sd(x) / sqrt(ess_mean(x)) +mcse_mean.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + sd(x) / sqrt(ess_mean(x)) + } else { + .mcse_mean_weighted(x, weights, ...) + } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, ...) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights(x), ...) } #' Monte Carlo standard error for the standard deviation @@ -784,6 +793,12 @@ fold_draws <- function(x) { ess } +.mcse_mean_weighted <- function(x, weights, r_eff = 1, ...) { + # Vehtari et al. 2022 equation 6 + weighted_mean <- matrixStats::weightedMean(x, w = weights) + weights^2 %*% (x - c(weighted_mean))^2 / r_eff +} + # should NA be returned by a convergence diagnostic? should_return_NA <- function(x, tol = .Machine$double.eps) { if (anyNA(x) || checkmate::anyInfinite(x)) { @@ -803,3 +818,4 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } + diff --git a/R/summarise_draws.R b/R/summarise_draws.R index 6f13755..05acd83 100644 --- a/R/summarise_draws.R +++ b/R/summarise_draws.R @@ -329,9 +329,9 @@ empty_draws_summary <- function(dimensions = "variable") { create_summary_list <- function(x, v, funs, .args) { draws <- drop_dims_or_classes(x[, , v], dims = 3, reset_class = FALSE) - args <- c(list(draws), .args) v_summary <- named_list(names(funs)) for (m in names(funs)) { + args <- c(list(draws), .args[[m]]) v_summary[[m]] <- do.call(funs[[m]], args) } v_summary From 74b3b71d51f95c0f66bef4eb8f697e1826700940 Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 7 Mar 2024 11:59:53 +0200 Subject: [PATCH 31/43] improvements to weighted ess, mcse --- R/convergence.R | 108 ++++++++++++++++++++++++++++++++++++---------- R/pareto_smooth.R | 5 --- 2 files changed, 85 insertions(+), 28 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 270020b..c3b6e34 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -183,13 +183,18 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export ess_bulk.default <- function(x, weights = NULL, ...) { - .ess(z_scale(.split_chains(x))) + if (is.null(weights)) { + .ess(z_scale(.split_chains(x))) + } else { + .ess_weighted(x, weights, ...) + } } #' @rdname ess_bulk #' @export ess_bulk.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_bulk, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_bulk, weights = weights, ...) } #' Tail effective sample size (tail-ESS) @@ -220,16 +225,17 @@ ess_tail <- function(x, ...) UseMethod("ess_tail") #' @rdname ess_tail #' @export -ess_tail.default <- function(x, ...) { - q05_ess <- ess_quantile(x, 0.05) - q95_ess <- ess_quantile(x, 0.95) +ess_tail.default <- function(x, weights = NULL, ...) { + q05_ess <- ess_quantile(x, 0.05, weights = weights, ...) + q95_ess <- ess_quantile(x, 0.95, weights = weights, ...) min(q05_ess, q95_ess) } #' @rdname ess_tail #' @export ess_tail.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_tail, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_tail, weights = weights, ...) } #' Effective sample sizes for quantiles @@ -258,13 +264,17 @@ ess_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname ess_quantile #' @export -ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .ess_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .ess_quantile, x = x) + } else { + out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, ...) + } if (names) { names(out) <- paste0("ess_q", probs * 100) } @@ -274,7 +284,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { #' @rdname ess_quantile #' @export ess_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_quantile, probs, weights = weights, names, ...) } #' @rdname ess_quantile @@ -297,6 +308,19 @@ ess_median <- function(x, ...) { .ess(.split_chains(I)) } +.ess_quantile_weighted <- function(x, prob, weights, r_eff = 1) { + if (should_return_NA(x)) { + return(NA_real_) + } + x <- as.matrix(x) + if (prob == 1) { + len <- length(x) + prob <- (len - 0.5) / len + } + I <- x <= weighted_quantile(x, prob, weights) + .ess_weighted(I, weights = weights, r_eff = r_eff) +} + #' Effective sample size for the mean #' #' Compute an effective sample size estimate for a mean (expectation) @@ -324,14 +348,15 @@ ess_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(x)) } else { - .ess(.split_chains(x)) * (1 / sum(weights^2)) / (NROW(x) * NCOL(x)) + .ess_weighted(x, weights, ...) } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights(x), ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights, ...) } #' Effective sample size for the standard deviation @@ -358,14 +383,19 @@ ess_sd <- function(x, ...) UseMethod("ess_sd") #' @rdname ess_sd #' @export -ess_sd.default <- function(x, ...) { - .ess(.split_chains(abs(x-mean(x)))) +ess_sd.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + .ess(.split_chains(abs(x-mean(x)))) + } else { + .ess_weighted(abs(x - mean(x)), weights = weights, ...) + } } #' @rdname ess_sd #' @export ess_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_sd, weights = weights, ...) } #' Monte Carlo standard error for quantiles @@ -394,23 +424,29 @@ mcse_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname mcse_quantile #' @export -mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .mcse_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .mcse_quantile, x = x) + } else { + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + } if (names) { names(out) <- paste0("mcse_q", probs * 100) } + out } #' @rdname mcse_quantile #' @export mcse_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, weights = weights, ...) } #' @rdname mcse_quantile @@ -431,6 +467,18 @@ mcse_median <- function(x, ...) { as.vector((th2 - th1) / 2) } +.mcse_quantile_weighted <- function(x, prob, weights) { + ess <- ess_quantile(x, prob, weights = weights) + p <- c(0.1586553, 0.8413447) + a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) + ssims <- sort(x) + S <- length(ssims) + th1 <- ssims[max(floor(a[1] * S), 1)] + th2 <- ssims[min(ceiling(a[2] * S), S)] + as.vector((th2 - th1) / 2) +} + + #' Monte Carlo standard error for the mean #' #' Compute the Monte Carlo standard error for the mean (expectation) of a @@ -458,14 +506,15 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { - .mcse_mean_weighted(x, weights, ...) + .mcse_weighted(x, weights, ...) } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights(x), ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights, ...) } #' Monte Carlo standard error for the standard deviation @@ -514,7 +563,8 @@ mcse_sd.default <- function(x, ...) { #' @rdname mcse_sd #' @export mcse_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_sd, weights = weights, ...) } #' Compute Quantiles @@ -793,10 +843,22 @@ fold_draws <- function(x) { ess } -.mcse_mean_weighted <- function(x, weights, r_eff = 1, ...) { +.mcse_weighted <- function(x, weights, r_eff = 1, ...) { # Vehtari et al. 2022 equation 6 - weighted_mean <- matrixStats::weightedMean(x, w = weights) - weights^2 %*% (x - c(weighted_mean))^2 / r_eff + + x <- as.numeric(x) + + weighted_mean <- matrixStats::weightedMean(x, w = weights) + + weights^2 %*% (x - c(weighted_mean))^2 / r_eff +} + +.ess_weighted <- function(x, weights, r_eff = 1, ...) { + # Vehtari et al. 2022 equation 7 + weighted_mean <- matrixStats::weightedMean(x, w = weights) + mcse <- .mcse_weighted(x, weights, r_eff, ...) + + mean((x - weighted_mean)^2) / mcse } # should NA be returned by a convergence diagnostic? diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 6aa9ce0..60c45cd 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -66,7 +66,6 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { } else { # take the max of khat for x * weights and khat for weights - weights_diags <- pareto_khat( weights(x, log = TRUE), are_log_weights = TRUE, @@ -82,10 +81,6 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { ... ) - print(weights_diags) - - print(product_diags) - dim(product_diags) <- dim(product_diags) %||% length(product_diags) margins <- seq_along(dim(product_diags)) From a1b6564e37de4486d2a5547d02b2ca9681ce07af Mon Sep 17 00:00:00 2001 From: n-kall Date: Mon, 11 Mar 2024 12:38:05 +0200 Subject: [PATCH 32/43] tweak weighted diagnostics --- R/convergence.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index c3b6e34..65e5b0b 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -850,7 +850,9 @@ fold_draws <- function(x) { weighted_mean <- matrixStats::weightedMean(x, w = weights) - weights^2 %*% (x - c(weighted_mean))^2 / r_eff + out <- weights^2 %*% (x - c(weighted_mean))^2 / r_eff + + out } .ess_weighted <- function(x, weights, r_eff = 1, ...) { From 6ccf496b512614186589ebd00b060756108aaecf Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 13 Mar 2024 11:54:03 +0200 Subject: [PATCH 33/43] add r_eff into calculation of weighted ess and mcse --- R/convergence.R | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 65e5b0b..7a0faa4 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -186,7 +186,8 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - .ess_weighted(x, weights, ...) + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -345,10 +346,11 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @export ess_mean.default <- function(x, weights = NULL, ...) { - if (is.null(weights)) { + if (is.null(weights)) { .ess(.split_chains(x)) } else { - .ess_weighted(x, weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -506,7 +508,8 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { - .mcse_weighted(x, weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .mcse_weighted(x, weights, ...) / r_eff } } @@ -843,23 +846,21 @@ fold_draws <- function(x) { ess } -.mcse_weighted <- function(x, weights, r_eff = 1, ...) { +.mcse_weighted <- function(x, weights, ...) { # Vehtari et al. 2022 equation 6 x <- as.numeric(x) - - weighted_mean <- matrixStats::weightedMean(x, w = weights) - out <- weights^2 %*% (x - c(weighted_mean))^2 / r_eff + weighted_mean <- matrixStats::weightedMean(x, w = weights) - out + (weights^2 %*% (x - c(weighted_mean))^2) } -.ess_weighted <- function(x, weights, r_eff = 1, ...) { +.ess_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 7 weighted_mean <- matrixStats::weightedMean(x, w = weights) - mcse <- .mcse_weighted(x, weights, r_eff, ...) - + mcse <- .mcse_weighted(x, weights, ...) / r_eff + mean((x - weighted_mean)^2) / mcse } @@ -882,4 +883,4 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } - + From e47fd921e971db2617572b7f757b456ebc043b0d Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 15 Mar 2024 17:46:57 +0200 Subject: [PATCH 34/43] fixes to weighted ess and mcse --- R/convergence.R | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 7a0faa4..3e49c68 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -186,7 +186,7 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -274,7 +274,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights if (is.null(weights)) { out <- ulapply(probs, .ess_quantile, x = x) } else { - out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, r_eff = r_eff, ...) } if (names) { names(out) <- paste0("ess_q", probs * 100) @@ -309,7 +310,7 @@ ess_median <- function(x, ...) { .ess(.split_chains(I)) } -.ess_quantile_weighted <- function(x, prob, weights, r_eff = 1) { +.ess_quantile_weighted <- function(x, prob, weights, r_eff) { if (should_return_NA(x)) { return(NA_real_) } @@ -318,7 +319,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= weighted_quantile(x, prob, weights) + I <- x <= quantile(x, prob) .ess_weighted(I, weights = weights, r_eff = r_eff) } @@ -389,7 +390,8 @@ ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(abs(x-mean(x)))) } else { - .ess_weighted(abs(x - mean(x)), weights = weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } } @@ -435,7 +437,8 @@ mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weight if (is.null(weights)) { out <- ulapply(probs, .mcse_quantile, x = x) } else { - out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) / r_eff } if (names) { names(out) <- paste0("mcse_q", probs * 100) @@ -846,22 +849,21 @@ fold_draws <- function(x) { ess } -.mcse_weighted <- function(x, weights, ...) { +.mcse_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 6 x <- as.numeric(x) - weighted_mean <- matrixStats::weightedMean(x, w = weights) - (weights^2 %*% (x - c(weighted_mean))^2) + sqrt(weights^2 %*% (x - c(weighted_mean))^2 / r_eff) } .ess_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 7 - weighted_mean <- matrixStats::weightedMean(x, w = weights) - mcse <- .mcse_weighted(x, weights, ...) / r_eff + mcse <- .mcse_weighted(x, weights, r_eff, ...) - mean((x - weighted_mean)^2) / mcse + var <- mean((x - mean(x))^2) + var / mcse^2 } # should NA be returned by a convergence diagnostic? From 1734e3d1e3f54a8565bb321a7035c040779d2954 Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 19 Mar 2024 17:02:09 +0200 Subject: [PATCH 35/43] use weighted quantile in weighted mcse for quantile --- R/convergence.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index 3e49c68..ccf265c 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -319,7 +319,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= quantile(x, prob) + I <- x <= weighted_quantile(x, prob, weights = weights) .ess_weighted(I, weights = weights, r_eff = r_eff) } From 222893a7ada08ecea07877168652ad7d984daa8a Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 19 Mar 2024 17:51:41 +0200 Subject: [PATCH 36/43] add tests for weighted convergence measures --- R/convergence.R | 2 +- tests/testthat/test-convergence.R | 58 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index ccf265c..2c78912 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -512,7 +512,7 @@ mcse_mean.default <- function(x, weights = NULL, ...) { sd(x) / sqrt(ess_mean(x)) } else { r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) - .mcse_weighted(x, weights, ...) / r_eff + .mcse_weighted(x, weights, r_eff, ...) } } diff --git a/tests/testthat/test-convergence.R b/tests/testthat/test-convergence.R index 4f9621a..a471e75 100644 --- a/tests/testthat/test-convergence.R +++ b/tests/testthat/test-convergence.R @@ -149,3 +149,61 @@ test_that("autocovariance returns correct results", { test_that("NA quantile2 works", { expect_equal(quantile2(NA_real_, c(0.25, 0.75)), c(q25 = NA_real_, q75 = NA_real_)) }) + + +test_that("weighted convergence measures work", { + + # draws from standard normal + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 0.5) + # here, ess should be higher for mean + # mcse should be lower for mean + w1 <- as.numeric(dnorm(x, sd = 0.5) / dnorm(x)) + w1 <- w1 / sum(w1) + xw1 <- weight_draws(xr, weights = w1) + + expect_true(ess_mean(xw1) > ess_mean(xr)) + expect_true(mcse_mean(xw1) < mcse_mean(xr)) + expect_true(ess_quantile(xw1, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw1, probs = 0.95) > ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw1, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw1, probs = 0.95) < mcse_quantile(xr, probs = 0.95)) + + # target is normal(0, 1.2) + # here ess should be lower, and mcse should be higher + w2 <- as.numeric(dnorm(x, sd = 1.2) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + expect_true(ess_mean(xw2) < ess_mean(xr)) + expect_true(mcse_mean(xw2) > mcse_mean(xr)) + + expect_true(ess_quantile(xw2, probs = 0.05) < ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw2, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw2, probs = 0.05) > mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw2, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + + + # target is normal(1, 1) + # here ess for mean and q95 should be lower, but for q5 it should be higher + w3 <- as.numeric(dnorm(x, mean = 1, sd = 1) / dnorm(x)) + w3 <- w3 / sum(w3) + + xw3 <- weight_draws(xr, weights = w3) + expect_true(ess_mean(xw3) < ess_mean(xr)) + expect_true(mcse_mean(xw3) > mcse_mean(xr)) + + expect_true(ess_quantile(xw3, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw3, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw3, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw3, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + +}) From a732ad86fcb85f3cdc4d2c327fbdaadd49331fad Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 22 Mar 2024 16:45:17 +0200 Subject: [PATCH 37/43] fix r_eff calculations for each quantity --- R/convergence.R | 46 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 2c78912..d66376e 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -82,6 +82,8 @@ rhat_basic.rvar <- function(x, split = TRUE, ...) { #' recommend the improved ESS convergence diagnostics implemented in #' [ess_bulk()] and [ess_tail()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -104,18 +106,25 @@ ess_basic <- function(x, ...) UseMethod("ess_basic") #' @rdname ess_basic #' @export -ess_basic.default <- function(x, split = TRUE, ...) { +ess_basic.default <- function(x, split = TRUE, weights = NULL, ...) { split <- as_one_logical(split) if (split) { x <- .split_chains(x) + .ess(x) + } + + if (is.null(weights)) { + r_eff <- .ess(x) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } - .ess(x) } #' @rdname ess_basic #' @export -ess_basic.rvar <- function(x, split = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_basic, split, ...) +ess_basic.rvar <- function(x, split = TRUE, weights = weights, ...) { + + summarise_rvar_by_element_with_chains(x, ess_basic, split, weights = weights, ...) + } #' Rhat convergence diagnostic @@ -162,6 +171,8 @@ rhat.rvar <- function(x, ...) { #' rank normalized values using split chains. For the tail effective sample size #' see [ess_tail()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -186,7 +197,7 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -206,6 +217,8 @@ ess_bulk.rvar <- function(x, ...) { #' sample sizes for 5% and 95% quantiles. For the bulk effective sample #' size see [ess_bulk()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -241,8 +254,9 @@ ess_tail.rvar <- function(x, ...) { #' Effective sample sizes for quantiles #' -#' Compute effective sample size estimates for quantile estimates of a single -#' variable. +#' Compute effective sample size estimates for quantile estimates of a +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -274,8 +288,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights if (is.null(weights)) { out <- ulapply(probs, .ess_quantile, x = x) } else { - r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) - out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, r_eff = r_eff, ...) + r_eff <- ulapply(probs, .ess_quantile, x = x) / (nrow(x) * ncol(x)) + out <- mapply(.ess_quantile_weighted, prob = probs, r_eff = r_eff, MoreArgs = list(x = x, weights = weights)) } if (names) { names(out) <- paste0("ess_q", probs * 100) @@ -367,6 +381,8 @@ ess_mean.rvar <- function(x, ...) { #' Compute an effective sample size estimate for the standard deviation (SD) #' estimate of a single variable. This is defined as the effective sample size #' estimate for the absolute deviation from mean. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -390,7 +406,7 @@ ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(abs(x-mean(x)))) } else { - r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + r_eff <- .ess(.split_chains(abs(x-mean(x)))) / (nrow(x) * ncol(x)) .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } } @@ -405,7 +421,8 @@ ess_sd.rvar <- function(x, ...) { #' Monte Carlo standard error for quantiles #' #' Compute Monte Carlo standard errors for quantile estimates of a -#' single variable. +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -487,7 +504,8 @@ mcse_median <- function(x, ...) { #' Monte Carlo standard error for the mean #' #' Compute the Monte Carlo standard error for the mean (expectation) of a -#' single variable. +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -564,6 +582,10 @@ mcse_sd.default <- function(x, ...) { # differentials of the moments can be neglected " varsd <- varvar / Evar / 4 sqrt(varsd) + + + #TODO: add weighted version + } #' @rdname mcse_sd From 6ce80dfd9eea365d6d9b74b17920e5a6bf45603e Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 22 Mar 2024 16:59:04 +0200 Subject: [PATCH 38/43] fix weighted mcse for sd --- R/convergence.R | 45 ++++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index d10d937..ea715fa 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -570,26 +570,33 @@ mcse_sd <- function(x, ...) UseMethod("mcse_sd") #' @rdname mcse_sd #' @export -mcse_sd.default <- function(x, ...) { - # var/sd are not a simple expectation of g(X), e.g. variance - # has (X-E[X])^2. The following ESS is based on a relevant quantity - # in the computation and is empirically a good choice. - sims_c <- x - mean(x) - ess <- ess_mean((sims_c)^2) - # Variance of variance estimate by Kenney and Keeping (1951, p. 141), - # which doesn't assume normality of sims. - Evar <- mean(sims_c^2) - varvar <- (mean(sims_c^4) - Evar^2) / ess - # The first order Taylor series approximation of variance of sd. - # Kenney and Keeping (1951, p. 141) write "...since fluctuations of - # any moment are of order N^{-1/2}, squares and higher powers of - # differentials of the moments can be neglected " - varsd <- varvar / Evar / 4 - sqrt(varsd) - - - #TODO: add weighted version +mcse_sd.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + # var/sd are not a simple expectation of g(X), e.g. variance + # has (X-E[X])^2. The following ESS is based on a relevant quantity + # in the computation and is empirically a good choice. + sims_c <- x - mean(x) + ess <- ess_mean((sims_c)^2) + # Variance of variance estimate by Kenney and Keeping (1951, p. 141), + # which doesn't assume normality of sims. + Evar <- mean(sims_c^2) + varvar <- (mean(sims_c^4) - Evar^2) / ess + # The first order Taylor series approximation of variance of sd. + # Kenney and Keeping (1951, p. 141) write "...since fluctuations of + # any moment are of order N^{-1/2}, squares and higher powers of + # differentials of the moments can be neglected " + varsd <- varvar / Evar / 4 + sqrt(varsd) + + } else { + + sims_c <- x - mean(x) + ess <- ess_mean((sims_c)^2) + r_eff <- ess / (nrow(x) * ncol(x)) + .mcse_weighted(sims_c, weights, r_eff, ...) + } } #' @rdname mcse_sd From c0d8b8e236b0a76efe670566ac191172ebced87b Mon Sep 17 00:00:00 2001 From: n-kall Date: Mon, 8 Apr 2024 13:01:56 +0300 Subject: [PATCH 39/43] fixes to pareto smoothing for weighted draws --- R/pareto_smooth.R | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 60c45cd..575ee8f 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -51,8 +51,8 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { draws_diags <- summarise_rvar_by_element_with_chains( x, pareto_smooth.default, - return_k = TRUE, smooth_draws = FALSE, + return_k = TRUE, verbose = verbose, ... ) @@ -74,10 +74,13 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { w <- weights(x) - x <- weight_draws(x, NULL) + xu <- weight_draws(x, NULL) + xu <- xu * rvar(w) + product_diags <- summarise_rvar_by_element_with_chains( - x * rvar(w, nchains = nchains(x)), - pareto_khat, + xu, + pareto_khat.default, + verbose = verbose, ... ) @@ -312,7 +315,7 @@ pareto_smooth.rvar <- function(x, return_k = FALSE, extra_diags = FALSE, ...) { #' @export pareto_smooth.default <- function(x, tail = c("both", "right", "left"), - r_eff = 1, + r_eff = NULL, ndraws_tail = NULL, return_k = FALSE, extra_diags = FALSE, @@ -502,6 +505,8 @@ pareto_convergence_rate.rvar <- function(x, ...) { ... ) { + x <- as.numeric(x) + if (are_log_weights) { # shift log values for safe exponentiation x <- x - max(x) From dbdab633c5c52008dda257484e97e74a28b5808f Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 9 Apr 2024 12:33:52 +0300 Subject: [PATCH 40/43] do not unintentionally merge chains in pareto smoothing --- R/pareto_smooth.R | 2 -- 1 file changed, 2 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 575ee8f..794c718 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -505,8 +505,6 @@ pareto_convergence_rate.rvar <- function(x, ...) { ... ) { - x <- as.numeric(x) - if (are_log_weights) { # shift log values for safe exponentiation x <- x - max(x) From a6cb3fb6c5388ff8a4e7f1f27348232eccb77671 Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 9 Apr 2024 12:41:23 +0300 Subject: [PATCH 41/43] add test for pareto_khat on weighted rvar --- tests/testthat/test-pareto_smooth.R | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/testthat/test-pareto_smooth.R b/tests/testthat/test-pareto_smooth.R index 8963560..abf540a 100644 --- a/tests/testthat/test-pareto_smooth.R +++ b/tests/testthat/test-pareto_smooth.R @@ -203,3 +203,29 @@ test_that("pareto_smooth works for log_weights", { expect_true(ps$diagnostics$khat > 0.7) }) + + + +test_that("pareto khat works for weighted rvars", { + + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 1.2), should have high pareto-khat + w2 <- as.numeric(dnorm(x, sd = 5) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + k <- pareto_khat(xw2)$khat + kw <- pareto_khat(w2, are_log_weights = TRUE)$khat + kp <- pareto_khat(draws_of(xw2) * w2)$khat + + expect_true(k > 0.7) + expect_equal(k, max(kw, kp)) +}) From d928bd4134fbee0443f63cae081178757215d9fb Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 9 Apr 2024 12:46:33 +0300 Subject: [PATCH 42/43] updates to weighted mcse for sd --- R/convergence.R | 42 +++++++++++++++++++++---------- tests/testthat/test-convergence.R | 2 +- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index ea715fa..6b4dc5f 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -324,7 +324,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= quantile(x, prob) + I <- (x <= quantile(x, prob)) .ess(.split_chains(I)) } @@ -337,7 +337,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= weighted_quantile(x, prob, weights = weights) + I <- (x <= weighted_quantile(x, prob, weights = weights)) .ess_weighted(I, weights = weights, r_eff = r_eff) } @@ -408,9 +408,9 @@ ess_sd <- function(x, ...) UseMethod("ess_sd") #' @export ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { - .ess(.split_chains(abs(x-mean(x)))) + .ess(.split_chains(abs(x - mean(x)))) } else { - r_eff <- .ess(.split_chains(abs(x-mean(x)))) / (nrow(x) * ncol(x)) + r_eff <- .ess(.split_chains(abs(x - mean(x)))) / (nrow(x) * ncol(x)) .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } } @@ -422,6 +422,8 @@ ess_sd.rvar <- function(x, ...) { summarise_rvar_by_element_with_chains(x, ess_sd, weights = weights, ...) } +# TODO: ess_weights + #' Monte Carlo standard error for quantiles #' #' Compute Monte Carlo standard errors for quantile estimates of a @@ -482,6 +484,7 @@ mcse_median <- function(x, ...) { } # MCSE of a single quantile +# TODO: refer to paper .mcse_quantile <- function(x, prob) { ess <- ess_quantile(x, prob) p <- c(0.1586553, 0.8413447) @@ -499,8 +502,8 @@ mcse_median <- function(x, ...) { a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) ssims <- sort(x) S <- length(ssims) - th1 <- ssims[max(floor(a[1] * S), 1)] - th2 <- ssims[min(ceiling(a[2] * S), S)] + th1 <- ssims[max(floor(a[1] * S), 1)] # adjust to account for weights + th2 <- ssims[min(ceiling(a[2] * S), S)] #adjust to account for weights as.vector((th2 - th1) / 2) } @@ -573,7 +576,7 @@ mcse_sd <- function(x, ...) UseMethod("mcse_sd") mcse_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { - + # var/sd are not a simple expectation of g(X), e.g. variance # has (X-E[X])^2. The following ESS is based on a relevant quantity # in the computation and is empirically a good choice. @@ -582,7 +585,8 @@ mcse_sd.default <- function(x, weights = NULL, ...) { # Variance of variance estimate by Kenney and Keeping (1951, p. 141), # which doesn't assume normality of sims. Evar <- mean(sims_c^2) - varvar <- (mean(sims_c^4) - Evar^2) / ess + varvar <- (mean(sims_c^4) - Evar^2) / ess # (Equation 6.20) + # The first order Taylor series approximation of variance of sd. # Kenney and Keeping (1951, p. 141) write "...since fluctuations of # any moment are of order N^{-1/2}, squares and higher powers of @@ -592,10 +596,23 @@ mcse_sd.default <- function(x, weights = NULL, ...) { } else { - sims_c <- x - mean(x) - ess <- ess_mean((sims_c)^2) - r_eff <- ess / (nrow(x) * ncol(x)) - .mcse_weighted(sims_c, weights, r_eff, ...) + # for weights try varvar weighted / varvar unweighted to see relative efficiency of weights + + first_moment_weighted <- weighted.mean(x, w = weights) + + x_centered <- x - first_moment_weighted + second_moment_weighted <- weighted.mean(x_centered^2, w = weights) + fourth_moment_weighted <- weighted.mean(x_centered^4, w = weights) + + r_eff <- .ess(x_centered^2) / (nrow(x) * ncol(x)) + weighted_ess <- .ess_weighted(x_centered^2, weights = weights, r_eff = r_eff) + + # Kenney and Keeping (1951, eq 6.20) + varvar_weighted <- (fourth_moment_weighted - second_moment_weighted^2) / weighted_ess + + # First-order Taylor series approximation + varsd <- varvar_weighted / second_moment_weighted / 4 + sqrt(varsd) } } @@ -918,4 +935,3 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } - diff --git a/tests/testthat/test-convergence.R b/tests/testthat/test-convergence.R index a471e75..1e9034a 100644 --- a/tests/testthat/test-convergence.R +++ b/tests/testthat/test-convergence.R @@ -196,8 +196,8 @@ test_that("weighted convergence measures work", { # here ess for mean and q95 should be lower, but for q5 it should be higher w3 <- as.numeric(dnorm(x, mean = 1, sd = 1) / dnorm(x)) w3 <- w3 / sum(w3) - xw3 <- weight_draws(xr, weights = w3) + expect_true(ess_mean(xw3) < ess_mean(xr)) expect_true(mcse_mean(xw3) > mcse_mean(xr)) From 14c01dea26af8f866c2c75cf3aa74a298181cfeb Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 9 Apr 2024 14:29:52 +0300 Subject: [PATCH 43/43] updates to mcse for weighted draws --- R/convergence.R | 64 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 6b4dc5f..0fb6e95 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -199,6 +199,13 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } @@ -291,6 +298,12 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights out <- ulapply(probs, .ess_quantile, x = x) } else { + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- ulapply(probs, .ess_quantile, x = x) / (nrow(x) * ncol(x)) out <- mapply(.ess_quantile_weighted, prob = probs, r_eff = r_eff, MoreArgs = list(x = x, weights = weights)) @@ -368,6 +381,13 @@ ess_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(x)) } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } @@ -410,6 +430,13 @@ ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(abs(x - mean(x)))) } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- .ess(.split_chains(abs(x - mean(x)))) / (nrow(x) * ncol(x)) .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } @@ -460,8 +487,14 @@ mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weight if (is.null(weights)) { out <- ulapply(probs, .mcse_quantile, x = x) } else { - r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) - out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) / r_eff + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) } if (names) { names(out) <- paste0("mcse_q", probs * 100) @@ -493,6 +526,7 @@ mcse_median <- function(x, ...) { S <- length(ssims) th1 <- ssims[max(floor(a[1] * S), 1)] th2 <- ssims[min(ceiling(a[2] * S), S)] + as.vector((th2 - th1) / 2) } @@ -500,10 +534,15 @@ mcse_median <- function(x, ...) { ess <- ess_quantile(x, prob, weights = weights) p <- c(0.1586553, 0.8413447) a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) - ssims <- sort(x) - S <- length(ssims) - th1 <- ssims[max(floor(a[1] * S), 1)] # adjust to account for weights - th2 <- ssims[min(ceiling(a[2] * S), S)] #adjust to account for weights + x_idx <- order(x) + x_sorted <- x[x_idx] + weights_sorted <- weights[x_idx] + S <- length(x) + + cweights <- cumsum(weights_sorted) + th1 <- x_sorted[max(max(which(cweights < a[1])), 1)] + th2 <- x_sorted[min(min(which(cweights > a[2])), S)] + as.vector((th2 - th1) / 2) } @@ -536,6 +575,13 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) .mcse_weighted(x, weights, r_eff, ...) } @@ -596,6 +642,12 @@ mcse_sd.default <- function(x, weights = NULL, ...) { } else { + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + # for weights try varvar weighted / varvar unweighted to see relative efficiency of weights first_moment_weighted <- weighted.mean(x, w = weights)