Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
osorensen committed Nov 8, 2023
1 parent a3e91ff commit 4d89f52
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 40 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ run_mcmc <- function(rankings, obs_freq, nmc, constraints, cardinalities, logz_e
.Call(`_BayesMallows_run_mcmc`, rankings, obs_freq, nmc, constraints, cardinalities, logz_estimate, rho_init, metric, error_model, Lswap, n_clusters, include_wcd, leap_size, alpha_prop_sd, alpha_init, alpha_jump, lambda, alpha_max, psi, rho_thinning, aug_thinning, clus_thin, save_aug, verbose, kappa_1, kappa_2, save_ind_clus)
}

smc_mallows_new_users <- function(rankings, new_rankings, rho_init, alpha_init, n_particles, mcmc_steps, alpha_prop_sd = 0.5, lambda = 0.1, aug_method = "uniform", logz_estimate = NULL, cardinalities = NULL, metric = "footrule", leap_size = 1L, aug_init = NULL, num_obs = 0L) {
.Call(`_BayesMallows_smc_mallows_new_users`, rankings, new_rankings, rho_init, alpha_init, n_particles, mcmc_steps, alpha_prop_sd, lambda, aug_method, logz_estimate, cardinalities, metric, leap_size, aug_init, num_obs)
smc_mallows_new_users <- function(rankings, new_rankings, rho_init, alpha_init, n_particles, mcmc_steps, alpha_prop_sd = 0.5, lambda = 0.1, aug_method = "uniform", logz_estimate = NULL, cardinalities = NULL, metric = "footrule", leap_size = 1L, aug_init = NULL) {
.Call(`_BayesMallows_smc_mallows_new_users`, rankings, new_rankings, rho_init, alpha_init, n_particles, mcmc_steps, alpha_prop_sd, lambda, aug_method, logz_estimate, cardinalities, metric, leap_size, aug_init)
}

1 change: 0 additions & 1 deletion R/update_mallows.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ update_mallows.SMCMallows <- function(model, new_rankings) {
logz_estimate = model$logz_list$logz_estimate,
cardinalities = model$logz_list$cardinalities,
metric = model$metric,
num_obs = nrow(model$rankings),
aug_init = aug_init
)

Expand Down
9 changes: 4 additions & 5 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ BEGIN_RCPP
END_RCPP
}
// smc_mallows_new_users
Rcpp::List smc_mallows_new_users(arma::mat rankings, arma::mat new_rankings, const arma::mat rho_init, const arma::vec alpha_init, const int n_particles, const int mcmc_steps, const double alpha_prop_sd, const double lambda, const std::string aug_method, const Rcpp::Nullable<arma::vec>& logz_estimate, const Rcpp::Nullable<arma::vec>& cardinalities, const std::string& metric, const int& leap_size, const Rcpp::Nullable<arma::cube> aug_init, int num_obs);
RcppExport SEXP _BayesMallows_smc_mallows_new_users(SEXP rankingsSEXP, SEXP new_rankingsSEXP, SEXP rho_initSEXP, SEXP alpha_initSEXP, SEXP n_particlesSEXP, SEXP mcmc_stepsSEXP, SEXP alpha_prop_sdSEXP, SEXP lambdaSEXP, SEXP aug_methodSEXP, SEXP logz_estimateSEXP, SEXP cardinalitiesSEXP, SEXP metricSEXP, SEXP leap_sizeSEXP, SEXP aug_initSEXP, SEXP num_obsSEXP) {
Rcpp::List smc_mallows_new_users(arma::mat rankings, arma::mat new_rankings, const arma::mat rho_init, const arma::vec alpha_init, const int n_particles, const int mcmc_steps, const double alpha_prop_sd, const double lambda, const std::string aug_method, const Rcpp::Nullable<arma::vec>& logz_estimate, const Rcpp::Nullable<arma::vec>& cardinalities, const std::string& metric, const int& leap_size, const Rcpp::Nullable<arma::cube> aug_init);
RcppExport SEXP _BayesMallows_smc_mallows_new_users(SEXP rankingsSEXP, SEXP new_rankingsSEXP, SEXP rho_initSEXP, SEXP alpha_initSEXP, SEXP n_particlesSEXP, SEXP mcmc_stepsSEXP, SEXP alpha_prop_sdSEXP, SEXP lambdaSEXP, SEXP aug_methodSEXP, SEXP logz_estimateSEXP, SEXP cardinalitiesSEXP, SEXP metricSEXP, SEXP leap_sizeSEXP, SEXP aug_initSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -185,8 +185,7 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const std::string& >::type metric(metricSEXP);
Rcpp::traits::input_parameter< const int& >::type leap_size(leap_sizeSEXP);
Rcpp::traits::input_parameter< const Rcpp::Nullable<arma::cube> >::type aug_init(aug_initSEXP);
Rcpp::traits::input_parameter< int >::type num_obs(num_obsSEXP);
rcpp_result_gen = Rcpp::wrap(smc_mallows_new_users(rankings, new_rankings, rho_init, alpha_init, n_particles, mcmc_steps, alpha_prop_sd, lambda, aug_method, logz_estimate, cardinalities, metric, leap_size, aug_init, num_obs));
rcpp_result_gen = Rcpp::wrap(smc_mallows_new_users(rankings, new_rankings, rho_init, alpha_init, n_particles, mcmc_steps, alpha_prop_sd, lambda, aug_method, logz_estimate, cardinalities, metric, leap_size, aug_init));
return rcpp_result_gen;
END_RCPP
}
Expand All @@ -203,7 +202,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_BayesMallows_asymptotic_partition_function", (DL_FUNC) &_BayesMallows_asymptotic_partition_function, 6},
{"_BayesMallows_rmallows", (DL_FUNC) &_BayesMallows_rmallows, 7},
{"_BayesMallows_run_mcmc", (DL_FUNC) &_BayesMallows_run_mcmc, 27},
{"_BayesMallows_smc_mallows_new_users", (DL_FUNC) &_BayesMallows_smc_mallows_new_users, 15},
{"_BayesMallows_smc_mallows_new_users", (DL_FUNC) &_BayesMallows_smc_mallows_new_users, 14},
{"run_testthat_tests", (DL_FUNC) &run_testthat_tests, 1},
{NULL, NULL, 0}
};
Expand Down
6 changes: 6 additions & 0 deletions src/missing_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ vec propose_augmentation(const vec& ranks, const uvec& indicator){
return proposal;
}

void set_up_missing(arma::mat& rankings, arma::umat& missing_indicator) {
rankings.replace(datum::nan, 0);
missing_indicator = conv_to<umat>::from(rankings);
missing_indicator.transform( [](int val) { return (val == 0) ? 1 : 0; } );
}

void initialize_missing_ranks(mat& rankings, const umat& missing_indicator) {
int n_assessors = rankings.n_cols;

Expand Down
2 changes: 2 additions & 0 deletions src/missing_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ arma::vec make_new_augmentation(const arma::vec& rankings, const arma::uvec& mis
const double& alpha, const arma::vec& rho,
const std::string& metric);

void set_up_missing(arma::mat& rankings, arma::umat& missing_indicator);

void initialize_missing_ranks(arma::mat& rankings, const arma::umat& missing_indicator);

void update_missing_ranks(arma::mat& rankings, const arma::uvec& current_cluster_assignment,
Expand Down
7 changes: 2 additions & 5 deletions src/run_mcmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,13 @@ Rcpp::List run_mcmc(arma::mat rankings, arma::vec obs_freq, int nmc,
bool any_missing = !is_finite(rankings);

umat missing_indicator{};
cube augmented_data{};

if(any_missing){
rankings.replace(datum::nan, 0);
missing_indicator = conv_to<umat>::from(rankings);
missing_indicator.transform( [](int val) { return (val == 0) ? 1 : 0; } );
set_up_missing(rankings, missing_indicator);
initialize_missing_ranks(rankings, missing_indicator);
}

// If the user wants to save augmented data, we need a cube
cube augmented_data{};
if(save_aug){
augmented_data.set_size(n_items, n_assessors, std::ceil(static_cast<double>(nmc * 1.0 / aug_thinning)));
augmented_data.slice(0) = rankings;
Expand Down
28 changes: 10 additions & 18 deletions src/smc_mallows_new_users.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ Rcpp::List smc_mallows_new_users(
const Rcpp::Nullable<arma::vec>& cardinalities = R_NilValue,
const std::string& metric = "footrule",
const int& leap_size = 1,
const Rcpp::Nullable<arma::cube> aug_init = R_NilValue,
int num_obs = 0
const Rcpp::Nullable<arma::cube> aug_init = R_NilValue
) {

int num_new_obs = new_rankings.n_cols;
Expand All @@ -42,30 +41,23 @@ Rcpp::List smc_mallows_new_users(
vec alpha_samples = zeros(n_particles);
double effective_sample_size;

cube augmented_data;
umat missing_indicator;

cube augmented_data{};
umat missing_indicator{};

if(any_missing){

rankings.replace(datum::nan, 0);
missing_indicator = conv_to<umat>::from(rankings);
missing_indicator.transform( [](int val) { return (val == 0) ? 1 : 0; } );
augmented_data = zeros(n_items, n_users, n_particles);
set_up_missing(rankings, missing_indicator);
augmented_data.set_size(n_items, n_users, n_particles);

for(int i{}; i < n_particles; i++) {
augmented_data.slice(i) = rankings;
initialize_missing_ranks(augmented_data.slice(i), missing_indicator);
}

if(aug_init.isNotNull()) {
augmented_data(span::all, span(0, num_obs - 1), span::all) = Rcpp::as<cube>(aug_init);
augmented_data(span::all, span(0, rankings.n_cols - new_rankings.n_cols - 1), span::all) = Rcpp::as<cube>(aug_init);
}
} else {
missing_indicator.reset();
}

num_obs += num_new_obs;
mat new_observed_rankings, all_observed_rankings;
if(!any_missing){
new_observed_rankings = new_rankings;
Expand All @@ -78,19 +70,19 @@ Rcpp::List smc_mallows_new_users(

if(any_missing){
smc_mallows_new_users_augment_partial(
augmented_data, aug_prob, rho_samples, alpha_samples, num_obs, num_new_obs,
augmented_data, aug_prob, rho_samples, alpha_samples, num_new_obs,
aug_method, missing_indicator, metric);
}


vec norm_wgt;
smc_mallows_new_users_reweight(
log_inc_wgt, effective_sample_size, norm_wgt, augmented_data, new_observed_rankings, rho_samples,
alpha_samples, logz_estimate, cardinalities, num_obs, num_new_obs, aug_prob,
alpha_samples, logz_estimate, cardinalities, num_new_obs, aug_prob,
any_missing, metric);

smc_mallows_new_users_resample(
rho_samples, alpha_samples, augmented_data, norm_wgt, num_obs,
rho_samples, alpha_samples, augmented_data, norm_wgt,
any_missing);

for (int ii = 0; ii < n_particles; ++ii) {
Expand All @@ -110,7 +102,7 @@ Rcpp::List smc_mallows_new_users(
}

if(any_missing) {

int num_obs = rankings.n_cols;
for (int jj = num_obs - num_new_obs; jj < num_obs; ++jj) {
double log_hastings_correction = 0;

Expand Down
6 changes: 3 additions & 3 deletions src/smc_mallows_new_users.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
#include <RcppArmadillo.h>

void smc_mallows_new_users_augment_partial(arma::cube&, arma::vec&,
const arma::mat&, const arma::vec&, const int&, const int&,
const arma::mat&, const arma::vec&, const int&,
const std::string&,
const arma::umat& missing_indicator, const std::string&);
void smc_mallows_new_users_reweight(
arma::vec&, double&, arma::vec&,
const arma::cube&, const arma::mat&, const arma::mat&,
const arma::vec&, const Rcpp::Nullable<arma::vec>,
const Rcpp::Nullable<arma::vec>,
const int&, const int&, const arma::vec&, const bool&,
const int&, const arma::vec&, const bool&,
const std::string&);
void smc_mallows_new_users_resample(
arma::mat&, arma::vec&, arma::cube&, const arma::vec&,
const int& num_obs, const bool& partial);
const bool& partial);

Rcpp::List make_pseudo_proposal(
arma::uvec unranked_items, arma::vec rankings, const double& alpha,
Expand Down
10 changes: 4 additions & 6 deletions src/smc_mallows_new_users_funs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ void smc_mallows_new_users_augment_partial(
arma::vec& aug_prob,
const arma::mat& rho_samples,
const arma::vec& alpha_samples,
const int& num_obs,
const int& num_new_obs,
const std::string& aug_method,
const umat& missing_indicator,
const std::string& metric = "footrule"
){
int n_particles = rho_samples.n_cols;
int num_obs = augmented_data.n_cols;

for (int ii{}; ii < n_particles; ++ii) {
double alpha = alpha_samples(ii);
Expand Down Expand Up @@ -93,7 +93,7 @@ void smc_mallows_new_users_augment_partial(

void smc_mallows_new_users_resample(
mat& rho_samples, vec& alpha_samples, cube& augmented_data,
const vec& norm_wgt, const int& num_obs,
const vec& norm_wgt,
const bool& partial
){
int n_particles = rho_samples.n_cols;
Expand All @@ -103,9 +103,7 @@ void smc_mallows_new_users_resample(
alpha_samples = alpha_samples.rows(index);

if(partial){
cube augmented_data_index = augmented_data.slices(index);
augmented_data.cols(0, num_obs - 1) =
augmented_data_index(span::all, span(0, num_obs - 1), span::all);
augmented_data = augmented_data.slices(index);
}
}

Expand All @@ -119,14 +117,14 @@ void smc_mallows_new_users_reweight(
const vec& alpha_samples,
const Rcpp::Nullable<vec> logz_estimate,
const Rcpp::Nullable<vec> cardinalities,
const int& num_obs,
const int& num_new_obs,
const vec& aug_prob,
const bool& partial,
const std::string& metric = "footrule"
){
int n_particles = rho_samples.n_cols;
int n_items = rho_samples.n_rows;
int num_obs = augmented_data.n_cols;
for (int ii{}; ii < n_particles; ++ii) {

/* Calculating log_z_alpha and log_likelihood ----------- */
Expand Down

0 comments on commit 4d89f52

Please sign in to comment.