Skip to content

Commit

Permalink
Add gradient computation
Browse files Browse the repository at this point in the history
  • Loading branch information
jqujqu committed Feb 1, 2018
1 parent d25940c commit f01b30c
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions src/prog/estparam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void get_jumps(const vector<Path> &paths, vector<Jump> &jumps) {
}

void get_suff_stat(const vector<Jump> &jumps,
vector<size_t> &tot_freq,
vector<double> &tot_freq,
vector<double> &weights) {
tot_freq.resize(8, 0);
weights.resize(8, 0);
Expand All @@ -91,7 +91,7 @@ void get_suff_stat(const vector<Jump> &jumps,


double llk(const vector<Jump> &jumps,
const vector<size_t> &tot_freq,
const vector<double> &tot_freq,
const vector<double> &weights,
const vector<double> &rates) {
double l = 0;
Expand All @@ -103,6 +103,34 @@ double llk(const vector<Jump> &jumps,
}


void grad_llk(const vector<Jump> &jumps,
const vector<double> &tot_freq,
const vector<double> &weights,
const vector<double> &rates,
vector<double> &gradient) {
gradient.resize(8, 0);

gradient[0] = (tot_freq[0] + tot_freq[7])/rates[0] - weights[0] -
weights[7]*rates[7]/rates[0];

gradient[2] = (tot_freq[2] - tot_freq[7])/rates[2] - weights[2] +
weights[7]*rates[7]/rates[2];

gradient[5] = (tot_freq[5] + tot_freq[7])/rates[5] - weights[5] -
weights[7]*rates[7]/rates[5];

gradient[1] = (tot_freq[1] + tot_freq[4] - 2*tot_freq[7])/rates[1] -
(weights[1] + weights[4]) + 2*weights[7]*rates[7]/rates[1];

gradient[3] = (tot_freq[3] + tot_freq[6] + 2*tot_freq[7])/rates[3] -
(weights[3] + weights[6]) - 2*weights[7]*rates[7]/rates[3];

gradient[4] = gradient[1];
gradient[6] = gradient[3];
// gradient[7] remains 0, and rates[7] is determined by other rates.
}


void est_trans_prob(const vector<bool> &seq,
vector<vector<double> > & trans_prob) {
double c00(0), c01(0), c10(0), c11(0), c0(0), c1(1);
Expand Down Expand Up @@ -203,7 +231,7 @@ int main(int argc, const char **argv) {
if (!param_file.empty()) {
model_param p;
read_param(param_file, p);
vector<size_t> tot_freq;
vector<double> tot_freq;
vector<double> weights;
get_suff_stat(jumps, tot_freq, weights);

Expand All @@ -230,6 +258,20 @@ int main(int argc, const char **argv) {
for (size_t i = 0; i < rates.size(); ++i) cerr << rates[i] << "\t";
cerr << endl;
cerr << "log-likelihood= " << llk(jumps, tot_freq, weights, rates) << endl;

cerr << "tot_freq:" << endl;
for (size_t i =0; i < 8; ++i)
cerr << "[" << i << "]\t" << tot_freq[i] << endl;

cerr << "weights:" << endl;
for (size_t i =0; i < 8; ++i)
cerr << "[" << i << "]\t" << weights[i] << endl;

vector<double> gradient;
grad_llk(jumps, tot_freq, weights, rates, gradient);
cerr << "gradient :" << endl;
for (size_t i =0; i < gradient.size(); ++i)
cerr << "[" << i << "]\t" << gradient[i] << endl;
}

}
Expand Down

0 comments on commit f01b30c

Please sign in to comment.