Skip to content

Commit

Permalink
Added batch to get average J/D before estimation triplet rates.
Browse files Browse the repository at this point in the history
  • Loading branch information
xjlizji committed Apr 18, 2020
1 parent 0400e93 commit 60261ef
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions src/prog/epievo_initialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,47 @@ initialize_paths(RandEngine &gen, const TreeHelper &th,
}


static void
sample_summary_stats(const vector<double> &rates, const TreeHelper &th,
vector<vector<Path> > &paths,
vector<vector<double> > &J_all_sites,
vector<vector<double> > &D_all_sites,
std::mt19937 &gen,
const size_t batch) {
const size_t n_triples = 8;
J_all_sites.resize(th.n_nodes);
D_all_sites.resize(th.n_nodes);
for(size_t b = 1; b < th.n_nodes; ++b) {
J_all_sites[b].clear();
J_all_sites[b].resize(n_triples, 0.0);
D_all_sites[b].clear();
D_all_sites[b].resize(n_triples, 0.0);
}

for (size_t i = 0; i < batch; i++) {
update_paths_indep(rates, th, paths, gen);

vector<vector<double> > J_one_site, D_one_site;
get_sufficient_statistics(paths, J_one_site, D_one_site);
for (size_t b = 1; b < th.n_nodes; b++) {
for (size_t i = 0; i < n_triples; i++) {
J_all_sites[b][i] += J_one_site[b][i];
D_all_sites[b][i] += D_one_site[b][i];
}
}
}

/* CALCULATE BATCH AVERAGE */
for (size_t b = 1; b < th.n_nodes; ++b) {
for (size_t i = 0; i < n_triples; i++) {
J_all_sites[b][i] /= batch;
D_all_sites[b][i] /= batch;
}
}
}



static void
initialize_model_from_indep_rates(EpiEvoModel &the_model,
const vector<double> rates) {
Expand Down Expand Up @@ -259,6 +300,7 @@ int main(int argc, const char **argv) {

size_t rng_seed = numeric_limits<size_t>::max();
size_t iterations = 10;
size_t batch = 10; // MCMC iterations

string paramfile;
string pathfile;
Expand All @@ -273,6 +315,7 @@ int main(int argc, const char **argv) {
opt_parse.add_opt("seed", 's', "rng seed", false, rng_seed);
opt_parse.add_opt("iterations", 'i', "number of iterations",
false, iterations);
opt_parse.add_opt("batch", 'B', "number of MCMC iteration", false, batch);
opt_parse.add_opt("param", 'p', "output file of parameters",
false, paramfile);
opt_parse.add_opt("outtree", 't', "output file of tree",
Expand Down Expand Up @@ -363,9 +406,9 @@ int main(int argc, const char **argv) {

for (size_t itr = 0; itr < iterations; itr++) {
if (!optimize_branches)
estimate_rates(J, D, rates, th);
estimate_rates_indep(J, D, rates, th);
else {
estimate_rates_and_branches(J, D, rates, th, paths);
estimate_rates_and_branches_indep(J, D, rates, th, paths);
the_tree.set_branch_lengths(th.branches);
}

Expand All @@ -376,8 +419,8 @@ int main(int argc, const char **argv) {
}

/* Re-sample a better initial path */
vector<vector<Path> > sampled_paths(paths);
sample_paths(rates, th, paths, gen, sampled_paths);
vector<vector<double> > J_trip, D_trip;
sample_summary_stats(rates, th, paths, J_trip, D_trip, gen, batch);

/*******************************************************/
/* Generate initial parameters of context-dependent model */
Expand All @@ -392,15 +435,16 @@ int main(int argc, const char **argv) {

// re-estimate triplet rates from paths
if (!optimize_branches) {
estimate_rates(false, param_tol, sampled_paths, the_model);
estimate_rates(false, param_tol, J_trip, D_trip, the_model);
set_one_change_per_site_per_unit_time(the_model.triplet_rates,
th.branches);
}
else {
estimate_rates_and_branches(false, param_tol, sampled_paths, th, the_model);
estimate_rates_and_branches(false, param_tol, J_trip, D_trip, th,
the_model);
the_tree.set_branch_lengths(th.branches);
}
scale_jump_times(sampled_paths, th);
scale_jump_times(paths, th);

// write path file
if (VERBOSE)
Expand All @@ -409,7 +453,7 @@ int main(int argc, const char **argv) {
write_root_to_pathfile_local(pathfile, th.node_names.front());
for (size_t node_id = 1; node_id < th.n_nodes; ++node_id)
append_to_pathfile_local(pathfile, th.node_names[node_id],
sampled_paths, node_id);
paths, node_id);

// write parameters
if (VERBOSE)
Expand Down

0 comments on commit 60261ef

Please sign in to comment.