Skip to content

Commit

Permalink
simplify argument structure of stan_log_lik_* functions
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 19, 2024
1 parent 2e6c14e commit c0eb374
Show file tree
Hide file tree
Showing 9 changed files with 362 additions and 407 deletions.
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,6 @@ S3method(restructure,brmsfit)
S3method(rhat,brmsfit)
S3method(shinystan::launch_shinystan,brmsfit)
S3method(stan_log_lik,brmsterms)
S3method(stan_log_lik,family)
S3method(stan_log_lik,mixfamily)
S3method(stan_log_lik,mvbrmsterms)
S3method(stan_predictor,bframel)
S3method(stan_predictor,bframenl)
Expand Down
32 changes: 23 additions & 9 deletions R/brmsterms.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
y$cov_ranef <- x$cov_ranef
class(y) <- "brmsterms"

y$resp <- ""
if (check_response) {
# extract response variables
y$respform <- validate_resp_formula(formula, empty_ok = FALSE)
if (mv) {
y$resp <- terms_resp(y$respform)
} else {
y$resp <- ""
}
}

Expand All @@ -97,6 +96,11 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
x$pforms[[dp]] <- combine_formulas(formula, x$pforms[[dp]], dp)
}
x$pforms <- move2start(x$pforms, mu_dpars)
for (i in seq_along(family$mix)) {
# store the respective mixture index in each mixture component
# this enables them to be easily passed along, e.g. in stan_log_lik
y$family$mix[[i]]$mix <- i
}
} else if (conv_cats_dpars(x$family)) {
mu_dpars <- str_subset(x$family$dpars, "^mu")
for (dp in mu_dpars) {
Expand All @@ -109,21 +113,20 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
}

# predicted distributional parameters
resp <- ifelse(mv && !is.null(y$resp), y$resp, "")
dpars <- intersect(names(x$pforms), valid_dpars(family))
dpar_forms <- x$pforms[dpars]
nlpars <- setdiff(names(x$pforms), dpars)

y$dpars <- named_list(dpars)
for (dp in dpars) {
if (get_nl(dpar_forms[[dp]])) {
y$dpars[[dp]] <- terms_nlf(dpar_forms[[dp]], nlpars, resp)
y$dpars[[dp]] <- terms_nlf(dpar_forms[[dp]], nlpars, y$resp)
} else {
y$dpars[[dp]] <- terms_lf(dpar_forms[[dp]])
}
y$dpars[[dp]]$family <- dpar_family(family, dp)
y$dpars[[dp]]$dpar <- dp
y$dpars[[dp]]$resp <- resp
y$dpars[[dp]]$resp <- y$resp
if (dpar_class(dp) == "mu") {
y$dpars[[dp]]$respform <- y$respform
y$dpars[[dp]]$adforms <- y$adforms
Expand All @@ -142,12 +145,12 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
attr(nlpar_forms[[nlp]], "center") <- FALSE
}
if (get_nl(nlpar_forms[[nlp]])) {
y$nlpars[[nlp]] <- terms_nlf(nlpar_forms[[nlp]], nlpars, resp)
y$nlpars[[nlp]] <- terms_nlf(nlpar_forms[[nlp]], nlpars, y$resp)
} else {
y$nlpars[[nlp]] <- terms_lf(nlpar_forms[[nlp]])
}
y$nlpars[[nlp]]$nlpar <- nlp
y$nlpars[[nlp]]$resp <- resp
y$nlpars[[nlp]]$resp <- y$resp
check_cs(y$nlpars[[nlp]])
}
used_nlpars <- ufrom_list(c(y$dpars, y$nlpars), "used_nlpars")
Expand Down Expand Up @@ -592,20 +595,31 @@ is.btnl <- function(x) {
inherits(x, "btnl")
}

# figure out if a certain distributional parameter is predicted
is_pred_dpar <- function(bterms, dpar) {
stopifnot(is.brmsterms(bterms))
if (!length(dpar)) {
return(FALSE)
}
mix <- get_mix_id(bterms)
any(paste0(dpar, mix) %in% names(bterms$dpars))
}

# transform mvbrmsterms objects for use in stan_llh.brmsterms
as.brmsterms <- function(x) {
stopifnot(is.mvbrmsterms(x), x$rescor)
families <- ulapply(x$terms, function(y) y$family$family)
stopifnot(all(families == families[1]))
out <- structure(list(), class = "brmsterms")
out$family <- structure(
list(family = paste0(families[1], "_mv"), link = "identity"),
list(family = families[1], link = "identity"),
class = c("brmsfamily", "family")
)
out$family$fun <- paste0(out$family$family, "_mv")
info <- get(paste0(".family_", families[1]))()
out$family[names(info)] <- info
out$sigma_pred <- any(ulapply(x$terms,
function(x) "sigma" %in% names(x$dpar) || is.formula(x$adforms$se)
function(x) is_pred_dpar(x, "sigma") || has_ad_terms(x, "se")
))
weight_forms <- rmNULL(lapply(x$terms, function(x) x$adforms$weights))
if (length(weight_forms)) {
Expand Down
3 changes: 2 additions & 1 deletion R/conditional_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,8 @@ get_int_vars.mvbrmsterms <- function(x, ...) {

#' @export
get_int_vars.brmsterms <- function(x, ...) {
advars <- ulapply(rmNULL(x$adforms[c("trials", "thres", "vint")]), all_vars)
adterms <- c("trials", "thres", "vint")
advars <- ulapply(rmNULL(x$adforms[adterms]), all_vars)
unique(c(advars, get_sp_vars(x, "mo")))
}

Expand Down
1 change: 1 addition & 0 deletions R/data-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ data_response.brmsframe <- function(x, data, check_response = TRUE,
}

# data for addition arguments of the response
# TODO: replace is.formula(x$adforms$term) pattern with has_ad_terms()
if (has_trials(x$family) || is.formula(x$adforms$trials)) {
if (!length(x$adforms$trials)) {
stop2("Specifying 'trials' is required for this model.")
Expand Down
18 changes: 11 additions & 7 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ combine_family_info <- function(x, y, ...) {
y <- as_one_character(y)
unite <- c(
"dpars", "type", "specials", "include",
"const", "cats", "ad", "normalized"
"const", "cats", "ad", "normalized", "mix"
)
if (y %in% c("family", "link")) {
x <- unlist(x)
Expand Down Expand Up @@ -1785,6 +1785,11 @@ no_nu <- function(bterms) {
isTRUE(bterms$rescor) && "student" %in% family_names(bterms)
}

# get mixture index if specified
get_mix_id <- function(family) {
family_info(family, "mix") %||% ""
}

# does the family-link combination have a built-in Stan function?
has_built_in_fun <- function(family, link = NULL, dpar = NULL, cdf = FALSE) {
link <- link %||% family$link
Expand All @@ -1802,19 +1807,18 @@ prepare_family <- function(x) {
stopifnot(is.brmsformula(x) || is.brmsterms(x))
family <- x$family
acframe <- frame_ac(x)
family$fun <- family[["fun"]] %||% family$family
if (use_ac_cov_time(acframe) && has_natural_residuals(x)) {
family$fun <- paste0(family$family, "_time")
family$fun <- paste0(family$fun, "_time")
} else if (has_ac_class(acframe, "sar")) {
acframe_sar <- subset2(acframe, class = "sar")
if (has_ac_subset(acframe_sar, type = "lag")) {
family$fun <- paste0(family$family, "_lagsar")
family$fun <- paste0(family$fun, "_lagsar")
} else if (has_ac_subset(acframe_sar, type = "error")) {
family$fun <- paste0(family$family, "_errorsar")
family$fun <- paste0(family$fun, "_errorsar")
}
} else if (has_ac_class(acframe, "fcor")) {
family$fun <- paste0(family$family, "_fcor")
} else {
family$fun <- family$family
family$fun <- paste0(family$fun, "_fcor")
}
family
}
Expand Down
16 changes: 10 additions & 6 deletions R/formula-ad.R
Original file line number Diff line number Diff line change
Expand Up @@ -376,21 +376,25 @@ trunc_bounds <- function(bterms, data = NULL, incl_family = FALSE,
out
}

# check if addition argument 'subset' ist used in the model
# check if addition argument 'subset' is used in the model
# works for both univariate and multivariate models
has_subset <- function(bterms) {
.has_subset <- function(x) {
is.formula(x$adforms$subset)
}
if (is.brmsterms(bterms)) {
out <- .has_subset(bterms)
out <- has_ad_terms(bterms, "subset")
} else if (is.mvbrmsterms(bterms)) {
out <- any(ulapply(bterms$terms, .has_subset))
out <- any(ulapply(bterms$terms, has_ad_terms, "subset"))
} else {
out <- FALSE
}
out
}

# check if a model has certain addition terms
has_ad_terms <- function(bterms, terms) {
stopifnot(is.brmsterms(bterms), is.character(terms))
any(ulapply(bterms$adforms[terms], is.formula))
}

# construct a list of indices for cross-formula referencing
frame_index <- function(x, data) {
out <- .frame_index(x, data)
Expand Down
3 changes: 1 addition & 2 deletions R/stan-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,9 @@ stan_cor_gen_comp <- function(cor, ncol) {
stan_has_built_in_fun <- function(family, bterms) {
stopifnot(all(c("family", "link") %in% names(family)))
stopifnot(is.brmsterms(bterms))
cens_or_trunc <- stan_log_lik_adj(bterms$adforms, c("cens", "trunc"))
link <- family[["link"]]
dpar <- family[["dpar"]]
if (cens_or_trunc) {
if (has_ad_terms(bterms, c("cens", "trunc"))) {
# only few families have special lcdf and lccdf functions
out <- has_built_in_fun(family, link, cdf = TRUE) ||
has_built_in_fun(bterms, link, dpar = dpar, cdf = TRUE)
Expand Down
Loading

0 comments on commit c0eb374

Please sign in to comment.