-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathbrms_tidiers.R
405 lines (387 loc) · 16.1 KB
/
brms_tidiers.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
#' Tidying methods for a brms model
#'
#' These methods tidy the estimates from
#' \code{\link[brms:brmsfit-class]{brmsfit-objects}}
#' (fitted model objects from the \pkg{brms} package) into a summary.
#'
#' @return All tidying methods return a \code{data.frame} without rownames.
#' The structure depends on the method chosen.
#'
#' @seealso \code{\link[brms]{brms}}, \code{\link[brms]{brmsfit-class}}
#'
#' @name brms_tidiers
#'
#' @param x Fitted model object from the \pkg{brms} package. See
#' \code{\link[brms]{brmsfit-class}}.
#' @examples
#' ## original model
#' \dontrun{
#' brms_crossedRE <- brm(mpg ~ wt + (1|cyl) + (1+wt|gear), data = mtcars,
#' iter = 500, chains = 2)
#' }
#' \donttest{
#' ## too slow for CRAN (>5 seconds)
#' ## load stored object
#' if (require("rstan") && require("brms")) {
#' load(system.file("extdata", "brms_example.rda", package="broom.mixed"))
#'
#' fit <- brms_crossedRE
#' tidy(fit)
#' tidy(fit, parameters = "^sd_", conf.int = FALSE)
#' tidy(fit, effects = "fixed", conf.method="HPDinterval")
#' tidy(fit, effects = "ran_vals")
#' tidy(fit, effects = "ran_pars", robust = TRUE)
#' if (require("posterior")) {
#' tidy(fit, effects = "ran_pars", rhat = TRUE, ess = TRUE)
#'
#' }
#' }
#' if (require("rstan") && require("brms")) {
#' # glance method
#' glance(fit)
#' ## this example will give a warning that it should be run with
#' ## reloo=TRUE; however, doing this will fail
#' ## because the \code{fit} object has been stripped down to save space
#' suppressWarnings(glance(fit, looic = TRUE, cores = 1))
#' head(augment(fit))
#' }
#' }
#'
NULL
## examples for all methods (tidy/glance/augment) included in the same
## block so we can surround them with a single "if (require(brms))" block
#' @rdname brms_tidiers
#' @param parameters Names of parameters for which a summary should be
#' returned, as given by a character vector or regular expressions.
#' If \code{NA} (the default) summarized parameters are specified
#' by the \code{effects} argument.
#' @param effects A character vector including one or more of \code{"fixed"},
#' \code{"ran_vals"}, or \code{"ran_pars"}.
#' See the Value section for details.
#' @param robust Whether to use median and median absolute deviation of
#' the posterior distribution, rather
#' than mean and standard deviation, to derive point estimates and uncertainty
#' @param conf.int If \code{TRUE} columns for the lower (\code{conf.low})
#' and upper bounds (\code{conf.high}) of posterior uncertainty intervals are included.
#' @param exponentiate whether to exponentiate the fixed-effect coefficient estimates and confidence intervals (common for logistic regression); if \code{TRUE}, also scales the standard errors by the exponentiated coefficient, transforming them to the new scale
#' @param conf.level Defines the range of the posterior uncertainty conf.int,
#' such that \code{100 * conf.level}\% of the parameter's posterior distributio
#' lies within the corresponding interval.
#' Only used if \code{conf.int = TRUE}.
#' @param conf.method method for computing confidence intervals
#' ("quantile" or "HPDinterval")
#' @param rhat whether to calculate the *Rhat* convergence metric
#' (\code{FALSE} by default)
#' @param ess whether to calculate the *effective sample size* (ESS) convergence metric
#' (\code{FALSE} by default)
#' @param fix.intercept rename "Intercept" parameter to "(Intercept)", to match
#' behaviour of other model types?
#' @param looic Should the LOO Information Criterion (and related info) be
#' included? See \code{\link[rstan]{loo.stanfit}} for details. (This
#' can be slow for models fit to large datasets.)
#' @param ... Extra arguments, not used
#' @return
#' When \code{parameters = NA}, the \code{effects} argument is used
#' to determine which parameters to summarize.
#'
#' Generally, \code{tidy.brmsfit} returns
#' one row for each coefficient, with at least three columns:
#' \item{term}{The name of the model parameter.}
#' \item{estimate}{A point estimate of the coefficient (mean or median).}
#' \item{std.error}{A standard error for the point estimate (sd or mad).}
#'
#' When \code{effects = "fixed"}, only population-level
#' effects are returned.
#'
#' When \code{effects = "ran_vals"}, only group-level effects are returned.
#' In this case, two additional columns are added:
#' \item{group}{The name of the grouping factor.}
#' \item{level}{The name of the level of the grouping factor.}
#'
#' Specifying \code{effects = "ran_pars"} selects the
#' standard deviations and correlations of the group-level parameters.
#'
#' If \code{conf.int = TRUE}, columns for the \code{lower} and
#' \code{upper} bounds of the posterior conf.int computed.
#'
#' @note The names \sQuote{fixed}, \sQuote{ran_pars}, and \sQuote{ran_vals}
#' (corresponding to "non-varying", "hierarchical", and "varying" respectively
#' in previous versions of the package), while technically inappropriate in
#' a Bayesian setting where "fixed" and "random" effects are not well-defined,
#' are used for compatibility with other (frequentist) mixed model types.
#' @note At present, the components of parameter estimates are separated by parsing the column names of \code{as_draws} (e.g. \code{r_patient[1,Intercept]} for the random effect on the intercept for patient 1, or \code{b_Trt1} for the fixed effect \code{Trt1}. We try to detect underscores in parameter names and warn, but detection may be imperfect.
#' @export
tidy.brmsfit <- function(x, parameters = NA,
effects = c("fixed", "ran_pars"),
robust = FALSE,
conf.int = TRUE, conf.level = 0.95,
conf.method = c("quantile", "HPDinterval"),
rhat = FALSE, ess = FALSE,
fix.intercept = TRUE,
exponentiate = FALSE,
...) {
check_dots(...)
bad_effects <- setdiff(effects, c("fixed", "ran_pars", "ran_vals", "ran_coefs"))
if (length(bad_effects)>0) {
stop("unrecognized effects: ", paste(bad_effects, collapse = ", "))
}
std.error <- NULL ## NSE/code check
if (!requireNamespace("brms", quietly=TRUE)) {
stop("can't tidy brms objects without brms installed")
}
xr <- brms::restructure(x)
has_ranef <- nrow(xr$ranef)>0
if (any(grepl("_", rownames(fixef(x)))) ||
(has_ranef && any(grepl("_", names(ranef(x)))))) {
warning("some parameter names contain underscores: term naming may be unreliable!")
}
use_effects <- anyNA(parameters)
conf.method <- match.arg(conf.method)
is.multiresp <- length(x$formula$forms)>1
## make regular expression from a list of prefixes
mkRE <- function(x,LB=FALSE) {
pref <- "(^|_)"
if (LB) pref <- sprintf("(?<=%s)",pref)
sprintf("%s(%s)", pref, paste(unlist(x), collapse = "|"))
}
## NOT USED: could use this (or something like) to
## obviate need for gsub("_","",str_extract(...)) pattern ...
prefs_LB <- list(
fixed = "b_", ran_vals = "r_",
## don't want to remove these pieces, so use look*behind*
ran_pars = sprintf("(?<=(%s))", c("sd_", "cor_", "sigma")),
components = sprintf("(?<=%s)", c("zi_","disp_"))
)
prefs <- list(
fixed = "b_", ran_vals = "r_",
## no lookahead (doesn't work with grep[l])
ran_pars = c("sd_", "cor_", "sigma"),
components = c("zi_", "disp_")
)
pref_RE <- mkRE(prefs[effects])
if (use_effects) {
## prefixes distinguishing fixed, random effects
parameters <- pref_RE
}
samples_perchain <- brms::as_draws_array(x, parameters, regex = TRUE)
if (is.null(samples_perchain) || posterior::nvariables(samples_perchain) == 0) {
stop("No parameter name matches the specified pattern.",
call. = FALSE
)
}
samples <- brms::as_draws_matrix(samples_perchain)
terms <- colnames(samples)
if (use_effects) {
if (is.multiresp) {
if ("ran_pars" %in% effects && any(grepl("^sd",terms))) {
warning("ran_pars response/group tidying for multi-response models is currently incorrect")
}
## FIXME: unfinished attempt to fix GH #39
## extract response component from terms
## resp0 <- strsplit(terms, "_+")
## resp1 <- sapply(resp0,
## function(x) if (length(x)==2) x[2] else x[length(x)-1])
## ## put the pieces back together
## t0 <- lapply(resp0,
## function(x) if (length(x)==2) x[1] else x[-(length(x)-1)])
## t1 <- lapply(t0,
## function(x)
## case_when(
## x[[1]]=="b" ~ sprintf("b%s",x[[2]]),
## x[[2]]=="sd" ~ sprintf("sd_%s__%s",x[[2]],x[[3]]),
## x[[3]]=="cor" ~ sprintf("cor_%s_%s_%s_%s",
## x[[2]],x[[3]],x[[4]],x[[5]])
## ))
## resp0 <- stringr::str_extract_all(terms, "_[^_]+")
## resp1 <- lapply(resp0, gsub, pattern= "^_", replacement="")
response <- gsub("^_","",stringr::str_extract(terms,"_[^_]+"))
terms <- sub("_[^_]+","",terms)
}
res_list <- list()
fixed.only <- identical(effects, "fixed")
if ("fixed" %in% effects) {
## empty tibble: NA columns will be filled in as appropriate
nfixed <- sum(grepl(prefs[["fixed"]], terms))
res_list$fixed <- as_tibble(matrix(nrow = nfixed, ncol = 0))
}
grpfun <- function(x) {
if (grepl("sigma",x[[1]])) "Residual" else x[[2]]
}
if ("ran_pars" %in% effects) {
rterms <- grep(mkRE(prefs$ran_pars), terms, value = TRUE)
ss <- strsplit(rterms, "__")
pp <- "^(cor|sd)(?=(_))"
nodash <- function(x) gsub("^_", "", x)
## split the first term (cor/sd) into tag + group
ss2 <- lapply(
ss,
function(x) {
if (!is.na(pref <- stringr::str_extract(x[1], pp))) {
return(c(pref, nodash(stringr::str_remove(x[1], pp)), x[-1]))
}
return(x)
}
)
sep <- getOption("broom.mixed.sep1")
termfun <- function(x) {
if (grepl("^sigma",x[[1]])) {
paste("sd", "Observation", sep = sep)
} else {
## re-attach remaining terms
paste(x[[1]],
paste(x[3:length(x)], collapse = "."),
sep = sep
)
}
}
res_list$ran_pars <-
dplyr::tibble(
group = sapply(ss2, grpfun),
term = sapply(ss2, termfun)
)
}
## nice, but needs to be done outside averaging loop ...
## meltfun <- function(a) {
## dd <- as.data.frame(ftable(a)) |>
## setNames(c("level", "var", "term", "value")) |>
## tidyr::pivot_wider(names_from = var, values_from = value) |>
## rename(estimate = "Estimate",
## std.error = "Est.Error",
## ## FIXME: not robust to changing levels
## conf.low = "Q2.5",
## conf.high = "Q97.5")
## }
## ## purrr:::map_dfr(ranef(x), meltfun, .id = "group")
## if ("ran_coefs" %in% effects) {
## res_list$ran_coefs <- purrr:::map_dfr(coef(x), meltfun, .id = "group")
## }
if ("ran_vals" %in% effects) {
rterms <- grep(mkRE(prefs$ran_vals), terms, value = TRUE)
vals <- stringr::str_match_all(rterms, "_(.+?)\\[(.+?),(.+?)\\]")
res_list$ran_vals <-
dplyr::tibble(
group = purrr::map_chr(vals, function (v) { v[[2]] }),
term = purrr::map_chr(vals, function (v) { v[[4]] }),
level = purrr::map_chr(vals, function (v) { v[[3]] })
)
}
out <- dplyr::bind_rows(res_list, .id = "effect")
# In the case where nrow(res_list$fixed) > 0 but nrow(res_list$ran_pars) == 0,
# the out object needs to be fixed a bit (replace columns with unexpected
# lists of NULL by expected vectors of NA).
for (col in c("group", "term")) {
if (is.list(out[[col]]) && all(sapply(out[[col]], is.null))) {
out[[col]] <- rep(NA, nrow(out))
}
}
v <- if (fixed.only) seq(nrow(out)) else is.na(out$term)
newterms <- stringr::str_remove(terms[v], mkRE(prefs[c("fixed")]))
if (length(newterms)>0) {
if (fixed.only) {
out$term <- newterms
} else {
out$term[v] <- newterms
}
}
if (is.multiresp) {
out$response <- response
}
## prefixes already removed for ran_vals; don't remove for ran_pars
} else {
## if !use_effects
out <- dplyr::tibble(term = terms)
}
pointfun <- if (robust) stats::median else base::mean
stdfun <- if (robust) stats::mad else stats::sd
out$estimate <- apply(samples, 2, pointfun)
out$std.error <- apply(samples, 2, stdfun)
if (conf.int) {
stopifnot(length(conf.level) == 1L)
probs <- c((1 - conf.level) / 2, 1 - (1 - conf.level) / 2)
if (conf.method == "HPDinterval") {
cc <- coda::HPDinterval(coda::as.mcmc(samples), prob=conf.level)
} else {
cc <- t(apply(samples, 2, stats::quantile, probs = probs))
}
out$conf.low <- cc[,1]
out$conf.high <- cc[,2]
}
posterior_metrics <- c()
if (rhat) {
posterior_metrics <- c(posterior_metrics, rhat = posterior::rhat)
}
if (ess) {
posterior_metrics <- c(posterior_metrics, ess = posterior::ess_basic)
}
if (length(posterior_metrics) > 0) {
if (!requireNamespace("posterior", quietly=TRUE)) {
stop(paste0(paste0(names(posterior_metrics), collapse=", "),
" calculation for brmsfit objects requires posterior package"))
}
out[names(posterior_metrics)] <- posterior::summarise_draws(samples_perchain, posterior_metrics)[names(posterior_metrics)]
}
## figure out component
out$component <- dplyr::case_when(grepl("(^|_)zi",out$term) ~ "zi",
## ??? is this possible in brms models
grepl("^disp",out$term) ~ "disp",
TRUE ~ "cond")
if (exponentiate) {
vv <- c("estimate", "conf.low", "conf.high")
out <- (out
%>% mutate(across(contains(vv), exp))
%>% mutate(across(std.error, ~ . * estimate))
)
}
out$term <- stringr::str_remove(out$term,mkRE(prefs[["components"]],
LB=TRUE))
if (fix.intercept) {
## use lookahead/lookbehind: replace Intercept with word boundary
## or underscore before/after by (Intercept) - without removing
## underscores!
out$term <- stringr::str_replace(out$term,
"(?<=(\\b|_))Intercept(?=(\\b|_))",
"(Intercept)")
}
out <- reorder_cols(out)
return(out)
}
#' @importFrom stats quantile
#' @export
sigma.brmsfit <- function (object, ...) {
if (!("sigma" %in% brms::variables(object)))
return(1)
stats::quantile(brms::as_draws_array(object, "sigma"), probs=0.5)
}
#' @rdname brms_tidiers
#' @export
glance.brmsfit <- function(x, looic = FALSE, ...) {
## defined in rstanarm_tidiers.R
glance_stan(x, looic = looic, type = "brmsfit", ...)
}
#' @rdname brms_tidiers
#' @param data data frame
#' @param newdata new data frame
#' @param se.fit return standard errors of fit?
#' @export
augment.brmsfit <- function(x, data = stats::model.frame(x), newdata = NULL,
se.fit = TRUE, ...) {
## can't use augment_columns because residuals.brmsfit returns
## a 4-column matrix (because summary=TRUE by default, no way
## to suppress this within augment_columns)
## ... add resids.arg to augment_columns?
args <- list(x, se.fit = se.fit)
if (!missing(newdata)) args$newdata <- newdata
## FIXME: influence measures??
## allow optional arguments to augment, e.g. pred.type,
## residual.type, re.form ...
pred <- do.call(stats::predict, args)
ret <- dplyr::tibble(.fitted = pred[, "Estimate"])
if (se.fit) ret[[".se.fit"]] <- pred[, "Est.Error"]
if (is.null(newdata)) {
ret[[".resid"]] <- stats::residuals(x)[, "Estimate"]
ret <- dplyr::bind_cols(as_tibble(data), ret)
} else {
ret <- dplyr::bind_cols(as_tibble(newdata), ret)
}
return(ret)
}