Skip to content

Commit

Permalink
Merge 22b9691 into c312846
Browse files Browse the repository at this point in the history
  • Loading branch information
mjskay authored Feb 3, 2024
2 parents c312846 + 22b9691 commit 88fa83d
Show file tree
Hide file tree
Showing 46 changed files with 1,102 additions and 232 deletions.
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ S3method(iteration_ids,draws_rvars)
S3method(iteration_ids,rvar)
S3method(length,rvar)
S3method(levels,rvar)
S3method(log_weights,draws)
S3method(log_weights,draws_rvars)
S3method(log_weights,rvar)
S3method(mad,default)
S3method(mad,rvar)
S3method(mad,rvar_ordered)
Expand Down Expand Up @@ -394,7 +397,9 @@ S3method(weight_draws,draws_df)
S3method(weight_draws,draws_list)
S3method(weight_draws,draws_matrix)
S3method(weight_draws,draws_rvars)
S3method(weight_draws,rvar)
S3method(weights,draws)
S3method(weights,rvar)
export("%**%")
export("%in%")
export("draws_of<-")
Expand Down Expand Up @@ -454,6 +459,7 @@ export(is_rvar)
export(is_rvar_factor)
export(is_rvar_ordered)
export(iteration_ids)
export(log_weights)
export(mad)
export(match)
export(mcse_mean)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* Add `pareto_smooth` option to `weight_draws`, to Pareto smooth
weights before adding to a draws object.
* Add support for applying weights to individual `rvar` objects.
* Add `log_weights()` function for easy access to raw internal weights.
* Matrix multiplication of `rvar`s can now be done with the base matrix
multiplication operator (`%*%`) instead of `%**%` in R >= 4.3.
* `variables()`, `variables<-()`, `set_variables()`, and `nvariables()` now
Expand Down
1 change: 1 addition & 0 deletions R/as_draws_array.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ as_draws_array.draws_rvars <- function(x, ...) {
x <- check_variables_are_numeric(
x, to = "draws_array", is_non_numeric = is_rvar_factor, convert = FALSE
)
x <- promote_rvar_weights_to_variable(x)

# cbind discards class information when applied to vectors, which converts
# the underlying factors to numeric
Expand Down
1 change: 1 addition & 0 deletions R/as_draws_df.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ as_draws_df.draws_rvars <- function(x, ...) {
if (ndraws(x) == 0L) {
return(empty_draws_df(variables(x)))
}
x <- promote_rvar_weights_to_variable(x)
out <- do.call(cbind, lapply(seq_along(x), function(i) {
# flatten each rvar so it only has two dimensions: draws and variables
# this also collapses indices into variable names in the format "var[i,j,k,...]"
Expand Down
1 change: 1 addition & 0 deletions R/as_draws_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ as_draws_matrix.draws_rvars <- function(x, ...) {
x <- check_variables_are_numeric(
x, to = "draws_matrix", is_non_numeric = is_rvar_factor, convert = FALSE
)
x <- promote_rvar_weights_to_variable(x)

# cbind discards class information when applied to vectors, which converts
# the underlying factors to numeric
Expand Down
30 changes: 29 additions & 1 deletion R/as_draws_rvars.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,27 @@ as_draws_rvars.mcmc.list <- function(x, ...) {

check_new_variables(names(x))

x <- conform_rvar_ndraws_nchains(x)
x <- conform_rvar_nchains_ndraws_weights(x)

class(x) <- class_draws_rvars()

# move the .log_weight column into the log_weights attribute of each rvar,
# but only if there is no conflict between any existing weights on the rvars
if (".log_weight" %in% names(x)) {
existing_weights <- log_weights(x[[1]])
.log_weight <- as.vector(draws_of(x$.log_weight))
if (is.null(existing_weights)) {
x$.log_weight <- NULL
x <- weight_draws(x, .log_weight, log = TRUE)
} else {
# if we reach this point either existing_weights and .log_weight
# are identical (so we don't have to do anything) or they aren't
# and weights2_common will throw the appropriate error --- thus
# we don't need to do anything with its output
weights2_common(existing_weights, .log_weight)
}
}

x
}

Expand Down Expand Up @@ -258,3 +276,13 @@ empty_draws_rvars <- function(variables = character(0), nchains = 0) {
class(out) <- class_draws_rvars()
out
}

# when converting draws_rvars to other formats, we must promote log weights
# to be a variable before doing the conversion
promote_rvar_weights_to_variable <- function(x) {
.log_weights <- log_weights(x)
if (!is.null(.log_weights)) {
x$.log_weight <- rvar(log_weights(x), nchains = nchains(x))
}
x
}
16 changes: 9 additions & 7 deletions R/convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -541,12 +541,9 @@ quantile2.default <- function(
) {
names <- as_one_logical(names)
na.rm <- as_one_logical(na.rm)
if (!na.rm && anyNA(x)) {
# quantile itself doesn't handle this case (#110)
out <- rep(NA_real_, length(probs))
} else {
out <- quantile(x, probs = probs, na.rm = na.rm, ...)
}

out <- weighted_quantile(x, probs = probs, na.rm = na.rm, ...)

if (names) {
names(out) <- paste0("q", probs * 100)
} else {
Expand All @@ -560,7 +557,12 @@ quantile2.default <- function(
quantile2.rvar <- function(
x, probs = c(0.05, 0.95), na.rm = FALSE, names = TRUE, ...
) {
summarise_rvar_by_element_with_chains(x, quantile2, probs, na.rm, names, ...)
weights <- weights(x)
summarise_rvar_by_element(x, function(draws) {
quantile2(
draws, probs = probs, weights = weights, na.rm = na.rm, names = names, ...
)
})
}

# internal ----------------------------------------------------------------
Expand Down
73 changes: 47 additions & 26 deletions R/discrete-summaries.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
#'
#' Normalized entropy, for measuring dispersion in draws from categorical distributions.
#'
#' @param x (multiple options) A vector to be interpreted as draws from
#' a categorical distribution, such as:
#' - A [factor]
#' - A [numeric] (should be [integer] or integer-like)
#' - An [rvar], [rvar_factor], or [rvar_ordered]
#' @template args-summaries-x-categorical
#' @template args-summaries-weights
#' @template args-methods-dots
#'
#' @details
#' Calculates the normalized Shannon entropy of the draws in `x`. This value is
Expand Down Expand Up @@ -51,14 +49,14 @@
#' xy
#' entropy(xy)
#' @export
entropy <- function(x) {
entropy <- function(x, ...) {
UseMethod("entropy")
}
#' @rdname entropy
#' @export
entropy.default <- function(x) {
entropy.default <- function(x, weights = NULL, ...) {
if (anyNA(x)) return(NA_real_)
p <- prop.table(simple_table(x)$count)
p <- prop.table(weighted_simple_table(x, weights)$count)
n <- length(p)

if (n == 1) {
Expand All @@ -71,8 +69,8 @@ entropy.default <- function(x) {
}
#' @rdname entropy
#' @export
entropy.rvar <- function(x) {
summarise_rvar_by_element(x, entropy)
entropy.rvar <- function(x, ...) {
summarise_rvar_by_element(x, entropy, weights = weights(x))
}


Expand All @@ -85,6 +83,8 @@ entropy.rvar <- function(x) {
#' - A [factor]
#' - A [numeric] (should be [integer] or integer-like)
#' - An [rvar], [rvar_factor], or [rvar_ordered]
#' @template args-summaries-weights
#' @template args-methods-dots
#'
#' @details
#' Calculates Tastle and Wierman's (2007) *dissention* measure:
Expand Down Expand Up @@ -125,12 +125,12 @@ entropy.rvar <- function(x) {
#' xy
#' dissent(xy)
#' @export
dissent <- function(x) {
dissent <- function(x, ...) {
UseMethod("dissent")
}
#' @rdname dissent
#' @export
dissent.default <- function(x) {
dissent.default <- function(x, weights = NULL, ...) {
if (anyNA(x)) return(NA_real_)
if (length(x) == 0) return(0)

Expand All @@ -141,33 +141,32 @@ dissent.default <- function(x) {
d <- diff(range(x))
}

tab <- simple_table(x)
tab <- weighted_simple_table(x, weights)
p <- prop.table(tab$count)

if (length(p) == 1) {
out <- 0
} else {
x_i <- tab$x
out <- -sum(p * log2(1 - abs(x_i - mean(x)) / d))
mean_x <- if (is.null(weights)) mean(x) else weighted.mean(x, weights)
out <- -sum(p * log2(1 - abs(x_i - mean_x) / d))
}
out
}
#' @rdname dissent
#' @export
dissent.rvar <- function(x) {
summarise_rvar_by_element(x, dissent)
dissent.rvar <- function(x, ...) {
summarise_rvar_by_element(x, dissent, weights = weights(x))
}


#' Modal category
#'
#' Modal category of a vector.
#'
#' @param x (multiple options) A vector to be interpreted as draws from
#' a categorical distribution, such as:
#' - A [factor]
#' - A [numeric] (should be [integer] or integer-like)
#' - An [rvar], [rvar_factor], or [rvar_ordered]
#' @template args-summaries-x-categorical
#' @template args-summaries-weights
#' @template args-methods-dots
#'
#' @details
#' Finds the modal category (i.e., most frequent value) in `x`. In the case of
Expand All @@ -192,20 +191,20 @@ dissent.rvar <- function(x) {
#' xy
#' modal_category(xy)
#' @export
modal_category <- function(x) {
modal_category <- function(x, ...) {
UseMethod("modal_category")
}
#' @rdname modal_category
#' @export
modal_category.default <- function(x) {
modal_category.default <- function(x, weights = NULL, ...) {
if (anyNA(x)) return(NA)
tab <- simple_table(x)
tab <- weighted_simple_table(x, weights)
tab$x[which.max(tab$count)]
}
#' @rdname modal_category
#' @export
modal_category.rvar <- function(x) {
summarise_rvar_by_element(x, modal_category)
modal_category.rvar <- function(x, ...) {
summarise_rvar_by_element(x, modal_category, weights = weights(x))
}


Expand All @@ -231,3 +230,25 @@ simple_table <- function(x) {
count = tabulate(x_int, nbins = length(values))
)
}

#' A weighted version of simple_table
#' @param x a vector (numeric, factor, character, etc)
#' @param weights weights
#' @returns a list with two components of the same length
#' - `x`: unique values from the input `x`
#' - `count`: sum of weights for each unique value of `x`
#' @noRd
weighted_simple_table <- function(x, weights) {
if (is.null(weights)) return(simple_table(x))
stopifnot(identical(length(x), length(weights)))

if (is.factor(x)) {
values <- levels(x)
} else {
values <- unique(x)
}
list(
x = values,
count = vapply(split(weights, factor(x, values)), sum, numeric(1), USE.NAMES = FALSE)
)
}
7 changes: 5 additions & 2 deletions R/draws-index.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,11 @@ nchains.rvar <- function(x) {
# attribute on an rvar, ALWAYS use this function so that the proxy
# cache is invalidated
`nchains_rvar<-` <- function(x, value) {
attr(x, "nchains") <- value
invalidate_rvar_cache(x)
if (attr(x, "nchains") != value) {
attr(x, "nchains") <- value
x <- invalidate_rvar_cache(x)
}
x
}


Expand Down
2 changes: 1 addition & 1 deletion R/mutate_variables.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ mutate_variables.draws_rvars <- function(.x, ...) {
for (var in names(dots)) {
.x[[var]] <- as_rvar(eval_tidy(dots[[var]], .x, env))
}
conform_rvar_ndraws_nchains(.x)
conform_rvar_nchains_ndraws_weights(.x)
}

# evaluate an expression passed to 'mutate_variables' and check its validity
Expand Down
2 changes: 1 addition & 1 deletion R/resample_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ resample_draws.draws <- function(x, weights = NULL, method = "stratified",
weights <- rep.int(1/ndraws_total, ndraws_total)
}
# resampling invalidates stored weights
x <- remove_variables(x, ".log_weight")
x <- weight_draws(x, NULL)
} else {
weights <- weights / sum(weights)
}
Expand Down
Loading

0 comments on commit 88fa83d

Please sign in to comment.