Skip to content

Commit

Permalink
working code with initial values
Browse files Browse the repository at this point in the history
  • Loading branch information
osorensen committed Apr 23, 2024
1 parent 4f60a9d commit c06effa
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 18 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ run_mcmc <- function(data, model_options, compute_options, priors, initial_value
.Call(`_BayesMallows_run_mcmc`, data, model_options, compute_options, priors, initial_values, pfun_values, pfun_estimate, progress_report)
}

run_smc <- function(data, new_data, model_options, smc_options, compute_options, priors, initial_values, pfun_values, pfun_estimate) {
.Call(`_BayesMallows_run_smc`, data, new_data, model_options, smc_options, compute_options, priors, initial_values, pfun_values, pfun_estimate)
run_smc <- function(data, new_data, model_options, smc_options, compute_options, priors, pfun_values, pfun_estimate) {
.Call(`_BayesMallows_run_smc`, data, new_data, model_options, smc_options, compute_options, priors, pfun_values, pfun_estimate)
}

8 changes: 1 addition & 7 deletions R/compute_mallows_sequentially.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
#' @param data A list of objects of class "BayesMallowsData" returned from
#' [setup_rank_data()]. Each list element is interpreted as the data belonging
#' to a given timepoint.
#' @param initial_values An object of class "BayesMallowsPriorSamples" returned
#' from [sample_prior()].
#' @param model_options An object of class "BayesMallowsModelOptions" returned
#' from [set_model_options()].
#' @param smc_options An object of class "SMCOptions" returned from
Expand Down Expand Up @@ -44,13 +42,12 @@
#'
compute_mallows_sequentially <- function(
data,
initial_values,
model_options = set_model_options(),
smc_options = set_smc_options(),
compute_options = set_compute_options(),
priors = set_priors(),
pfun_estimate = NULL) {
validate_class(initial_values, "BayesMallowsPriorSamples")

if (!is.list(data) | !all(vapply(data, inherits, logical(1), "BayesMallowsData"))) {
stop("data must be a list of BayesMallowsData objects.")
}
Expand All @@ -70,8 +67,6 @@ compute_mallows_sequentially <- function(
x
})
pfun_values <- extract_pfun_values(model_options$metric, data[[1]]$n_items, pfun_estimate)
alpha_init <- sample(initial_values$alpha, smc_options$n_particles, replace = TRUE)
rho_init <- initial_values$rho[, sample(ncol(initial_values$rho), smc_options$n_particles, replace = TRUE)]

ret <- run_smc(
data = flush(data[[1]]),
Expand All @@ -80,7 +75,6 @@ compute_mallows_sequentially <- function(
smc_options = smc_options,
compute_options = compute_options,
priors = priors,
initial_values = list(alpha_init = alpha_init, rho_init = rho_init, aug_init = NULL),
pfun_values = pfun_values,
pfun_estimate = pfun_estimate
)
Expand Down
9 changes: 4 additions & 5 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ BEGIN_RCPP
END_RCPP
}
// run_smc
Rcpp::List run_smc(Rcpp::List data, Rcpp::List new_data, Rcpp::List model_options, Rcpp::List smc_options, Rcpp::List compute_options, Rcpp::List priors, Rcpp::List initial_values, Rcpp::Nullable<arma::mat> pfun_values, Rcpp::Nullable<arma::mat> pfun_estimate);
RcppExport SEXP _BayesMallows_run_smc(SEXP dataSEXP, SEXP new_dataSEXP, SEXP model_optionsSEXP, SEXP smc_optionsSEXP, SEXP compute_optionsSEXP, SEXP priorsSEXP, SEXP initial_valuesSEXP, SEXP pfun_valuesSEXP, SEXP pfun_estimateSEXP) {
Rcpp::List run_smc(Rcpp::List data, Rcpp::List new_data, Rcpp::List model_options, Rcpp::List smc_options, Rcpp::List compute_options, Rcpp::List priors, Rcpp::Nullable<arma::mat> pfun_values, Rcpp::Nullable<arma::mat> pfun_estimate);
RcppExport SEXP _BayesMallows_run_smc(SEXP dataSEXP, SEXP new_dataSEXP, SEXP model_optionsSEXP, SEXP smc_optionsSEXP, SEXP compute_optionsSEXP, SEXP priorsSEXP, SEXP pfun_valuesSEXP, SEXP pfun_estimateSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -155,10 +155,9 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< Rcpp::List >::type smc_options(smc_optionsSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type compute_options(compute_optionsSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type priors(priorsSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type initial_values(initial_valuesSEXP);
Rcpp::traits::input_parameter< Rcpp::Nullable<arma::mat> >::type pfun_values(pfun_valuesSEXP);
Rcpp::traits::input_parameter< Rcpp::Nullable<arma::mat> >::type pfun_estimate(pfun_estimateSEXP);
rcpp_result_gen = Rcpp::wrap(run_smc(data, new_data, model_options, smc_options, compute_options, priors, initial_values, pfun_values, pfun_estimate));
rcpp_result_gen = Rcpp::wrap(run_smc(data, new_data, model_options, smc_options, compute_options, priors, pfun_values, pfun_estimate));
return rcpp_result_gen;
END_RCPP
}
Expand All @@ -173,7 +172,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_BayesMallows_get_partition_function", (DL_FUNC) &_BayesMallows_get_partition_function, 4},
{"_BayesMallows_rmallows", (DL_FUNC) &_BayesMallows_rmallows, 7},
{"_BayesMallows_run_mcmc", (DL_FUNC) &_BayesMallows_run_mcmc, 8},
{"_BayesMallows_run_smc", (DL_FUNC) &_BayesMallows_run_smc, 9},
{"_BayesMallows_run_smc", (DL_FUNC) &_BayesMallows_run_smc, 8},
{NULL, NULL, 0}
};

Expand Down
11 changes: 10 additions & 1 deletion src/run_smc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Rcpp::List run_smc(
Rcpp::List smc_options,
Rcpp::List compute_options,
Rcpp::List priors,
Rcpp::List initial_values,
Rcpp::Nullable<arma::mat> pfun_values,
Rcpp::Nullable<arma::mat> pfun_estimate) {

Expand All @@ -26,6 +25,16 @@ Rcpp::List run_smc(
Priors pris{priors};
SMCAugmentation aug{compute_options, smc_options};

vec alpha_init = randg(pars.n_particles, distr_param(pris.gamma, 1 / pris.lambda));
Rcpp::List initial_values;
initial_values["alpha_init"] = alpha_init;
mat rho_init(dat.n_items, pars.n_particles);
for(size_t i{}; i < pars.n_particles; i++) {
rho_init.col(i) = conv_to<vec>::from(randperm(dat.n_items) + 1);
}
initial_values["rho_init"] = rho_init;
initial_values["aug_init"] = R_NilValue;

std::vector<StaticParticle> particle_vector =
initialize_particles(initial_values, pars.n_particles, pars.n_particle_filters, dat);

Expand Down
7 changes: 4 additions & 3 deletions work-docs/smc2/sushi_clustering.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ library(ggplot2)

dat <- lapply(1:50, function(i) {
dd <- sushi_rankings[i, ]
dd[runif(10) > .7] <- NA
#dd[runif(10) > .7] <- NA
setup_rank_data(dd, user_ids = i)
})

Expand All @@ -14,8 +14,8 @@ initial_values <- sample_prior(

mod <- compute_mallows_sequentially(
data = dat,
initial_values = initial_values,
smc_options = set_smc_options(n_particles = 500, n_particle_filters = 10, resampling_threshold = 250)
smc_options = set_smc_options(n_particles = 2000, n_particle_filters = 1, resampling_threshold = 1000),
priors = set_priors(gamma = 3, lambda = .1)
)

hist(mod$alpha_samples)
Expand All @@ -29,6 +29,7 @@ plot_dat <- data.frame(
)



ggplot(plot_dat[-1, ], aes(x = n_obs, y = alpha_mean, ymin = alpha_mean - alpha_sd,
ymax = alpha_mean + alpha_sd)) +
geom_line() +
Expand Down

0 comments on commit c06effa

Please sign in to comment.