Skip to content

Commit

Permalink
Fix bcaPrec
Browse files Browse the repository at this point in the history
  • Loading branch information
mwydmuch committed Jan 14, 2024
1 parent 74452a6 commit 87456ed
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 15 deletions.
2 changes: 2 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ int main(int argc, char** argv) {
ofo(args);
else if (command == "bcaPrec")
bcaPrec(args);
// else if (command == "bcaF1")
// bcaF1(args);
else if (command == "testPredictionTime")
testPredictionTime(args);
else {
Expand Down
18 changes: 9 additions & 9 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,9 @@ std::vector<std::vector<Prediction>> Model::bcaPrec(SRMatrix& features, SRMatrix
// }
// }

tp.resize(m + 2, 0.0);
fp.resize(m + 2, 0.0);
fn.resize(m + 2, 0.0);
predPos.resize(m + 2, 0.0);
truePos.resize(m + 2, 0.0);
condPos.resize(m + 2, 0.0);

// Generate by predicting labels
if(!args.bcGreedy){
Expand All @@ -395,8 +395,8 @@ std::vector<std::vector<Prediction>> Model::bcaPrec(SRMatrix& features, SRMatrix

for(int j = 0; j < n; ++j){
for(auto &p : predictions[j]){
tp[p.label] += p.value;
fp[p.label] += (1 - p.value);
truePos[p.label] += p.value;
predPos[p.label] += 1;
}
}
}
Expand All @@ -418,17 +418,17 @@ std::vector<std::vector<Prediction>> Model::bcaPrec(SRMatrix& features, SRMatrix
printProgress(tj, n);
if(!args.bcGreedy){
for(auto& p : predictions[j]){
tp[p.label] -= p.value;
fp[p.label] -= (1 - p.value);
truePos[p.label] -= p.value;
predPos[p.label] -= 1;
}
}

predictions[j].clear();
this->predict(predictions[j], features[j], args);

for(auto& p : predictions[j]){
tp[p.label] += p.value;
fp[p.label] += (1 - p.value);
truePos[p.label] += p.value;
predPos[p.label] += 1;
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ class Model {
std::vector<Real> thresholds; // For prediction with thresholds
std::vector<Real> labelsWeights; // For prediction with label weights

std::vector<Real> tp;
std::vector<Real> fp;
std::vector<Real> fn;
std::vector<Real> truePos;
std::vector<Real> predPos;
std::vector<Real> condPos;
std::vector<double> a;
std::vector<double> b;

Expand Down
7 changes: 4 additions & 3 deletions src/models/plt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,15 @@ void PLT::predict(std::vector<Prediction>& prediction, SparseVector& features, A
// };
// }

if(!tp.empty() && !fp.empty()){
//Log(CERR) << "Using TP/FP scores ...\n";
// Precision
if(!truePos.empty() && !predPos.empty()){
//Log(CERR) << "Using Precision scores ...\n";
calculateValue = [&](TreeNode* node, Real prob) {

Real score = -9999999;
int bestL = -1;
for(auto& l : nodesLabels[node->index]){
Real tmpScore = (tp[l] + prob) / (fp[l] + 1) - tp[l] / (fp[l] + 0.000001);
Real tmpScore = (truePos[l] + prob) / (predPos[l] + 1) - truePos[l] / (predPos[l] + 0.000001);
if(tmpScore >= score){
bestL = l;
score = tmpScore;
Expand Down

0 comments on commit 87456ed

Please sign in to comment.