Skip to content

Commit

Permalink
adding final model train loop
Browse files Browse the repository at this point in the history
  • Loading branch information
cmaceves committed Aug 21, 2024
1 parent b3283e0 commit 5a77f58
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 78 deletions.
145 changes: 67 additions & 78 deletions src/gmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cmath>
#include <algorithm>
#include <limits>
#include <unordered_map>

void assign_clusters(std::vector<variant> &variants, gaussian_mixture_model gmodel){
std::vector<std::vector<double>> tv = transpose_vector(gmodel.prob_matrix);
Expand All @@ -17,96 +18,74 @@ void assign_clusters(std::vector<variant> &variants, gaussian_mixture_model gmod
assign_variants_simple(variants, gmodel.prob_matrix, index, gmodel.means);

}
std::vector<double> count_repeated_values(const std::vector<double>& vec) {
std::unordered_map<double, int> frequency_map;
for (double value : vec) {
frequency_map[value]++;
}

std::vector<double> repeats;
for (const auto& pair : frequency_map) {
if (pair.second > 1) {
repeats.push_back(pair.first);
}
}
return repeats;
}

gaussian_mixture_model retrain_model(uint32_t n, arma::mat data, float var_floor, float universal_cluster, float noise_cluster, bool fix_clusters){
gaussian_mixture_model gmodel;
arma::mat mmeans;
std::vector<double> kmeans;
std::vector<uint32_t> n_counts(n, 0);
bool status2 = arma::kmeans(mmeans, data, n, arma::random_spread, 15, true);
for(auto m : mmeans){
std::cerr << m << std::endl;
kmeans.push_back((double)m);
}
for(auto d : data){
double dd = (double)d;
auto it = std::min_element(kmeans.begin(), kmeans.end(),
[dd](double a, double b) {
return std::abs(a - dd) < std::abs(b - dd);
});
uint32_t index = std::distance(kmeans.begin(), it);
n_counts[index] += 1;
//std::cerr << d << " " << index << std::endl;
}
n = 0;
std::vector<double> means;
std::cerr << "percents" << std::endl;
for(uint32_t i =0; i < mmeans.size(); i++){
double percent_assigned = n_counts[i] / (double)data.size();
std::cerr << percent_assigned << std::endl;
if(percent_assigned > 0.10){
means.push_back((double) mmeans[i]);
n += 1;
}
}
float dcov_first = 0.0001;
float dcov_second = 0.0001;
gmodel.n = n;
float dcov_first = 0.001;
var_floor = 0.01;
auto min_iterator = std::min_element(means.begin(), means.end());
uint32_t min_index = std::distance(means.begin(), min_iterator);
auto max_iterator = std::max_element(means.begin(), means.end());
uint32_t max_index = std::distance(means.begin(), max_iterator);

arma::mat mean_fill (1, n, arma::fill::zeros);
for(uint32_t l=0; l < n; l++){
if(l == min_index){
mean_fill.col(l) = noise_cluster;
} else if(l == max_index){
mean_fill.col(l) = universal_cluster;
} else if(means[l] > universal_cluster || means[l] < noise_cluster){
continue;
} else{
mean_fill.col(l) = means[l];
}
}

//gaussian_mixture_model gmodel;
arma::mat cov (1, n, arma::fill::zeros);
for(uint32_t l=0; l < n;l++){
if(means[l] >= universal_cluster){
cov.col(l) = dcov_first;
} else if (means[l] <= noise_cluster) {
cov.col(l) = dcov_first;
}else {
cov.col(l) = dcov_second;
}
}
float val = 1 / (float)n;
arma::rowvec hef (n);
hef.fill(val);
arma::gmm_diag model;
model.reset(1, n);
model.set_means(mean_fill);
model.set_hefts(hef);
model.set_dcovs(cov);
std::cerr << "pre dcovs " << model.dcovs << std::endl;
std::cerr << "pre hefts " << model.hefts << std::endl;
std::cerr << universal_cluster << " " << noise_cluster << " " << fix_clusters << std::endl;

//train model
bool status = model.learn(data, n, arma::eucl_dist, arma::keep_existing, 10, 15, var_floor, false);
bool status = model.learn(data, n, arma::eucl_dist, arma::static_subset, 10, 15, var_floor,false);
if(!status){
std::cerr << "model failed to converge" << std::endl;
}
std::cerr << "means " << model.means << std::endl;
std::cerr << "hefts " << model.hefts << std::endl;
std::cerr << "dcovs " << model.dcovs << std::endl;
means.clear();

std::vector<double> means;
std::vector<double> hefts;
for(uint32_t i=0; i < model.means.size(); i++){
means.push_back((double)model.means[i]);

for(auto m : model.means){
double factor = std::pow(10.0, 2);
double rounded = std::round((double)m * factor) / factor;
means.push_back(rounded);
}
std::vector<double> repeats = count_repeated_values(means);
auto min_iterator = std::min_element(means.begin(), means.end());
uint32_t min_index = std::distance(means.begin(), min_iterator);
auto max_iterator = std::max_element(means.begin(), means.end());
uint32_t max_index = std::distance(means.begin(), max_iterator);
arma::mat mean_fill2 (1, n, arma::fill::zeros);
for(uint32_t i=0; i < means.size(); i++){
if(i == min_index || means[i] < noise_cluster){
mean_fill2.col(i) = noise_cluster;
means[i] = (double)0.03;
} else if(i == max_index || means[i] > universal_cluster){
mean_fill2.col(i) = universal_cluster;;
means[i] = (double)0.97;
}else{
mean_fill2.col(i) = means[i];
}
hefts.push_back((double)model.hefts[i]);
//if this is not one of the repeats but is < 0.05 of the heft, remove it
auto it = std::find(repeats.begin(), repeats.end(), means[i]);
if((double)model.hefts[i] < 0.05 && it == repeats.end()){
n -= 1;
}
}

for(auto m : means){
std::cerr << m << std::endl;
}
std::cerr << "repeats " << repeats.size() << std::endl;
n -= repeats.size();
model.set_means(mean_fill2);
std::vector<std::vector<double>> prob_matrix;
std::vector<double> tmp;
for(uint32_t i=0; i < n; i++){
Expand All @@ -117,8 +96,8 @@ gaussian_mixture_model retrain_model(uint32_t n, arma::mat data, float var_floor
}
prob_matrix.push_back(tmp);
}
gmodel.new_n = n;
gmodel.prob_matrix = prob_matrix;

gmodel.means = means;
gmodel.hefts = hefts;
return(gmodel);
Expand All @@ -133,7 +112,7 @@ gaussian_mixture_model train_model(uint32_t n, arma::mat data, float var_floor,
std::vector<double> hefts;

//train model
bool status = model.learn(data, n, arma::eucl_dist, arma::random_spread, 10, 20, var_floor, false);
bool status = model.learn(data, n, arma::eucl_dist, arma::static_subset, 10, 15, var_floor, false);
if(!status){
std::cerr << "model failed to converge" << std::endl;
}
Expand Down Expand Up @@ -940,7 +919,7 @@ std::vector<variant> gmm_model(std::string prefix, uint32_t n, std::string outp
parse_internal_variants(prefix, base_variants, depth_cutoff, lower_bound, upper_bound, deletion_positions, low_quality_positions, round_val);
for(uint32_t i=0; i < base_variants.size(); i++){
if(!base_variants[i].amplicon_flux && !base_variants[i].depth_flag && !base_variants[i].outside_freq_range && !base_variants[i].qual_flag && !base_variants[i].del_flag && !base_variants[i].amplicon_masked && !base_variants[i].primer_masked){
variants.push_back(base_variants[i]);
variants.push_back(base_variants[i]);
}
}

Expand All @@ -951,7 +930,17 @@ std::vector<variant> gmm_model(std::string prefix, uint32_t n, std::string outp
float var_floor = 0.01;
arma::gmm_diag model;
gaussian_mixture_model gmodel = train_model(n, data, var_floor, universal_cluster, noise_cluster, true, model);
uint32_t original_n=6;
uint32_t new_n=6;
gaussian_mixture_model retrained = retrain_model(n, data, var_floor, universal_cluster, noise_cluster, true);
new_n = retrained.new_n;
while(original_n != new_n){
std::cerr << "original n " << original_n << " new n " << new_n << std::endl;
gaussian_mixture_model retrained = retrain_model(new_n, data, var_floor, universal_cluster, noise_cluster, true);
original_n = retrained.n;
new_n = retrained.new_n;
}

assign_clusters(variants, retrained);
/*
std::vector<variant> retraining_set;
Expand Down
1 change: 1 addition & 0 deletions src/gmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
struct gaussian_mixture_model {
std::vector<std::vector<double>> prob_matrix;
uint32_t n;
uint32_t new_n;
float var_floor;
std::vector<double> means;
std::vector<double> hefts;
Expand Down

0 comments on commit 5a77f58

Please sign in to comment.