Skip to content

Commit

Permalink
Added compatibility for a tree.
Browse files Browse the repository at this point in the history
  • Loading branch information
xjlizji committed Apr 16, 2020
1 parent 8214ea8 commit 85000f7
Showing 1 changed file with 46 additions and 28 deletions.
74 changes: 46 additions & 28 deletions src/prog/average_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,37 +43,43 @@ using std::runtime_error;
using std::numeric_limits;
using std::to_string;

typedef vector<vector<double> > matrix;

static void
add_paths(const vector<vector<Path> > &paths,
vector<vector<double> > &average_paths) {
const size_t n_points = average_paths[0].size();
const double bin = paths[1][0].tot_time/(n_points - 1);
for (size_t site_id = 0; site_id < paths[1].size(); site_id++) {
size_t prev_state = paths[1][site_id].init_state;
average_paths[site_id][0] += prev_state;
double curr_time = bin;
for (size_t i = 1; i < average_paths[site_id].size(); i++) {
average_paths[site_id][i] +=
paths[1][site_id].state_at_time(curr_time);
add_paths(const vector<vector<Path> > &paths, vector<matrix> &average_paths,
const size_t n_points) {
for (size_t b = 1; b < paths.size(); b++) {
const double bin = paths[b].front().tot_time/(n_points - 1);
for (size_t site_id = 0; site_id < paths[b].size(); site_id++) {
size_t prev_state = paths[b][site_id].init_state;
average_paths[b][site_id].front() += prev_state;
double curr_time = bin;
for (size_t i = 1; i < n_points; i++) {
average_paths[b][site_id][i] +=
paths[1][site_id].state_at_time(curr_time);
curr_time += bin;
}
}
}
}


static void
write_output(const string &outfile, const vector<vector<double> > &states) {
write_output(const string &outfile, const vector<matrix> &states,
const vector<string> &node_names,
const vector<double> &branch_len) {
std::ofstream out(outfile);
if (!out)
throw runtime_error("bad output file: " + outfile);

const size_t n_sites = states.size();
for (size_t site_id = 0; site_id < n_sites; ++site_id) {
out << states[site_id].front();
for (size_t t = 1; t < states[site_id].size(); t++)
out << "\t" << states[site_id][t];
out << endl;
out << "NODE:" << node_names[0] << endl;
for (size_t b = 1; b < node_names.size(); b++) {
out << "NODE:" << node_names[b] << "\t" << branch_len[b] << endl;
for (size_t site_id = 0; site_id < states[b].size(); ++site_id) {
out << states[b][site_id].front();
for (size_t t = 1; t < states[b][site_id].size(); t++)
out << "\t" << states[b][site_id][t];
out << endl;
}
}
}

Expand Down Expand Up @@ -122,28 +128,40 @@ int main(int argc, const char **argv) {

vector<vector<Path> > paths; // along multiple branches
vector<string> node_names;
vector<matrix> average_paths;
read_paths(files[0], node_names, paths);

const size_t n_nodes = paths.size();
const size_t n_sites = paths[1].size();
vector<vector<double> > average_paths(n_sites, vector<double>(n_points, 0));
add_paths(paths, average_paths);
vector<double> branch_len(n_nodes);
average_paths.resize(n_nodes);
for (size_t b = 1; b < paths.size(); b++) {
average_paths[b].resize(n_sites);
branch_len[b] = paths[b].front().tot_time;
for (size_t site_id = 0; site_id < paths[b].size(); site_id++)
average_paths[b][site_id].resize(n_points, 0.0);
}

add_paths(paths, average_paths, n_points);

for (size_t i = 1; i < n_files; i++) {
paths.clear(); // along multiple branches
vector<string> node_names;
read_paths(files[i], node_names, paths);
add_paths(paths, average_paths);
add_paths(paths, average_paths, n_points);
}

/* AVERAGING PATHS */
for (size_t site_id = 0; site_id < average_paths.size(); ++site_id)
transform(begin(average_paths[site_id]), end(average_paths[site_id]),
begin(average_paths[site_id]),
[&](const double x) {return x/n_files;});
for (size_t b = 1; b < n_nodes; b++)
for (size_t site_id = 0; site_id < n_sites; ++site_id)
transform(begin(average_paths[b][site_id]),
end(average_paths[b][site_id]),
begin(average_paths[b][site_id]),
[&](const double x) {return x/n_files;});

if (VERBOSE)
cerr << "[WRITING OUTPUT TO: ]" << outfile << "]" << endl;
write_output(outfile, average_paths);
cerr << "[WRITING OUTPUT TO: " << outfile << "]" << endl;
write_output(outfile, average_paths, node_names, branch_len);
}
catch (const std::exception &e) {
cerr << e.what() << endl;
Expand Down

0 comments on commit 85000f7

Please sign in to comment.