diff --git a/NAMESPACE b/NAMESPACE index f2539edac..5e4615112 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -275,8 +275,6 @@ S3method(restructure,brmsfit) S3method(rhat,brmsfit) S3method(shinystan::launch_shinystan,brmsfit) S3method(stan_log_lik,brmsterms) -S3method(stan_log_lik,family) -S3method(stan_log_lik,mixfamily) S3method(stan_log_lik,mvbrmsterms) S3method(stan_predictor,bframel) S3method(stan_predictor,bframenl) diff --git a/R/brmsterms.R b/R/brmsterms.R index 7addbcd84..e12e93ca4 100644 --- a/R/brmsterms.R +++ b/R/brmsterms.R @@ -64,13 +64,12 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE, y$cov_ranef <- x$cov_ranef class(y) <- "brmsterms" + y$resp <- "" if (check_response) { # extract response variables y$respform <- validate_resp_formula(formula, empty_ok = FALSE) if (mv) { y$resp <- terms_resp(y$respform) - } else { - y$resp <- "" } } @@ -97,6 +96,11 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE, x$pforms[[dp]] <- combine_formulas(formula, x$pforms[[dp]], dp) } x$pforms <- move2start(x$pforms, mu_dpars) + for (i in seq_along(family$mix)) { + # store the respective mixture index in each mixture component + # this enables them to be easily passed along, e.g. in stan_log_lik + y$family$mix[[i]]$mix <- i + } } else if (conv_cats_dpars(x$family)) { mu_dpars <- str_subset(x$family$dpars, "^mu") for (dp in mu_dpars) { @@ -109,7 +113,6 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE, } # predicted distributional parameters - resp <- ifelse(mv && !is.null(y$resp), y$resp, "") dpars <- intersect(names(x$pforms), valid_dpars(family)) dpar_forms <- x$pforms[dpars] nlpars <- setdiff(names(x$pforms), dpars) @@ -117,13 +120,13 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE, y$dpars <- named_list(dpars) for (dp in dpars) { if (get_nl(dpar_forms[[dp]])) { - y$dpars[[dp]] <- terms_nlf(dpar_forms[[dp]], nlpars, resp) + y$dpars[[dp]] <- terms_nlf(dpar_forms[[dp]], nlpars, y$resp) } else { y$dpars[[dp]] <- terms_lf(dpar_forms[[dp]]) } y$dpars[[dp]]$family <- dpar_family(family, dp) y$dpars[[dp]]$dpar <- dp - y$dpars[[dp]]$resp <- resp + y$dpars[[dp]]$resp <- y$resp if (dpar_class(dp) == "mu") { y$dpars[[dp]]$respform <- y$respform y$dpars[[dp]]$adforms <- y$adforms @@ -142,12 +145,12 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE, attr(nlpar_forms[[nlp]], "center") <- FALSE } if (get_nl(nlpar_forms[[nlp]])) { - y$nlpars[[nlp]] <- terms_nlf(nlpar_forms[[nlp]], nlpars, resp) + y$nlpars[[nlp]] <- terms_nlf(nlpar_forms[[nlp]], nlpars, y$resp) } else { y$nlpars[[nlp]] <- terms_lf(nlpar_forms[[nlp]]) } y$nlpars[[nlp]]$nlpar <- nlp - y$nlpars[[nlp]]$resp <- resp + y$nlpars[[nlp]]$resp <- y$resp check_cs(y$nlpars[[nlp]]) } used_nlpars <- ufrom_list(c(y$dpars, y$nlpars), "used_nlpars") @@ -592,6 +595,16 @@ is.btnl <- function(x) { inherits(x, "btnl") } +# figure out if a certain distributional parameter is predicted +is_pred_dpar <- function(bterms, dpar) { + stopifnot(is.brmsterms(bterms)) + if (!length(dpar)) { + return(FALSE) + } + mix <- get_mix_id(bterms) + any(paste0(dpar, mix) %in% names(bterms$dpars)) +} + # transform mvbrmsterms objects for use in stan_llh.brmsterms as.brmsterms <- function(x) { stopifnot(is.mvbrmsterms(x), x$rescor) @@ -599,13 +612,14 @@ as.brmsterms <- function(x) { stopifnot(all(families == families[1])) out <- structure(list(), class = "brmsterms") out$family <- structure( - list(family = paste0(families[1], "_mv"), link = "identity"), + list(family = families[1], link = "identity"), class = c("brmsfamily", "family") ) + out$family$fun <- paste0(out$family$family, "_mv") info <- get(paste0(".family_", families[1]))() out$family[names(info)] <- info out$sigma_pred <- any(ulapply(x$terms, - function(x) "sigma" %in% names(x$dpar) || is.formula(x$adforms$se) + function(x) is_pred_dpar(x, "sigma") || has_ad_terms(x, "se") )) weight_forms <- rmNULL(lapply(x$terms, function(x) x$adforms$weights)) if (length(weight_forms)) { diff --git a/R/conditional_effects.R b/R/conditional_effects.R index 15689154e..53acc3d6e 100644 --- a/R/conditional_effects.R +++ b/R/conditional_effects.R @@ -606,7 +606,8 @@ get_int_vars.mvbrmsterms <- function(x, ...) { #' @export get_int_vars.brmsterms <- function(x, ...) { - advars <- ulapply(rmNULL(x$adforms[c("trials", "thres", "vint")]), all_vars) + adterms <- c("trials", "thres", "vint") + advars <- ulapply(rmNULL(x$adforms[adterms]), all_vars) unique(c(advars, get_sp_vars(x, "mo"))) } diff --git a/R/data-response.R b/R/data-response.R index dfb90a4ac..958c12861 100644 --- a/R/data-response.R +++ b/R/data-response.R @@ -171,6 +171,7 @@ data_response.brmsframe <- function(x, data, check_response = TRUE, } # data for addition arguments of the response + # TODO: replace is.formula(x$adforms$term) pattern with has_ad_terms() if (has_trials(x$family) || is.formula(x$adforms$trials)) { if (!length(x$adforms$trials)) { stop2("Specifying 'trials' is required for this model.") diff --git a/R/families.R b/R/families.R index 0839aa7cd..07e870d9e 100644 --- a/R/families.R +++ b/R/families.R @@ -441,7 +441,7 @@ combine_family_info <- function(x, y, ...) { y <- as_one_character(y) unite <- c( "dpars", "type", "specials", "include", - "const", "cats", "ad", "normalized" + "const", "cats", "ad", "normalized", "mix" ) if (y %in% c("family", "link")) { x <- unlist(x) @@ -1785,6 +1785,11 @@ no_nu <- function(bterms) { isTRUE(bterms$rescor) && "student" %in% family_names(bterms) } +# get mixture index if specified +get_mix_id <- function(family) { + family_info(family, "mix") %||% "" +} + # does the family-link combination have a built-in Stan function? has_built_in_fun <- function(family, link = NULL, dpar = NULL, cdf = FALSE) { link <- link %||% family$link @@ -1802,19 +1807,18 @@ prepare_family <- function(x) { stopifnot(is.brmsformula(x) || is.brmsterms(x)) family <- x$family acframe <- frame_ac(x) + family$fun <- family[["fun"]] %||% family$family if (use_ac_cov_time(acframe) && has_natural_residuals(x)) { - family$fun <- paste0(family$family, "_time") + family$fun <- paste0(family$fun, "_time") } else if (has_ac_class(acframe, "sar")) { acframe_sar <- subset2(acframe, class = "sar") if (has_ac_subset(acframe_sar, type = "lag")) { - family$fun <- paste0(family$family, "_lagsar") + family$fun <- paste0(family$fun, "_lagsar") } else if (has_ac_subset(acframe_sar, type = "error")) { - family$fun <- paste0(family$family, "_errorsar") + family$fun <- paste0(family$fun, "_errorsar") } } else if (has_ac_class(acframe, "fcor")) { - family$fun <- paste0(family$family, "_fcor") - } else { - family$fun <- family$family + family$fun <- paste0(family$fun, "_fcor") } family } diff --git a/R/formula-ad.R b/R/formula-ad.R index c02c84b34..ea08f85f5 100644 --- a/R/formula-ad.R +++ b/R/formula-ad.R @@ -376,21 +376,25 @@ trunc_bounds <- function(bterms, data = NULL, incl_family = FALSE, out } -# check if addition argument 'subset' ist used in the model +# check if addition argument 'subset' is used in the model +# works for both univariate and multivariate models has_subset <- function(bterms) { - .has_subset <- function(x) { - is.formula(x$adforms$subset) - } if (is.brmsterms(bterms)) { - out <- .has_subset(bterms) + out <- has_ad_terms(bterms, "subset") } else if (is.mvbrmsterms(bterms)) { - out <- any(ulapply(bterms$terms, .has_subset)) + out <- any(ulapply(bterms$terms, has_ad_terms, "subset")) } else { out <- FALSE } out } +# check if a model has certain addition terms +has_ad_terms <- function(bterms, terms) { + stopifnot(is.brmsterms(bterms), is.character(terms)) + any(ulapply(bterms$adforms[terms], is.formula)) +} + # construct a list of indices for cross-formula referencing frame_index <- function(x, data) { out <- .frame_index(x, data) diff --git a/R/stan-helpers.R b/R/stan-helpers.R index 68af271c5..29e887d91 100644 --- a/R/stan-helpers.R +++ b/R/stan-helpers.R @@ -93,10 +93,9 @@ stan_cor_gen_comp <- function(cor, ncol) { stan_has_built_in_fun <- function(family, bterms) { stopifnot(all(c("family", "link") %in% names(family))) stopifnot(is.brmsterms(bterms)) - cens_or_trunc <- stan_log_lik_adj(bterms$adforms, c("cens", "trunc")) link <- family[["link"]] dpar <- family[["dpar"]] - if (cens_or_trunc) { + if (has_ad_terms(bterms, c("cens", "trunc"))) { # only few families have special lcdf and lccdf functions out <- has_built_in_fun(family, link, cdf = TRUE) || has_built_in_fun(bterms, link, dpar = dpar, cdf = TRUE) diff --git a/R/stan-likelihood.R b/R/stan-likelihood.R index 7aa932d29..5a8f14df7 100644 --- a/R/stan-likelihood.R +++ b/R/stan-likelihood.R @@ -1,28 +1,42 @@ # unless otherwise specified, functions return a single character # string defining the likelihood of the model in Stan language +# Stan code for the log likelihood stan_log_lik <- function(x, ...) { UseMethod("stan_log_lik") } -# Stan code for the model likelihood -# @param bterms object of class brmsterms -# @param mix optional mixture component ID -# @param ptheta are mixing proportions predicted? #' @export -stan_log_lik.family <- function(x, bterms, threads, normalize, - mix = "", ptheta = FALSE, ...) { - # TODO: find a better way to pass resp and mix through stan_log_lik_* functions +stan_log_lik.brmsterms <- function(x, ...) { + if (is.mixfamily(x$family)) { + out <- stan_log_lik_mixfamily(x, ...) + } else { + out <- stan_log_lik_family(x, ...) + } + out +} + +#' @export +stan_log_lik.mvbrmsterms <- function(x, ...) { + if (x$rescor) { + out <- stan_log_lik(as.brmsterms(x), ...) + } else { + out <- ulapply(x$terms, stan_log_lik, ...) + } + out +} + +# Stan code for the log likelihood of a regular family +stan_log_lik_family <- function(bterms, threads, ...) { stopifnot(is.brmsterms(bterms)) - stopifnot(length(mix) == 1L) - bterms$family <- x - resp <- usc(combine_prefix(bterms)) # prepare family part of the likelihood - log_lik_args <- nlist(bterms, resp, mix, threads) - log_lik_fun <- paste0("stan_log_lik_", prepare_family(bterms)$fun) + log_lik_args <- nlist(bterms, threads, ...) + log_lik_fun <- prepare_family(bterms)$fun + log_lik_fun <- paste0("stan_log_lik_", log_lik_fun) ll <- do_call(log_lik_fun, log_lik_args) # incorporate other parts into the likelihood - args <- nlist(ll, bterms, resp, threads, normalize, mix, ptheta) + args <- nlist(ll, bterms, threads, ...) + mix <- get_mix_id(bterms) if (nzchar(mix)) { out <- do_call(stan_log_lik_mix, args) } else if (is.formula(bterms$adforms$cens)) { @@ -34,6 +48,7 @@ stan_log_lik.family <- function(x, bterms, threads, normalize, } if (grepl(stan_nn_regex(), out) && !nzchar(mix)) { # loop over likelihood if it cannot be vectorized + resp <- usc(bterms$resp) out <- paste0( " for (n in 1:N", resp, ") {\n", stan_nn_def(threads), @@ -44,25 +59,25 @@ stan_log_lik.family <- function(x, bterms, threads, normalize, out } -#' @export -stan_log_lik.mixfamily <- function(x, bterms, threads, ...) { +# Stan code for the log likelihood of a mixture family +stan_log_lik_mixfamily <- function(bterms, threads, ...) { + stopifnot(is.brmsterms(bterms), is.mixfamily(bterms$family)) dp_ids <- dpar_id(names(bterms$dpars)) fdp_ids <- dpar_id(names(bterms$fdpars)) - resp <- usc(bterms$resp) - ptheta <- any(dpar_class(names(bterms$dpars)) %in% "theta") - ll <- rep(NA, length(x$mix)) - for (i in seq_along(x$mix)) { + pred_mix_prob <- any(dpar_class(names(bterms$dpars)) %in% "theta") + ll <- rep(NA, length(bterms$family$mix)) + for (i in seq_along(ll)) { sbterms <- bterms + sbterms$family <- sbterms$family$mix[[i]] sbterms$dpars <- sbterms$dpars[dp_ids == i] sbterms$fdpars <- sbterms$fdpars[fdp_ids == i] - ll[i] <- stan_log_lik( - x$mix[[i]], sbterms, mix = i, ptheta = ptheta, - threads = threads, ... + ll[i] <- stan_log_lik_family( + sbterms, pred_mix_prob = pred_mix_prob, threads = threads, ... ) } - resp <- usc(combine_prefix(bterms)) + resp <- usc(bterms$resp) n <- stan_nn(threads) - has_weights <- is.formula(bterms$adforms$weights) + has_weights <- has_ad_terms(bterms, "weights") weights <- str_if(has_weights, glue("weights{resp}{n} * ")) out <- glue( " // likelihood of the mixture model\n", @@ -78,41 +93,28 @@ stan_log_lik.mixfamily <- function(x, bterms, threads, ...) { out } -#' @export -stan_log_lik.brmsterms <- function(x, ...) { - stan_log_lik(x$family, bterms = x, ...) -} - -#' @export -stan_log_lik.mvbrmsterms <- function(x, ...) { - if (x$rescor) { - out <- stan_log_lik(as.brmsterms(x), ...) - } else { - out <- ulapply(x$terms, stan_log_lik, ...) - } - out -} - # default likelihood in Stan language -stan_log_lik_general <- function(ll, bterms, threads, normalize, resp = "", ...) { +stan_log_lik_general <- function(ll, bterms, threads, normalize, ...) { stopifnot(is.sdist(ll)) require_n <- grepl(stan_nn_regex(), ll$args) n <- str_if(require_n, stan_nn(threads), stan_slice(threads)) lpdf <- stan_log_lik_lpdf_name(bterms, normalize, dist = ll$dist) Y <- stan_log_lik_Y_name(bterms) - tr <- stan_log_lik_trunc(ll, bterms, resp = resp, threads = threads) + resp <- usc(bterms$resp) + tr <- stan_log_lik_trunc(ll, bterms, threads = threads, ...) glue("{tp()}{ll$dist}_{lpdf}({Y}{resp}{n}{ll$shift} | {ll$args}){tr};\n") } # censored likelihood in Stan language -stan_log_lik_cens <- function(ll, bterms, threads, normalize, resp = "", ...) { +stan_log_lik_cens <- function(ll, bterms, threads, normalize, ...) { stopifnot(is.sdist(ll)) cens <- eval_rhs(bterms$adforms$cens) lpdf <- stan_log_lik_lpdf_name(bterms, normalize, dist = ll$dist) Y <- stan_log_lik_Y_name(bterms) + resp <- usc(bterms$resp) tp <- tp() - has_weights <- is.formula(bterms$adforms$weights) - has_trunc <- is.formula(bterms$adforms$trunc) + has_weights <- has_ad_terms(bterms, "weights") + has_trunc <- has_ad_terms(bterms, "trunc") has_interval_cens <- cens$vars$y2 != "NA" if (ll$vec && !(has_weights || has_trunc)) { # vectorized log-likelihood contributions @@ -148,7 +150,7 @@ stan_log_lik_cens <- function(ll, bterms, threads, normalize, resp = "", ...) { # non-vectorized likelihood contributions n <- stan_nn(threads) w <- str_if(has_weights, glue("weights{resp}{n} * ")) - tr <- stan_log_lik_trunc(ll, bterms, resp = resp, threads = threads) + tr <- stan_log_lik_trunc(ll, bterms, threads = threads) out <- glue( " // special treatment of censored data\n", " if (cens{resp}{n} == 0) {{\n", @@ -173,11 +175,12 @@ stan_log_lik_cens <- function(ll, bterms, threads, normalize, resp = "", ...) { } # weighted likelihood in Stan language -stan_log_lik_weights <- function(ll, bterms, threads, normalize, resp = "", ...) { +stan_log_lik_weights <- function(ll, bterms, threads, normalize, ...) { stopifnot(is.sdist(ll)) - tr <- stan_log_lik_trunc(ll, bterms, resp = resp, threads = threads) + tr <- stan_log_lik_trunc(ll, bterms, threads = threads) lpdf <- stan_log_lik_lpdf_name(bterms, normalize, dist = ll$dist) Y <- stan_log_lik_Y_name(bterms) + resp <- usc(bterms$resp) n <- stan_nn(threads) glue( "{tp()}weights{resp}{n} * ({ll$dist}_{lpdf}", @@ -186,14 +189,17 @@ stan_log_lik_weights <- function(ll, bterms, threads, normalize, resp = "", ...) } # likelihood of a single mixture component -stan_log_lik_mix <- function(ll, bterms, mix, ptheta, threads, - normalize, resp = "", ...) { +# @param pred_mix_prob are mixing proportions predicted? +stan_log_lik_mix <- function(ll, bterms, pred_mix_prob, threads, + normalize, ...) { stopifnot(is.sdist(ll)) - theta <- str_if(ptheta, + resp <- usc(bterms$resp) + mix <- get_mix_id(bterms) + theta <- str_if(pred_mix_prob, glue("theta{mix}{resp}[n]"), glue("log(theta{mix}{resp})") ) - tr <- stan_log_lik_trunc(ll, bterms, resp = resp, threads = threads) + tr <- stan_log_lik_trunc(ll, bterms, threads = threads) lpdf <- stan_log_lik_lpdf_name(bterms, normalize, dist = ll$dist) Y <- stan_log_lik_Y_name(bterms) n <- stan_nn(threads) @@ -235,13 +241,13 @@ stan_log_lik_mix <- function(ll, bterms, mix, ptheta, threads, # truncated part of the likelihood # @param short use the T[, ] syntax? -stan_log_lik_trunc <- function(ll, bterms, threads, resp = "", - short = FALSE) { +stan_log_lik_trunc <- function(ll, bterms, threads, short = FALSE, ...) { stopifnot(is.sdist(ll)) bounds <- bterms$frame$resp$bounds if (!any(bounds$lb > -Inf | bounds$ub < Inf)) { return("") } + resp <- usc(bterms$resp) n <- stan_nn(threads) m1 <- str_if(use_int(bterms), " - 1") lb <- str_if(any(bounds$lb > -Inf), glue("lb{resp}{n}{m1}")) @@ -286,13 +292,17 @@ stan_log_lik_Y_name <- function(bterms) { ifelse(is.formula(bterms$adforms$mi), "Yl", "Y") } -# prepare names of distributional parameters +# prepare Stan code for distributional parameters # @param reqn will the likelihood be wrapped in a loop over n? # @param dpars optional names of distributional parameters to be prepared # if not specified will prepare all distributional parameters -stan_log_lik_dpars <- function(bterms, reqn = NULL, resp = "", mix = "", - dpars = NULL, type = NULL) { - reqn <- reqn %||% stan_log_lik_adj(bterms, mix = mix) +# @param type optional type of distribution parameters to be extract +# see valid_dpars() for details +# @return a named list with elements containing the Stan code per parameter +stan_log_lik_dpars <- function(bterms, reqn = stan_log_lik_adj(bterms), + dpars = NULL, type = NULL, ...) { + resp <- usc(bterms$resp) + mix <- get_mix_id(bterms) if (is.null(dpars)) { dpars <- paste0(valid_dpars(bterms, type = type), mix) } @@ -305,12 +315,21 @@ stan_log_lik_dpars <- function(bterms, reqn = NULL, resp = "", mix = "", named_list(dpars, out) } +# stan code for log likelihood variables originating from addition terms +stan_log_lik_advars <- function(bterms, advars, + reqn = stan_log_lik_adj(bterms), + threads = NULL, ...) { + slice <- str_if(reqn, stan_nn(threads), stan_slice(threads)) + out <- paste0(advars, usc(bterms$resp), slice) + named_list(advars, out) +} + # adjust lpdf name if a more efficient version is available # for a specific link. For instance 'poisson_log' stan_log_lik_simple_lpdf <- function(lpdf, link, bterms, sep = "_") { stopifnot(is.brmsterms(bterms)) - cens_or_trunc <- stan_log_lik_adj(bterms, c("cens", "trunc")) - if (bterms$family$link == link && !cens_or_trunc) { + has_cens_or_trunc <- has_ad_terms(bterms, c("cens", "trunc")) + if (bterms$family$link == link && !has_cens_or_trunc) { lpdf <- paste0(lpdf, sep, link) } lpdf @@ -318,22 +337,22 @@ stan_log_lik_simple_lpdf <- function(lpdf, link, bterms, sep = "_") { # prepare _logit suffix for distributional parameters # used in zero-inflated and hurdle models -stan_log_lik_dpar_usc_logit <- function(dpar, bterms) { - stopifnot(dpar %in% c("zi", "hu")) +stan_log_lik_dpar_usc_logit <- function(bterms, dpar) { stopifnot(is.brmsterms(bterms)) - cens_or_trunc <- stan_log_lik_adj(bterms, c("cens", "trunc")) + stopifnot(dpar %in% c("zi", "hu")) + has_cens_or_trunc <- has_ad_terms(bterms, c("cens", "trunc")) usc_logit <- isTRUE(bterms$dpars[[dpar]]$family$link == "logit") - str_if(usc_logit && !cens_or_trunc, "_logit") + str_if(usc_logit && !has_cens_or_trunc, "_logit") } # add 'se' to 'sigma' within the Stan likelihood -stan_log_lik_add_se <- function(sigma, bterms, reqn = NULL, resp = "", - threads = NULL) { - if (!is.formula(bterms$adforms$se)) { +stan_log_lik_add_se <- function(sigma, bterms, reqn = stan_log_lik_adj(bterms), + threads = NULL, ...) { + if (!has_ad_terms(bterms, "se")) { return(sigma) } - reqn <- reqn %||% stan_log_lik_adj(bterms) nse <- str_if(reqn, stan_nn(threads), stan_slice(threads)) + resp <- usc(bterms$resp) if (no_sigma(bterms)) { sigma <- glue("se{resp}{nse}") } else { @@ -344,15 +363,15 @@ stan_log_lik_add_se <- function(sigma, bterms, reqn = NULL, resp = "", # multiply 'dpar' by the 'rate' denominator within the Stan likelihood # @param log add the rate denominator on the log scale if sensible? -# @param reqn2 like reqn indicates if a loop over observations is needed -# which makes all the computations scalar (non-vectorized). However, -# censoring may turn non-vectorized into vectorized statements later on -# (see stan_log_lik_cens) which then makes the * operator invalid and -# requires .* instead. Accordingly, reqn2 should be FALSE if [n] is required -# only because of censoring. -stan_log_lik_multiply_rate_denom <- function(dpar, bterms, reqn, resp = "", - log = FALSE, transform = NULL, - threads = NULL, reqn2 = reqn) { +# @param req_dot_multiply Censoring may turn non-vectorized into vectorized +# statements later on (see stan_log_lik_cens) which then makes the * operator +# invalid and requires .* instead. Accordingly, req_dot_multiply should be +# FALSE if [n] is required only because of censoring. +stan_log_lik_multiply_rate_denom <- function( + dpar, bterms, reqn = stan_log_lik_adj(bterms), + req_dot_multiply = stan_log_lik_adj(bterms, c("trunc", "weights")), + log = FALSE, transform = NULL, threads = NULL, ...) { + dpar_transform <- dpar if (!is.null(transform)) { dpar_transform <- glue("{transform}({dpar})") @@ -360,67 +379,61 @@ stan_log_lik_multiply_rate_denom <- function(dpar, bterms, reqn, resp = "", if (!is.formula(bterms$adforms$rate)) { return(dpar_transform) } + resp <- usc(bterms$resp) ndenom <- str_if(reqn, stan_nn(threads), stan_slice(threads)) denom <- glue("denom{resp}{ndenom}") - cens_or_trunc <- stan_log_lik_adj(bterms, c("cens", "trunc")) - if (log && bterms$family$link == "log" && !cens_or_trunc) { + has_cens_or_trunc <- has_ad_terms(bterms, c("cens", "trunc")) + if (log && bterms$family$link == "log" && !has_cens_or_trunc) { denom <- glue("log_{denom}") operator <- "+" } else { # dpar without resp name or index dpar_clean <- sub("(_|\\[).*", "", dpar) is_pred <- dpar_clean %in% c("mu", names(bterms$dpars)) - operator <- str_if(reqn2 || !is_pred, "*", ".*") + operator <- str_if(req_dot_multiply || !is_pred, "*", ".*") } glue("{dpar_transform} {operator} {denom}") } # check if the log-likelihood needs to be adjusted to a non-vectorized form # either because of addition terms or mixture modeling -# @param x named list of formulas or brmsterms object -# @param adds vector of addition argument names -# @param mix optional mixture component index +# @param terms vector of addition term names # @return a single logical value -stan_log_lik_adj <- function(x, adds = c("weights", "cens", "trunc"), - mix = "") { - adds <- match.arg(adds, several.ok = TRUE) - mix <- as_one_character(mix) - if (is.brmsterms(x)) { - x <- x$adforms - } - any(ulapply(x[adds], is.formula)) || nzchar(mix) +stan_log_lik_adj <- function(bterms, terms = c("weights", "cens", "trunc")) { + stopifnot(is.brmsterms(bterms)) + terms <- match.arg(terms, several.ok = TRUE) + mix <- get_mix_id(bterms) + has_ad_terms(bterms, terms) || any(nzchar(mix)) } # one function per family -stan_log_lik_gaussian <- function(bterms, resp = "", mix = "", threads = NULL, - ...) { +stan_log_lik_gaussian <- function(bterms, ...) { if (use_glm_primitive(bterms)) { - p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads) - p$sigma <- paste0("sigma", resp) + p <- args_glm_primitive(bterms$dpars$mu, ...) + p$sigma <- paste0("sigma", usc(bterms$resp)) out <- sdist("normal_id_glm", p$x, p$alpha, p$beta, p$sigma) } else { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) - p$sigma <- stan_log_lik_add_se(p$sigma, bterms, reqn, resp, threads) + p <- stan_log_lik_dpars(bterms) + p$sigma <- stan_log_lik_add_se(p$sigma, bterms, ...) out <- sdist("normal", p$mu, p$sigma) } out } -stan_log_lik_gaussian_mv <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) || bterms$sigma_pred - p <- list(Mu = paste0("Mu", if (reqn) "[n]")) - p$LSigma <- paste0("LSigma", if (bterms$sigma_pred) "[n]") +stan_log_lik_gaussian_mv <- function(bterms,...) { + reqn <- stan_log_lik_adj(bterms) || bterms$sigma_pred + p <- list(Mu = paste0("Mu", str_if(reqn, "[n]"))) + p$LSigma <- paste0("LSigma", str_if(bterms$sigma_pred, "[n]")) sdist("multi_normal_cholesky", p$Mu, p$LSigma) } -stan_log_lik_gaussian_time <- function(bterms, resp = "", mix = "", ...) { +stan_log_lik_gaussian_time <- function(bterms, ...) { if (stan_log_lik_adj(bterms)) { stop2("Invalid addition arguments for this model.") } has_se <- is.formula(bterms$adforms$se) flex <- has_ac_class(bterms$frame$ac, "unstr") - p <- stan_log_lik_dpars(bterms, FALSE, resp, mix) + p <- stan_log_lik_dpars(bterms, reqn = FALSE) v <- c("Lcortime", "nobs_tg", "begin_tg", "end_tg") if (has_se) { c(v) <- "se2" @@ -428,7 +441,7 @@ stan_log_lik_gaussian_time <- function(bterms, resp = "", mix = "", ...) { if (flex) { c(v) <- "Jtime_tg" } - p[v] <- as.list(paste0(v, resp)) + p[v] <- as.list(paste0(v, usc(bterms$resp))) sfx <- str_if("sigma" %in% names(bterms$dpars), "het", "hom") sfx <- str_if(has_se, paste0(sfx, "_se"), sfx) sfx <- str_if(flex, paste0(sfx, "_flex"), sfx) @@ -438,58 +451,53 @@ stan_log_lik_gaussian_time <- function(bterms, resp = "", mix = "", ...) { ) } -stan_log_lik_gaussian_fcor <- function(bterms, resp = "", mix = "", ...) { - has_se <- is.formula(bterms$adforms$se) - if (stan_log_lik_adj(bterms) || has_se) { +stan_log_lik_gaussian_fcor <- function(bterms, ...) { + if (stan_log_lik_adj(bterms) || has_ad_terms(bterms, "se")) { stop2("Invalid addition arguments for this model.") } - p <- stan_log_lik_dpars(bterms, FALSE, resp, mix) - p$Lfcor <- paste0("Lfcor", resp) + p <- stan_log_lik_dpars(bterms, reqn = FALSE) + p$Lfcor <- paste0("Lfcor", usc(bterms$resp)) sfx <- str_if("sigma" %in% names(bterms$dpars), "het", "hom") sdist(glue("normal_fcor_{sfx}"), p$mu, p$sigma, p$Lfcor) } -stan_log_lik_gaussian_lagsar <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - p <- stan_log_lik_dpars(bterms, FALSE, resp, mix) - p$sigma <- stan_log_lik_add_se(p$sigma, bterms, FALSE, resp, threads) +stan_log_lik_gaussian_lagsar <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = FALSE) + p$sigma <- stan_log_lik_add_se(p$sigma, bterms, reqn = FALSE, ...) v <- c("lagsar", "Msar", "eigenMsar") - p[v] <- as.list(paste0(v, resp)) + p[v] <- as.list(paste0(v, usc(bterms$resp))) sdist("normal_lagsar", p$mu, p$sigma, p$lagsar, p$Msar, p$eigenMsar) } -stan_log_lik_gaussian_errorsar <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - p <- stan_log_lik_dpars(bterms, FALSE, resp, mix) - p$sigma <- stan_log_lik_add_se(p$sigma, bterms, FALSE, resp, threads) +stan_log_lik_gaussian_errorsar <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = FALSE) + p$sigma <- stan_log_lik_add_se(p$sigma, bterms, reqn = FALSE, ...) v <- c("errorsar", "Msar", "eigenMsar") - p[v] <- as.list(paste0(v, resp)) + p[v] <- as.list(paste0(v, usc(bterms$resp))) sdist("normal_errorsar", p$mu, p$sigma, p$errorsar, p$Msar, p$eigenMsar) } -stan_log_lik_student <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) - p$sigma <- stan_log_lik_add_se(p$sigma, bterms, reqn, resp, threads) +stan_log_lik_student <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) + p$sigma <- stan_log_lik_add_se(p$sigma, bterms, ...) sdist("student_t", p$nu, p$mu, p$sigma) } -stan_log_lik_student_mv <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) || bterms$sigma_pred - p <- stan_log_lik_dpars(bterms, reqn, resp, mix, dpars = "nu") - p$Mu <- paste0("Mu", if (reqn) "[n]") - p$Sigma <- paste0("Sigma", if (bterms$sigma_pred) "[n]") +stan_log_lik_student_mv <- function(bterms, ...) { + reqn <- stan_log_lik_adj(bterms) || bterms$sigma_pred + p <- stan_log_lik_dpars(bterms, reqn = reqn, dpars = "nu") + p$Mu <- paste0("Mu", str_if(reqn, "[n]")) + p$Sigma <- paste0("Sigma", str_if(bterms$sigma_pred, "[n]")) sdist("multi_student_t", p$nu, p$Mu, p$Sigma) } -stan_log_lik_student_time <- function(bterms, resp = "", mix = "", ...) { +stan_log_lik_student_time <- function(bterms, ...) { if (stan_log_lik_adj(bterms)) { stop2("Invalid addition arguments for this model.") } has_se <- is.formula(bterms$adforms$se) flex <- has_ac_class(bterms$frame$ac, "unstr") - p <- stan_log_lik_dpars(bterms, FALSE, resp, mix) + p <- stan_log_lik_dpars(bterms, reqn = FALSE) v <- c("Lcortime", "nobs_tg", "begin_tg", "end_tg") if (has_se) { c(v) <- "se2" @@ -497,7 +505,7 @@ stan_log_lik_student_time <- function(bterms, resp = "", mix = "", ...) { if (flex) { c(v) <- "Jtime_tg" } - p[v] <- as.list(paste0(v, resp)) + p[v] <- as.list(paste0(v, usc(bterms$resp))) sfx <- str_if("sigma" %in% names(bterms$dpars), "het", "hom") sfx <- str_if(has_se, paste0(sfx, "_se"), sfx) sfx <- str_if(flex, paste0(sfx, "_flex"), sfx) @@ -507,127 +515,99 @@ stan_log_lik_student_time <- function(bterms, resp = "", mix = "", ...) { ) } -stan_log_lik_student_fcor <- function(bterms, resp = "", mix = "", ...) { - has_se <- is.formula(bterms$adforms$se) - if (stan_log_lik_adj(bterms) || has_se) { +stan_log_lik_student_fcor <- function(bterms, ...) { + if (stan_log_lik_adj(bterms) || has_ad_terms(bterms, "se")) { stop2("Invalid addition arguments for this model.") } - p <- stan_log_lik_dpars(bterms, FALSE, resp, mix) - p$Lfcor <- paste0("Lfcor", resp) + p <- stan_log_lik_dpars(bterms, reqn = FALSE) + p$Lfcor <- paste0("Lfcor", usc(bterms$resp)) sfx <- str_if("sigma" %in% names(bterms$dpars), "het", "hom") sdist(glue("student_t_fcor_{sfx}"), p$nu, p$mu, p$sigma, p$Lfcor) } -stan_log_lik_student_lagsar <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - p <- stan_log_lik_dpars(bterms, FALSE, resp, mix) - p$sigma <- stan_log_lik_add_se(p$sigma, bterms, FALSE, resp, threads) +stan_log_lik_student_lagsar <- function(bterms,...) { + p <- stan_log_lik_dpars(bterms, reqn = FALSE) + p$sigma <- stan_log_lik_add_se(p$sigma, bterms, reqn = FALSE, ...) v <- c("lagsar", "Msar", "eigenMsar") - p[v] <- as.list(paste0(v, resp)) + p[v] <- as.list(paste0(v, usc(bterms$resp))) sdist("student_t_lagsar", p$nu, p$mu, p$sigma, p$lagsar, p$Msar, p$eigenMsar) } -stan_log_lik_student_errorsar <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - p <- stan_log_lik_dpars(bterms, FALSE, resp, mix) - p$sigma <- stan_log_lik_add_se(p$sigma, bterms, FALSE, resp, threads) +stan_log_lik_student_errorsar <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = FALSE) + p$sigma <- stan_log_lik_add_se(p$sigma, bterms, reqn = FALSE, ...) v <- c("errorsar", "Msar", "eigenMsar") - p[v] <- as.list(paste0(v, resp)) + p[v] <- as.list(paste0(v, usc(bterms$resp))) sdist("student_t_errorsar", p$nu, p$mu, p$sigma, p$errorsar, p$Msar, p$eigenMsar) } -stan_log_lik_lognormal <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) +stan_log_lik_lognormal <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) sdist("lognormal", p$mu, p$sigma) } -stan_log_lik_shifted_lognormal <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) +stan_log_lik_shifted_lognormal <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) sdist("lognormal", p$mu, p$sigma, shift = paste0(" - ", p$ndt)) } -stan_log_lik_asym_laplace <- function(bterms, resp = "", mix = "", ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_asym_laplace <- function(bterms,...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) sdist("asym_laplace", p$mu, p$sigma, p$quantile, vec = FALSE) } -stan_log_lik_skew_normal <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) - p$sigma <- stan_log_lik_add_se(p$sigma, bterms, reqn, resp, threads) +stan_log_lik_skew_normal <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) + p$sigma <- stan_log_lik_add_se(p$sigma, bterms, ...) # required because of CP parameterization of mu and sigma - nomega <- any(grepl(stan_nn_regex(), c(p$sigma, p$alpha))) - nomega <- str_if(reqn && nomega, "[n]") - p$omega <- paste0("omega", mix, resp, nomega) + mix <- get_mix_id(bterms) + resp <- usc(bterms$resp) + reqn <- any(grepl(stan_nn_regex(), c(p$sigma, p$alpha))) + p$omega <- paste0("omega", mix, resp, str_if(reqn, "[n]")) sdist("skew_normal", p$mu, p$omega, p$alpha) } -stan_log_lik_poisson <- function(bterms, resp = "", mix = "", threads = NULL, - ...) { +stan_log_lik_poisson <- function(bterms, ...) { if (use_glm_primitive(bterms)) { - p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads) + p <- args_glm_primitive(bterms$dpars$mu, ...) out <- sdist("poisson_log_glm", p$x, p$alpha, p$beta) } else { - reqn <- stan_log_lik_adj(bterms, mix = mix) - reqn2 <- stan_log_lik_adj(bterms, c("trunc", "weights"), mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) - p$mu <- stan_log_lik_multiply_rate_denom( - p$mu, bterms, reqn, resp, log = TRUE, - reqn2 = reqn2, threads = threads - ) + p <- stan_log_lik_dpars(bterms) + p$mu <- stan_log_lik_multiply_rate_denom(p$mu, bterms, log = TRUE, ...) lpdf <- stan_log_lik_simple_lpdf("poisson", "log", bterms) out <- sdist(lpdf, p$mu) } out } -stan_log_lik_negbinomial <- function(bterms, resp = "", mix = "", threads = NULL, - ...) { +stan_log_lik_negbinomial <- function(bterms, ...) { if (use_glm_primitive(bterms)) { - p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads) - p$shape <- paste0("shape", resp) + p <- args_glm_primitive(bterms$dpars$mu, ...) + p$shape <- paste0("shape", usc(bterms$resp)) out <- sdist("neg_binomial_2_log_glm", p$x, p$alpha, p$beta, p$shape) } else { - reqn <- stan_log_lik_adj(bterms, mix = mix) - reqn2 <- stan_log_lik_adj(bterms, c("trunc", "weights"), mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) - p$mu <- stan_log_lik_multiply_rate_denom( - p$mu, bterms, reqn, resp, log = TRUE, - reqn2 = reqn2, threads = threads - ) - p$shape <- stan_log_lik_multiply_rate_denom( - p$shape, bterms, reqn, resp, - reqn2 = reqn2, threads = threads - ) + p <- stan_log_lik_dpars(bterms) + p$mu <- stan_log_lik_multiply_rate_denom(p$mu, bterms, log = TRUE, ...) + p$shape <- stan_log_lik_multiply_rate_denom(p$shape, bterms, ...) lpdf <- stan_log_lik_simple_lpdf("neg_binomial_2", "log", bterms) out <- sdist(lpdf, p$mu, p$shape) } out } -stan_log_lik_negbinomial2 <- function(bterms, resp = "", mix = "", threads = NULL, - ...) { +stan_log_lik_negbinomial2 <- function(bterms, ...) { if (use_glm_primitive(bterms)) { - p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads) - p$sigma <- paste0("sigma", resp) + p <- args_glm_primitive(bterms$dpars$mu, ...) + p$sigma <- paste0("sigma", usc(bterms$resp)) p$shape <- paste0("inv(", p$sigma, ")") out <- sdist("neg_binomial_2_log_glm", p$x, p$alpha, p$beta, p$shape) } else { - reqn <- stan_log_lik_adj(bterms, mix = mix) - reqn2 <- stan_log_lik_adj(bterms, c("trunc", "weights"), mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) - p$mu <- stan_log_lik_multiply_rate_denom( - p$mu, bterms, reqn, resp, log = TRUE, - reqn2 = reqn2, threads = threads - ) + p <- stan_log_lik_dpars(bterms) + p$mu <- stan_log_lik_multiply_rate_denom(p$mu, bterms, log = TRUE, ...) p$shape <- stan_log_lik_multiply_rate_denom( - p$sigma, bterms, reqn, resp, transform = "inv", - reqn2 = reqn2, threads = threads + p$sigma, bterms, transform = "inv", ... ) lpdf <- stan_log_lik_simple_lpdf("neg_binomial_2", "log", bterms) out <- sdist(lpdf, p$mu, p$shape) @@ -635,164 +615,138 @@ stan_log_lik_negbinomial2 <- function(bterms, resp = "", mix = "", threads = NUL out } -stan_log_lik_geometric <- function(bterms, resp = "", mix = "", threads = NULL, - ...) { +stan_log_lik_geometric <- function(bterms, ...) { if (use_glm_primitive(bterms)) { - p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads) + p <- args_glm_primitive(bterms$dpars$mu, ...) p$shape <- "1" out <- sdist("neg_binomial_2_log_glm", p$x, p$alpha, p$beta, p$shape) } else { - reqn <- stan_log_lik_adj(bterms, mix = mix) - reqn2 <- stan_log_lik_adj(bterms, c("trunc", "weights"), mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) + p <- stan_log_lik_dpars(bterms) p$shape <- "1" - p$mu <- stan_log_lik_multiply_rate_denom( - p$mu, bterms, reqn, resp, log = TRUE, - reqn2 = reqn2, threads = threads - ) - p$shape <- stan_log_lik_multiply_rate_denom( - p$shape, bterms, reqn, resp, - reqn2 = reqn2, threads = threads - ) + p$mu <- stan_log_lik_multiply_rate_denom(p$mu, bterms, log = TRUE, ...) + p$shape <- stan_log_lik_multiply_rate_denom(p$shape, bterms, ...) lpdf <- stan_log_lik_simple_lpdf("neg_binomial_2", "log", bterms) out <- sdist(lpdf, p$mu, p$shape) } } -stan_log_lik_binomial <- function(bterms, resp = "", mix = "", threads = NULL, - ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) - slice <- str_if(reqn, stan_nn(threads), stan_slice(threads)) - p$trials <- paste0("trials", resp, slice) +stan_log_lik_binomial <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) + p$trials <- stan_log_lik_advars(bterms, "trials", ...)$trials lpdf <- stan_log_lik_simple_lpdf("binomial", "logit", bterms) sdist(lpdf, p$trials, p$mu) } -stan_log_lik_beta_binomial <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) - n <- stan_nn(threads) +stan_log_lik_beta_binomial <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) + p$trials <- stan_log_lik_advars(bterms, "trials", ...)$trials sdist( "beta_binomial", - paste0("trials", resp, n), + p$trials, paste0(p$mu, " * ", p$phi), paste0("(1 - ", p$mu, ") * ", p$phi), vec = FALSE ) } -stan_log_lik_bernoulli <- function(bterms, resp = "", mix = "", threads = NULL, - ...) { +stan_log_lik_bernoulli <- function(bterms, ...) { if (use_glm_primitive(bterms)) { - p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads) + p <- args_glm_primitive(bterms$dpars$mu, ...) out <- sdist("bernoulli_logit_glm", p$x, p$alpha, p$beta) } else { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) + p <- stan_log_lik_dpars(bterms) lpdf <- stan_log_lik_simple_lpdf("bernoulli", "logit", bterms) out <- sdist(lpdf, p$mu) } out } -stan_log_lik_discrete_weibull <- function(bterms, resp = "", mix = "", ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_discrete_weibull <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) sdist("discrete_weibull", p$mu, p$shape, vec = FALSE) } -stan_log_lik_com_poisson <- function(bterms, resp = "", mix = "", ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_com_poisson <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) lpdf <- stan_log_lik_simple_lpdf("com_poisson", "log", bterms) sdist(lpdf, p$mu, p$shape, vec = FALSE) } -stan_log_lik_gamma <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) || - paste0("shape", mix) %in% names(bterms$dpars) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) +stan_log_lik_gamma <- function(bterms, ...) { + reqn <- stan_log_lik_adj(bterms) || is_pred_dpar(bterms, "shape") + p <- stan_log_lik_dpars(bterms, reqn = reqn) # Stan uses shape-rate parameterization with rate = shape / mean div_op <- str_if(reqn, " / ", " ./ ") sdist("gamma", p$shape, paste0(p$shape, div_op, p$mu)) } -stan_log_lik_exponential <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) +stan_log_lik_exponential <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) # Stan uses rate parameterization with rate = 1 / mean sdist("exponential", paste0("inv(", p$mu, ")")) } -stan_log_lik_weibull <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) +stan_log_lik_weibull <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) # Stan uses shape-scale parameterization for weibull - need_dot_div <- !reqn && paste0("shape", mix) %in% names(bterms$dpars) + need_dot_div <- !stan_log_lik_adj(bterms) && is_pred_dpar(bterms, "shape") div_op <- str_if(need_dot_div, " ./ ", " / ") p$scale <- paste0(p$mu, div_op, "tgamma(1 + 1", div_op, p$shape, ")") sdist("weibull", p$shape, p$scale) } -stan_log_lik_frechet <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) +stan_log_lik_frechet <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) # Stan uses shape-scale parameterization for frechet - need_dot_div <- !reqn && paste0("nu", mix) %in% names(bterms$dpars) + need_dot_div <- !stan_log_lik_adj(bterms) && is_pred_dpar(bterms, "nu") div_op <- str_if(need_dot_div, " ./ ", " / ") p$scale <- paste0(p$mu, div_op, "tgamma(1 - 1", div_op, p$nu, ")") sdist("frechet", p$nu, p$scale) } -stan_log_lik_gen_extreme_value <- function(bterms, resp = "", mix = "", ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_gen_extreme_value <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) sdist("gen_extreme_value", p$mu, p$sigma, p$xi, vec = FALSE) } -stan_log_lik_exgaussian <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) +stan_log_lik_exgaussian <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) sdist( "exp_mod_normal", paste0(p$mu, " - ", p$beta), p$sigma, paste0("inv(", p$beta, ")") ) } -stan_log_lik_inverse.gaussian <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) || - glue("shape{mix}") %in% names(bterms$dpars) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) - n <- str_if(reqn, "[n]") +stan_log_lik_inverse.gaussian <- function(bterms, ...) { + reqn <- stan_log_lik_adj(bterms) || is_pred_dpar(bterms, "shape") + p <- stan_log_lik_dpars(bterms, reqn = reqn) sdist("inv_gaussian", p$mu, p$shape, vec = FALSE) } -stan_log_lik_wiener <- function(bterms, resp = "", mix = "", threads = NULL, - ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) - n <- stan_nn(threads) - p$dec <- paste0("dec", resp, n) +stan_log_lik_wiener <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) + p$dec <- stan_log_lik_advars(bterms, "dec", reqn = TRUE, ...)$dec sdist("wiener_diffusion", p$dec, p$bs, p$ndt, p$bias, p$mu, vec = FALSE) } -stan_log_lik_beta <- function(bterms, resp = "", mix = "", ...) { - # TODO: check if we still require n when phi is predicted - # and check the same for other families too - reqn <- stan_log_lik_adj(bterms, mix = mix) || - paste0("phi", mix) %in% names(bterms$dpars) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) +stan_log_lik_beta <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) + req_dot_multiply <- !stan_log_lik_adj(bterms) && is_pred_dpar(bterms, "phi") + multiply <- str_if(req_dot_multiply, " .* ", " * ") sdist("beta", - paste0(p$mu, " * ", p$phi), - paste0("(1 - ", p$mu, ") * ", p$phi) + paste0(p$mu, multiply, p$phi), + paste0("(1 - ", p$mu, ")", multiply, p$phi) ) } -stan_log_lik_von_mises <- function(bterms, resp = "", mix = "", ...) { - reqn <- stan_log_lik_adj(bterms, mix = mix) - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) +stan_log_lik_von_mises <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms) sdist("von_mises", p$mu, p$kappa) } -stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = NULL, ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_cox <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) + resp <- usc(bterms$resp) p$bhaz <- paste0("bhaz", resp, "[n]") p$cbhaz <- paste0("cbhaz", resp, "[n]") lpdf <- "cox" @@ -802,74 +756,66 @@ stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = NULL, ...) { sdist(lpdf, p$mu, p$bhaz, p$cbhaz, vec = FALSE) } -stan_log_lik_cumulative <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { +stan_log_lik_cumulative <- function(bterms, ...) { if (use_glm_primitive(bterms)) { - p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads) + p <- args_glm_primitive(bterms$dpars$mu, ...) out <- sdist("ordered_logistic_glm", p$x, p$beta, p$alpha) } else { - out <- stan_log_lik_ordinal(bterms, resp, mix, threads, ...) + out <- stan_log_lik_ordinal(bterms, ...) } out } -stan_log_lik_sratio <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - stan_log_lik_ordinal(bterms, resp, mix, threads, ...) +stan_log_lik_sratio <- function(bterms, ...) { + stan_log_lik_ordinal(bterms, ...) } -stan_log_lik_cratio <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - stan_log_lik_ordinal(bterms, resp, mix, threads, ...) +stan_log_lik_cratio <- function(bterms, ...) { + stan_log_lik_ordinal(bterms, ...) } -stan_log_lik_acat <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - stan_log_lik_ordinal(bterms, resp, mix, threads, ...) +stan_log_lik_acat <- function(bterms, ...) { + stan_log_lik_ordinal(bterms, ...) } -stan_log_lik_categorical <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { +stan_log_lik_categorical <- function(bterms, ...) { stopifnot(bterms$family$link == "logit") - stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed if (use_glm_primitive_categorical(bterms)) { bterms1 <- bterms$dpars[[1]] bterms1$family <- bterms$family - p <- args_glm_primitive(bterms1, resp = resp, threads = threads) + p <- args_glm_primitive(bterms1, ...) out <- sdist("categorical_logit_glm", p$x, p$alpha, p$beta) } else { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", type = "multi") + p <- stan_log_lik_dpars(bterms, reqn = TRUE, dpars = "mu", type = "multi") out <- sdist("categorical_logit", p$mu, vec = FALSE) } out } -stan_log_lik_multinomial <- function(bterms, resp = "", mix = "", ...) { +stan_log_lik_multinomial <- function(bterms, ...) { stopifnot(bterms$family$link == "logit") - stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", type = "multi") + p <- stan_log_lik_dpars(bterms, reqn = TRUE, dpars = "mu", type = "multi") sdist("multinomial_logit2", p$mu, vec = FALSE) } -stan_log_lik_dirichlet <- function(bterms, resp = "", mix = "", ...) { +stan_log_lik_dirichlet <- function(bterms, ...) { stopifnot(bterms$family$link == "logit") - stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed - mu <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", type = "multi")$mu - reqn <- glue("phi{mix}") %in% names(bterms$dpars) - phi <- stan_log_lik_dpars(bterms, reqn, resp, mix, dpars = "phi")$phi + mu <- stan_log_lik_dpars(bterms, reqn = TRUE, dpars = "mu", type = "multi")$mu + reqn_phi <- is_pred_dpar(bterms, "phi") + phi <- stan_log_lik_dpars(bterms, reqn = reqn_phi, dpars = "phi")$phi sdist("dirichlet_logit", mu, phi, vec = FALSE) } -stan_log_lik_dirichlet2 <- function(bterms, resp = "", mix = "", ...) { - stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed - mu <- stan_log_lik_dpars(bterms, TRUE, resp, mix, dpars = "mu", type = "multi")$mu +stan_log_lik_dirichlet2 <- function(bterms,...) { + mu <- stan_log_lik_dpars(bterms, reqn = TRUE, dpars = "mu", type = "multi")$mu sdist("dirichlet", mu, vec = FALSE) } -stan_log_lik_logistic_normal <- function(bterms, resp = "", mix = "", ...) { +stan_log_lik_logistic_normal <- function(bterms, ...) { stopifnot(bterms$family$link == "identity") - stopifnot(!isTRUE(nzchar(mix))) # mixture models are not allowed - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix, type = "multi") + resp <- usc(bterms$resp) + mix <- get_mix_id(bterms) + p <- stan_log_lik_dpars(bterms, reqn = TRUE, type = "multi") p$Llncor <- glue("Llncor{mix}{resp}") p$refcat <- get_refcat(bterms$family, int = TRUE) sdist( @@ -879,10 +825,8 @@ stan_log_lik_logistic_normal <- function(bterms, resp = "", mix = "", ...) { ) } -stan_log_lik_ordinal <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - prefix <- paste0(str_if(nzchar(mix), paste0("_mu", mix)), resp) - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_ordinal <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) if (use_ordered_builtin(bterms, "logit")) { lpdf <- "ordered_logistic" p[grepl("^disc", names(p))] <- NULL @@ -894,12 +838,14 @@ stan_log_lik_ordinal <- function(bterms, resp = "", mix = "", } if (has_thres_groups(bterms)) { str_add(lpdf) <- "_merged" - n <- stan_nn(threads) - p$Jthres <- paste0("Jthres", resp, n) + p$Jthres <- stan_log_lik_advars(bterms, "Jthres", reqn = TRUE, ...)$Jthres p$thres <- "merged_Intercept" } else { p$thres <- "Intercept" } + resp <- usc(bterms$resp) + mix <- get_mix_id(bterms) + prefix <- paste0(str_if(nzchar(mix), paste0("_mu", mix)), resp) str_add(p$thres) <- prefix if (has_sum_to_zero_thres(bterms)) { str_add(p$thres) <- "_stz" @@ -914,39 +860,37 @@ stan_log_lik_ordinal <- function(bterms, resp = "", mix = "", sdist(lpdf, p$mu, p$disc, p$thres, p$Jthres, vec = FALSE) } -stan_log_lik_hurdle_poisson <- function(bterms, resp = "", mix = "", ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_hurdle_poisson <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) lpdf <- stan_log_lik_simple_lpdf("hurdle_poisson", "log", bterms) - lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit("hu", bterms)) + lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "hu")) sdist(lpdf, p$mu, p$hu, vec = FALSE) } -stan_log_lik_hurdle_negbinomial <- function(bterms, resp = "", mix = "", ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_hurdle_negbinomial <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) lpdf <- stan_log_lik_simple_lpdf("hurdle_neg_binomial", "log", bterms) - lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit("hu", bterms)) + lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "hu")) sdist(lpdf, p$mu, p$shape, p$hu, vec = FALSE) } -stan_log_lik_hurdle_gamma <- function(bterms, resp = "", mix = "", ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) - usc_logit <- stan_log_lik_dpar_usc_logit("hu", bterms) +stan_log_lik_hurdle_gamma <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) + usc_logit <- stan_log_lik_dpar_usc_logit(bterms, "hu") lpdf <- paste0("hurdle_gamma", usc_logit) # Stan uses shape-rate parameterization for gamma with rate = shape / mean sdist(lpdf, p$shape, paste0(p$shape, " / ", p$mu), p$hu, vec = FALSE) } -stan_log_lik_hurdle_lognormal <- function(bterms, resp = "", mix = "", ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) - usc_logit <- stan_log_lik_dpar_usc_logit("hu", bterms) +stan_log_lik_hurdle_lognormal <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) + usc_logit <- stan_log_lik_dpar_usc_logit(bterms, "hu") lpdf <- paste0("hurdle_lognormal", usc_logit) sdist(lpdf, p$mu, p$sigma, p$hu, vec = FALSE) } -stan_log_lik_hurdle_cumulative <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - prefix <- paste0(str_if(nzchar(mix), paste0("_mu", mix)), resp) - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_hurdle_cumulative <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) if (use_ordered_builtin(bterms, "logit")) { lpdf <- "hurdle_cumulative_ordered_logistic" } else if (use_ordered_builtin(bterms, "probit")) { @@ -956,12 +900,14 @@ stan_log_lik_hurdle_cumulative <- function(bterms, resp = "", mix = "", } if (has_thres_groups(bterms)) { str_add(lpdf) <- "_merged" - n <- stan_nn(threads) - p$Jthres <- paste0("Jthres", resp, n) + p$Jthres <- stan_log_lik_advars(bterms, "Jthres", reqn = TRUE, ...)$Jthres p$thres <- "merged_Intercept" } else { p$thres <- "Intercept" } + resp <- usc(bterms$resp) + mix <- get_mix_id(bterms) + prefix <- paste0(str_if(nzchar(mix), paste0("_mu", mix)), resp) str_add(p$thres) <- prefix if (has_sum_to_zero_thres(bterms)) { str_add(p$thres) <- "_stz" @@ -976,74 +922,66 @@ stan_log_lik_hurdle_cumulative <- function(bterms, resp = "", mix = "", sdist(lpdf, p$mu, p$hu, p$disc, p$thres, p$Jthres, vec = FALSE) } -stan_log_lik_zero_inflated_poisson <- function(bterms, resp = "", mix = "", - ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_zero_inflated_poisson <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) lpdf <- stan_log_lik_simple_lpdf("zero_inflated_poisson", "log", bterms) - lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit("zi", bterms)) + lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "zi")) sdist(lpdf, p$mu, p$zi, vec = FALSE) } -stan_log_lik_zero_inflated_negbinomial <- function(bterms, resp = "", mix = "", - ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_zero_inflated_negbinomial <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) lpdf <- stan_log_lik_simple_lpdf("zero_inflated_neg_binomial", "log", bterms) - lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit("zi", bterms)) + lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "zi")) sdist(lpdf, p$mu, p$shape, p$zi, vec = FALSE) } -stan_log_lik_zero_inflated_binomial <- function(bterms, resp = "", mix = "", - threads = NULL, ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) - n <- stan_nn(threads) - p$trials <- paste0("trials", resp, n) +stan_log_lik_zero_inflated_binomial <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) + p$trials <- stan_log_lik_advars(bterms, "trials", reqn = TRUE, ...)$trials lpdf <- "zero_inflated_binomial" lpdf <- stan_log_lik_simple_lpdf(lpdf, "logit", bterms, sep = "_b") - lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit("zi", bterms)) + lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "zi")) sdist(lpdf, p$trials, p$mu, p$zi, vec = FALSE) } -stan_log_lik_zero_inflated_beta_binomial <- function(bterms, resp = "", - mix = "", threads = NULL, - ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) - n <- stan_nn(threads) - p$trials <- paste0("trials", resp, n) +stan_log_lik_zero_inflated_beta_binomial <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) + p$trials <- stan_log_lik_advars(bterms, "trials", reqn = TRUE, ...)$trials lpdf <- "zero_inflated_beta_binomial" - lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit("zi", bterms)) + lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "zi")) sdist(lpdf, p$trials, p$mu, p$phi, p$zi, vec = FALSE) } -stan_log_lik_zero_inflated_beta <- function(bterms, resp = "", mix = "", ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) - usc_logit <- stan_log_lik_dpar_usc_logit("zi", bterms) +stan_log_lik_zero_inflated_beta <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) + usc_logit <- stan_log_lik_dpar_usc_logit(bterms, "zi") lpdf <- paste0("zero_inflated_beta", usc_logit) sdist(lpdf, p$mu, p$phi, p$zi, vec = FALSE) } -stan_log_lik_zero_one_inflated_beta <- function(bterms, resp = "", mix = "", - ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) +stan_log_lik_zero_one_inflated_beta <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) sdist("zero_one_inflated_beta", p$mu, p$phi, p$zoi, p$coi, vec = FALSE) } -stan_log_lik_zero_inflated_asym_laplace <- function(bterms, resp = "", mix = "", - ...) { - p <- stan_log_lik_dpars(bterms, TRUE, resp, mix) - usc_logit <- stan_log_lik_dpar_usc_logit("zi", bterms) +stan_log_lik_zero_inflated_asym_laplace <- function(bterms, ...) { + p <- stan_log_lik_dpars(bterms, reqn = TRUE) + usc_logit <- stan_log_lik_dpar_usc_logit(bterms, "zi") lpdf <- paste0("zero_inflated_asym_laplace", usc_logit) sdist(lpdf, p$mu, p$sigma, p$quantile, p$zi, vec = FALSE) } -stan_log_lik_custom <- function(bterms, resp = "", mix = "", threads = NULL, ...) { +stan_log_lik_custom <- function(bterms, threads = NULL, ...) { family <- bterms$family no_loop <- isFALSE(family$loop) - if (no_loop && (stan_log_lik_adj(bterms, mix = mix))) { + if (no_loop && (stan_log_lik_adj(bterms))) { stop2("This model requires evaluating the custom ", "likelihood as a loop over observations.") } - reqn <- !no_loop - p <- stan_log_lik_dpars(bterms, reqn, resp, mix) + resp <- usc(bterms$resp) + p <- stan_log_lik_dpars(bterms, reqn = !no_loop) + mix <- get_mix_id(bterms) dpars <- paste0(family$dpars, mix) if (is_ordinal(family)) { prefix <- paste0(resp, if (nzchar(mix)) paste0("_mu", mix)) @@ -1239,11 +1177,11 @@ stan_ordinal_lpmf <- function(family, link) { } # log probability density for hurdle ordinal models +# TODO: generalize to non-cumulative families? # @return a character string stan_hurdle_ordinal_lpmf <- function(family, link) { family <- as_one_character(family) link <- as_one_character(link) - # TODO: generalize to non-cumulative families? stopifnot(family == "hurdle_cumulative") inv_link <- stan_inv_link(link) th <- function(k) { @@ -1355,9 +1293,7 @@ stan_hurdle_ordinal_lpmf <- function(family, link) { out } -# use Stan GLM primitive functions? -# @param bterms a brmsterms object -# @return TRUE or FALSE +# use a Stan GLM primitive function? use_glm_primitive <- function(bterms) { stopifnot(is.brmsterms(bterms)) # the model can only have a single predicted parameter @@ -1366,7 +1302,7 @@ use_glm_primitive <- function(bterms) { non_glm_adterms <- c("se", "weights", "thres", "cens", "trunc", "rate") if (!is.btl(mu) || length(bterms$dpars) > 1L || isTRUE(bterms$rescor) || is.formula(mu$ac) || - any(names(bterms$adforms) %in% non_glm_adterms)) { + has_ad_terms(bterms, non_glm_adterms)) { return(FALSE) } # some primitives do not support special terms in the way @@ -1387,9 +1323,7 @@ use_glm_primitive <- function(bterms) { length(all_terms(mu$fe)) > 0 && !is_sparse(mu$fe) } -# use Stan categorical GLM primitive function? -# @param bterms a brmsterms object -# @return TRUE or FALSE +# use Stan's categorical GLM primitive function? use_glm_primitive_categorical <- function(bterms) { stopifnot(is.brmsterms(bterms)) if (!is_categorical(bterms)) { @@ -1411,10 +1345,10 @@ use_glm_primitive_categorical <- function(bterms) { # standard arguments for primitive Stan GLM functions # @param bterms a btl object -# @param resp optional name of the response variable # @return a named list of Stan code snippets -args_glm_primitive <- function(bterms, resp = "", threads = NULL) { +args_glm_primitive <- function(bterms, threads = NULL, ...) { stopifnot(is.btl(bterms)) + resp <- usc(bterms$resp) decomp <- get_decomp(bterms$fe) center_X <- stan_center_X(bterms) slice <- stan_slice(threads) diff --git a/tests/testthat/tests.stancode.R b/tests/testthat/tests.stancode.R index 544ecb528..8f4aadb1e 100644 --- a/tests/testthat/tests.stancode.R +++ b/tests/testthat/tests.stancode.R @@ -2406,7 +2406,7 @@ test_that("likelihood of distributional beta models is correct", { scode <- stancode( bf(prop ~ 1, phi ~ 1), data = dat, family = Beta() ) - expect_match2(scode, "beta_lpdf(Y[n] | mu[n] * phi[n], (1 - mu[n]) * phi[n])") + expect_match2(scode, "target += beta_lpdf(Y | mu .* phi, (1 - mu) .* phi);") }) test_that("student-t group-level effects work without errors", {