Skip to content

Commit

Permalink
remove residuals of old array syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 26, 2023
1 parent 177e02f commit a587ba6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 43 deletions.
2 changes: 1 addition & 1 deletion R/make_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ make_stancode <- function(formula, data, family = gaussian(),
partial_log_lik <- gsub(" target \\+=", " ptarget +=", partial_log_lik)
partial_log_lik <- paste0(
"// compute partial sums of the log-likelihood\n",
"real partial_log_lik", resp, "_lpmf(int[] seq", resp,
"real partial_log_lik", resp, "_lpmf(array[] int seq", resp,
", int start, int end", pll_args$typed, ") {\n",
" real ptarget = 0;\n",
" int N = end - start + 1;\n",
Expand Down
32 changes: 16 additions & 16 deletions R/stan-predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ stan_predictor.mvbrmsterms <- function(x, prior, threads, normalize, ...) {
" Y[n] = {stan_vector(glue('Y_{resp}[n]'))};\n",
" }}\n"
)
str_add(out$pll_args) <- ", data vector[] Y"
str_add(out$pll_args) <- ", data array[] vector Y"
if (any(adnames %in% "weights")) {
str_add(out$tdata_def) <- glue(
" // weights of the pointwise log-likelihood\n",
Expand All @@ -168,8 +168,8 @@ stan_predictor.mvbrmsterms <- function(x, prior, threads, normalize, ...) {
}
miforms <- rmNULL(from_list(adforms, "mi"))
if (length(miforms)) {
str_add(out$model_no_pll_def) <- " vector[nresp] Yl[N] = Y;\n"
str_add(out$pll_args) <- ", vector[] Yl"
str_add(out$model_no_pll_def) <- " array[N] vector[nresp] Yl = Y;\n"
str_add(out$pll_args) <- ", array[] vector Yl"
for (i in seq_along(miforms)) {
j <- match(names(miforms)[i], resp)
# needs to happen outside of reduce_sum
Expand Down Expand Up @@ -510,7 +510,7 @@ stan_re <- function(ranef, prior, normalize, ...) {
" // multi-membership weights\n"
)
str_add(out$pll_args) <- cglue(
", data int[] J_{id}{res}_{ng}, data real[] W_{id}{res}_{ng}"
", data array[] int J_{id}{res}_{ng}, data array[] real W_{id}{res}_{ng}"
)
}
} else {
Expand All @@ -519,7 +519,7 @@ stan_re <- function(ranef, prior, normalize, ...) {
" // grouping indicator per observation\n"
)
str_add(out$pll_args) <- cglue(
", data int[] J_{id}{uresp}"
", data array[] int J_{id}{uresp}"
)
}
if (has_by) {
Expand Down Expand Up @@ -970,7 +970,7 @@ stan_sp <- function(bterms, data, prior, stanvars, ranef, meef, threads,
" array[N{resp}] int Xmo{p}_{j}; // monotonic variable\n"
)
str_add(out$pll_args) <- glue(
", int[] Xmo{p}_{j}, vector simo{p}_{j}"
", array[] int Xmo{p}_{j}, vector simo{p}_{j}"
)
if (is.na(id) || j_id == j) {
# no ID or first appearance of the ID
Expand Down Expand Up @@ -1001,7 +1001,7 @@ stan_sp <- function(bterms, data, prior, stanvars, ranef, meef, threads,
str_add(out$data) <- glue(
" array[N{resp}] int {idxl}; // matching indices\n"
)
str_add(out$pll_args) <- glue(", data int[] {idxl}")
str_add(out$pll_args) <- glue(", data array[] int {idxl}")
}

# prepare special effects coefficients
Expand Down Expand Up @@ -1095,7 +1095,7 @@ stan_gp <- function(bterms, data, prior, threads, normalize, ...) {
" vector[{Ngp}[{J}]] Cgp{pi}_{J};\n"
)
str_add(out$pll_args) <- cglue(
", data int[] {Igp}, data vector Cgp{pi}_{J}"
", data array[] int {Igp}, data vector Cgp{pi}_{J}"
)
str_add_list(out) <- stan_prior(
prior, class = "lscale", coef = sfx2,
Expand All @@ -1111,7 +1111,7 @@ stan_gp <- function(bterms, data, prior, threads, normalize, ...) {
" // indices of latent GP groups per observation\n",
" array[{Ngp}[{J}]] int<lower=1> Jgp{pi}_{J};\n"
)
str_add(out$pll_args) <- cglue(", data int[] Jgp{pi}_{J}")
str_add(out$pll_args) <- cglue(", data array[] int Jgp{pi}_{J}")
}
if (is_approx) {
str_add(out$data) <-
Expand Down Expand Up @@ -1185,7 +1185,7 @@ stan_gp <- function(bterms, data, prior, threads, normalize, ...) {
" // indices of latent GP groups per observation\n",
" array[N{resp}] int<lower=1> Jgp{pi};\n"
)
str_add(out$pll_args) <- glue(", data int[] Jgp{pi}")
str_add(out$pll_args) <- glue(", data array[] int Jgp{pi}")
}
Cgp <- ""
if (bynum) {
Expand Down Expand Up @@ -1445,7 +1445,7 @@ stan_ac <- function(bterms, data, prior, threads, normalize, ...) {
" array[N_tg{p}] int<lower=1> nobs_tg{p};\n"
)
str_add(out$pll_args) <- glue(
", int[] begin_tg{p}, int[] end_tg{p}, int[] nobs_tg{p}"
", array[] int begin_tg{p}, array[] int end_tg{p}, array[] int nobs_tg{p}"
)
str_add(out$tdata_def) <- glue(
" int max_nobs_tg{p} = max(nobs_tg{p});",
Expand All @@ -1460,7 +1460,7 @@ stan_ac <- function(bterms, data, prior, threads, normalize, ...) {
" int n_unique_t{p}; // total number of unique time points\n",
" int n_unique_cortime{p}; // number of unique correlations\n"
)
str_add(out$pll_args) <- glue(", int[,] Jtime_tg{p}")
str_add(out$pll_args) <- glue(", array[,] int Jtime_tg{p}")
if (has_ac_latent_residuals) {
str_add(out$tpar_comp) <- glue(
" // compute correlated time-series residuals\n",
Expand Down Expand Up @@ -1572,7 +1572,7 @@ stan_ac <- function(bterms, data, prior, threads, normalize, ...) {
comment = "SD of the CAR structure", normalize = normalize
)
}
str_add(out$pll_args) <- glue(", vector rcar{p}, data int[] Jloc{p}")
str_add(out$pll_args) <- glue(", vector rcar{p}, data array[] int Jloc{p}")
str_add(out$loopeta) <- glue(" + rcar{p}[Jloc{p}{n}]")
if (acef_car$type %in% c("escar", "esicar")) {
str_add(out$data) <- glue(
Expand Down Expand Up @@ -1746,12 +1746,12 @@ stan_nl <- function(bterms, data, nlpars, threads, ...) {
str_add(out$data) <- glue(
" array[N{resp}, {dim2}] int C{p}_{i};\n"
)
str_add(out$pll_args) <- glue(", data int[,] C{p}_{i}")
str_add(out$pll_args) <- glue(", data array[,] int C{p}_{i}")
} else {
str_add(out$data) <- glue(
" array[N{resp}] int C{p}_{i};\n"
)
str_add(out$pll_args) <- glue(", data int[] C{p}_{i}")
str_add(out$pll_args) <- glue(", data array[] int C{p}_{i}")
}
} else {
if (is_matrix) {
Expand Down Expand Up @@ -1834,7 +1834,7 @@ stan_Xme <- function(meef, prior, threads, normalize) {
" int<lower=0> Nme_{i}; // number of latent values\n",
" array[N] int<lower=1> Jme_{i}; // group index per observation\n"
)
str_add(out$pll_args) <- glue(", data int[] Jme_{i}")
str_add(out$pll_args) <- glue(", data array[] int Jme_{i}")
} else {
Nme <- "N"
}
Expand Down
32 changes: 16 additions & 16 deletions R/stan-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ stan_response <- function(bterms, data, normalize) {
str_add(out$data) <- glue(
" array[N{resp}] vector[ncat{resp}] Y{resp}; // response array\n"
)
str_add(out$pll_args) <- glue(", data vector[] Y{resp}")
str_add(out$pll_args) <- glue(", data array[] vector Y{resp}")
} else if (rtype == "int") {
str_add(out$data) <- glue(
" array[N{resp}, ncat{resp}] int Y{resp}; // response array\n"
Expand All @@ -40,7 +40,7 @@ stan_response <- function(bterms, data, normalize) {
}
} else {
if (rtype == "real") {
# type vector (instead of real[]) is required by some PDFs
# type vector (instead of array real) is required by some PDFs
str_add(out$data) <- glue(
" vector[N{resp}] Y{resp}; // response variable\n"
)
Expand All @@ -49,7 +49,7 @@ stan_response <- function(bterms, data, normalize) {
str_add(out$data) <- glue(
" array[N{resp}] int Y{resp}; // response variable\n"
)
str_add(out$pll_args) <- glue(", data int[] Y{resp}")
str_add(out$pll_args) <- glue(", data array[] int Y{resp}")
}
}
if (has_ndt(family)) {
Expand All @@ -61,7 +61,7 @@ stan_response <- function(bterms, data, normalize) {
str_add(out$data) <- glue(
" array[N{resp}] int trials{resp}; // number of trials\n"
)
str_add(out$pll_args) <- glue(", data int[] trials{resp}")
str_add(out$pll_args) <- glue(", data array[] int trials{resp}")
}
if (is.formula(bterms$adforms$weights)) {
str_add(out$data) <- glue(
Expand Down Expand Up @@ -94,7 +94,7 @@ stan_response <- function(bterms, data, normalize) {
" }}\n"
)
str_add(out$pll_args) <- glue(
", data int[] nthres{resp}, data int[,] Jthres{resp}"
", data array[] int nthres{resp}, data array[,] int Jthres{resp}"
)
} else {
str_add(out$data) <- glue(
Expand All @@ -118,7 +118,7 @@ stan_response <- function(bterms, data, normalize) {
str_add(out$data) <- glue(
" array[N{resp}] int<lower=0,upper=1> dec{resp}; // decisions\n"
)
str_add(out$pll_args) <- glue(", data int[] dec{resp}")
str_add(out$pll_args) <- glue(", data array[] int dec{resp}")
}
if (is.formula(bterms$adforms$rate)) {
str_add(out$data) <- glue(
Expand All @@ -137,15 +137,15 @@ stan_response <- function(bterms, data, normalize) {
str_add(out$data) <- glue(
" array[N{resp}] int<lower=-1,upper=2> cens{resp}; // indicates censoring\n"
)
str_add(out$pll_args) <- glue(", data int[] cens{resp}")
str_add(out$pll_args) <- glue(", data array[] int cens{resp}")
y2_expr <- get_ad_expr(bterms, "cens", "y2")
if (!is.null(y2_expr)) {
# interval censoring is required
if (rtype == "int") {
str_add(out$data) <- glue(
" array[N{resp}] int rcens{resp};"
)
str_add(out$pll_args) <- glue(", data int[] rcens{resp}")
str_add(out$pll_args) <- glue(", data array[] int rcens{resp}")
} else {
str_add(out$data) <- glue(
" vector[N{resp}] rcens{resp};"
Expand All @@ -160,13 +160,13 @@ stan_response <- function(bterms, data, normalize) {
str_add(out$data) <- glue(
" array[N{resp}] {rtype} lb{resp}; // lower truncation bounds;\n"
)
str_add(out$pll_args) <- glue(", data {rtype}[] lb{resp}")
str_add(out$pll_args) <- glue(", data array[] {rtype} lb{resp}")
}
if (any(bounds$ub < Inf)) {
str_add(out$data) <- glue(
" array[N{resp}] {rtype} ub{resp}; // upper truncation bounds\n"
)
str_add(out$pll_args) <- glue(", data {rtype}[] ub{resp}")
str_add(out$pll_args) <- glue(", data array[] {rtype} ub{resp}")
}
if (is.formula(bterms$adforms$mi)) {
# TODO: pass 'Ybounds' via 'standata' instead of hardcoding them
Expand Down Expand Up @@ -215,7 +215,7 @@ stan_response <- function(bterms, data, normalize) {
" // data for custom real vectors\n",
" array[N{resp}] real vreal{seq_len(k)}{resp};\n"
)
str_add(out$pll_args) <- cglue(", data real[] vreal{seq_len(k)}{resp}")
str_add(out$pll_args) <- cglue(", data array[] real vreal{seq_len(k)}{resp}")
}
if (is.formula(bterms$adforms$vint)) {
# vectors of integer values for use in custom families
Expand All @@ -225,7 +225,7 @@ stan_response <- function(bterms, data, normalize) {
" // data for custom integer vectors\n",
" array[N{resp}] int vint{seq_len(k)}{resp};\n"
)
str_add(out$pll_args) <- cglue(", data int[] vint{seq_len(k)}{resp}")
str_add(out$pll_args) <- cglue(", data array[] int vint{seq_len(k)}{resp}")
}
out
}
Expand Down Expand Up @@ -639,7 +639,7 @@ stan_ordinal_lpmf <- function(family, link) {
" * a scalar to be added to the log posterior\n",
" */\n",
" real {family}_{link}_merged_lpmf(",
"int y, real mu, real disc, vector thres, int[] j) {{\n",
"int y, real mu, real disc, vector thres, array[] int j) {{\n",
" return {family}_{link}_lpmf(y | mu, disc, thres[j[1]:j[2]]);\n",
" }}\n"
)
Expand All @@ -656,7 +656,7 @@ stan_ordinal_lpmf <- function(family, link) {
" * a scalar to be added to the log posterior\n",
" */\n",
" real ordered_logistic_merged_lpmf(",
"int y, real mu, vector thres, int[] j) {{\n",
"int y, real mu, vector thres, array[] int j) {{\n",
" return ordered_logistic_lpmf(y | mu, thres[j[1]:j[2]]);\n",
" }}\n"
)
Expand Down Expand Up @@ -743,7 +743,7 @@ stan_hurdle_ordinal_lpmf <- function(family, link) {
" * a scalar to be added to the log posterior\n",
" */\n",
" real {family}_{link}_merged_lpmf(",
"int y, real mu, real hu, real disc, vector thres, int[] j) {{\n",
"int y, real mu, real hu, real disc, vector thres, array[] int j) {{\n",
" return {family}_{link}_lpmf(y | mu, hu, disc, thres[j[1]:j[2]]);\n",
" }}\n"
)
Expand Down Expand Up @@ -774,7 +774,7 @@ stan_hurdle_ordinal_lpmf <- function(family, link) {
" * a scalar to be added to the log posterior\n",
" */\n",
" real hurdle_cumulative_ordered_logistic_merged_lpmf(",
"int y, real mu, real hu, real disc, vector thres, int[] j) {{\n",
"int y, real mu, real hu, real disc, vector thres, array[] int j) {{\n",
" return hurdle_cumulative_ordered_logistic_lpmf(y | mu, hu, disc, thres[j[1]:j[2]]);\n",
" }}\n"
)
Expand Down
20 changes: 10 additions & 10 deletions tests/testthat/tests.make_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -2470,9 +2470,9 @@ test_that("threaded Stan code is correct", {
)
scode <- make_stancode(bform, dat, family = student(), threads = threads)
expect_match2(scode, "real partial_log_lik_lpmf(array[] int seq, int start,")
expect_match2(scode, "mu[n] += bsp[1] * mo(simo_1, Xmo_1[nn])")
expect_match2(scode, "ptarget += student_t_lpdf(Y[start : end] | nu, mu, sigma);")
expect_match2(scode, "+ gp_pred_sigma_1[Jgp_sigma_1[start : end]]")
expect_match2(scode, "mu[n] += (bsp[1]) * mo(simo_1, Xmo_1[nn])")
expect_match2(scode, "ptarget += student_t_lpdf(Y[start:end] | nu, mu, sigma);")
expect_match2(scode, "+ gp_pred_sigma_1[Jgp_sigma_1[start:end]]")
expect_match2(scode, ".* gp_pred_sigma_2_1[Jgp_sigma_2_1[which_gp_sigma_2_1]];")
expect_match2(scode, "sigma[start_at_one(Igp_sigma_2_2[which_gp_sigma_2_2], start)] +=")
expect_match2(scode, "target += reduce_sum(partial_log_lik_lpmf, seq, grainsize, Y,")
Expand All @@ -2481,7 +2481,7 @@ test_that("threaded Stan code is correct", {
visit ~ cs(Trt) + Age, dat, family = sratio(),
threads = threads,
)
expect_match2(scode, "matrix[N, nthres] mucs = Xcs[start : end] * bcs;")
expect_match2(scode, "matrix[N, nthres] mucs = Xcs[start:end] * bcs;")
expect_match2(scode,
"ptarget += sratio_logit_lpmf(Y[nn] | mu[n], disc, Intercept")
expect_match2(scode, " - transpose(mucs[n]));")
Expand All @@ -2493,18 +2493,18 @@ test_that("threaded Stan code is correct", {
threads = threads
)
expect_match2(scode, "mu[n] = exp(nlp_a[n] * C_1[nn] ^ nlp_b[n]);")
expect_match2(scode, "ptarget += gamma_lpdf(Y[start : end] | shape, shape ./ mu);")
expect_match2(scode, "ptarget += gamma_lpdf(Y[start:end] | shape, shape ./ mu);")

bform <- bf(mvbind(count, Exp) ~ Trt) + set_rescor(TRUE)
scode <- make_stancode(bform, dat, gaussian(), threads = threads)
expect_match2(scode, "ptarget += multi_normal_cholesky_lpdf(Y[start : end] | Mu, LSigma);")
expect_match2(scode, "ptarget += multi_normal_cholesky_lpdf(Y[start:end] | Mu, LSigma);")

bform <- bf(brms::mvbind(count, Exp) ~ Trt) + set_rescor(FALSE)
scode <- make_stancode(bform, dat, gaussian(), threads = threads)
expect_match2(scode, "target += reduce_sum(partial_log_lik_count_lpmf, seq_count,")
expect_match2(scode, "target += reduce_sum(partial_log_lik_Exp_lpmf, seq_Exp,")
expect_match2(scode,
"ptarget += normal_id_glm_lpdf(Y_Exp[start : end] | Xc_Exp[start : end], Intercept_Exp, b_Exp, sigma_Exp);"
"ptarget += normal_id_glm_lpdf(Y_Exp[start:end] | Xc_Exp[start:end], Intercept_Exp, b_Exp, sigma_Exp);"
)

scode <- make_stancode(
Expand Down Expand Up @@ -2548,8 +2548,8 @@ test_that("Un-normalized Stan code is correct", {
normalize = FALSE, threads = threading(2)
)
expect_match2(scode, "target += reduce_sum(partial_log_lik_lpmf, seq, grainsize, Y, Xc, b,")
expect_match2(scode, " Intercept, J_1, Z_1_1, r_1_1, J_2, Z_2_1, r_2_1);")
expect_match2(scode, "ptarget += poisson_log_glm_lupmf(Y[start : end] | Xc[start : end], mu, b);")
expect_match2(scode, "Intercept, J_1, Z_1_1, r_1_1, J_2, Z_2_1, r_2_1);")
expect_match2(scode, "ptarget += poisson_log_glm_lupmf(Y[start:end] | Xc[start:end], mu, b);")
expect_match2(scode, "lprior += student_t_lupdf(b | 5, 0, 10);")
expect_match2(scode, "lprior += student_t_lupdf(Intercept | 3, 1.4, 2.5);")
expect_match2(scode, "lprior += cauchy_lupdf(sd_1 | 0, 2);")
Expand All @@ -2562,7 +2562,7 @@ test_that("Un-normalized Stan code is correct", {
normalize = FALSE
)
expect_match2(scode, "target += sratio_cloglog_lpmf(Y[n] | mu[n], disc, Intercept")
expect_match2(scode, " - transpose(mucs[n]));")
expect_match2(scode, "- transpose(mucs[n]));")

# Check that user-specified custom distributions stay normalized
dat <- data.frame(size = 10, y = sample(0:10, 20, TRUE), x = rnorm(20))
Expand Down

0 comments on commit a587ba6

Please sign in to comment.