Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use multiple indexing in Stan code for varying effects #772

Open
paul-buerkner opened this issue Oct 8, 2019 · 5 comments
Open

Use multiple indexing in Stan code for varying effects #772

paul-buerkner opened this issue Oct 8, 2019 · 5 comments

Comments

@paul-buerkner
Copy link
Owner

paul-buerkner commented Oct 8, 2019

Currently, the Stan code of a multilevel models looks a little verbose due to first indexing columns and then looping over observations to select the right elements of the computed vectors. This has historically been more efficient that other indexing options available in Stan. However, with the multiple indexing feature of Stan, there should be some much less verbose option available.

Preliminary analysis suggests that this will actually make the sampling less efficient (see branch 're-multiple-indexing') but more testing is required to say something reliable about the efficiency aspect.

Here is how the Stan code of a varying intercept, varying slope model currently looks:

data {
  int<lower=1> N;  // number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_1_1;
  vector[N] Z_1_2;
  int<lower=1> NC_1;  // number of group-level correlations
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // population-level effects
  // temporary intercept for centered predictors
  real Intercept;
  real<lower=0> sigma;  // residual SD
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  // cholesky factor of correlation matrix
  cholesky_factor_corr[M_1] L_1;
}
transformed parameters {
  // actual group-level effects
  matrix[N_1, M_1] r_1 = (diag_pre_multiply(sd_1, L_1) * z_1)';
  // using vectors speeds up indexing in loops
  vector[N_1] r_1_1 = r_1[, 1];
  vector[N_1] r_1_2 = r_1[, 2];
}
model {
  // initialize linear predictor term
  vector[N] mu = Intercept + Xc * b;
  for (n in 1:N) {
    // add more terms to the linear predictor
    mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n];
  }
  ...
}

Here is how the Stan code of a varying intercept, varying slope model could look like

data {
  int<lower=1> N;  // number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  matrix[N, M_1] Z_1;
  int<lower=1> NC_1;  // number of group-level correlations
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // population-level effects
  // temporary intercept for centered predictors
  real Intercept;
  real<lower=0> sigma;  // residual SD
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  // cholesky factor of correlation matrix
  cholesky_factor_corr[M_1] L_1;
}
transformed parameters {
  // actual group-level effects
  matrix[N_1, M_1] r_1 = (diag_pre_multiply(sd_1, L_1) * z_1)';
}
model {
  // initialize linear predictor term
  vector[N] mu = Intercept + Xc * b + rows_dot_product(Z_1, r_1[J_1]);
  ...
}
@SteveBronder
Copy link

Just wanted to update, for 2.27 there's a few notable things that could make for a good bit faster brms thats related to this issue. I think we can try these once rstan updates to 2.27

  1. Use Eigen Maps for data and transformed data stan-dev/stanc3#865 makes it so that data and transformed data is stored as an Eigen::Map<Eigen::Matrix>. The PR goes over what that means in detail, but the tl;dr is that we should do large data manipulations once in the transformed data block and then we won't need to copy data when calling stan math functions in the model/transformed parameters block. This is only a thing for data and not for parameters

  2. In csr_matrix_time_vector data * var specialization stan-dev/math#2462 I made a specialization for csr_matrix_time_vector(data, parameters), so sparse may become more efficient.

  3. We have fma() specializations for matrices and vectors now so

  // add more terms to the linear predictor
  mu += fma(r_1_1[J_1], Z_1_1, r_1_2[J_1] .* Z_1_2);

Should be faster than

  for (n in 1:N) {
    // add more terms to the linear predictor
    mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n];
  }

though we have an optimization in the compiler to take

  // add more terms to the linear predictor
  mu += r_1_1[J_1] .* Z_1_1 + r_1_2[J_1] .* Z_1_2;

and automajically do the fma thing if it can. It's not exposed yet but should be by 2.28 so might be worth just waiting

@paul-buerkner
Copy link
Owner Author

That sounds really nice and exciting! Out of interest, did you make some performance benchmarks for actual sparse situation? I wasn't sure if the benchmarks shown in stan-dev/math#2462 are for sparse matrix stuff in particular?

@SteveBronder
Copy link

I did not do benchmarking, though I can whip one up this week using brms that should be pretty easy. In my brain I think it should be faster than calling the multi-indexing each time but there will probably be some sort of cost matrix over the size of groups and data I need to think about

@wds15
Copy link
Contributor

wds15 commented May 31, 2021

These things can be really weird and I personally stopped trusting my intuition, but rely on brute-force benchmarks only.

It would be really nice to leave for loops behind...

@paul-buerkner paul-buerkner added this to the 2.21.0 milestone Jan 27, 2024
@paul-buerkner
Copy link
Owner Author

With rstan being now more up to date, I put this issue here higher on the agenda.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants