Skip to content

Commit

Permalink
Add priority sort to filterdb
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-steinegger committed Oct 29, 2024
1 parent 7c8c4a6 commit 54f8983
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/commons/Parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Parameters::Parameters():
PARAM_EXTRACT_LINES(PARAM_EXTRACT_LINES_ID, "--extract-lines", "Extract N lines", "Extract n lines of each entry", typeid(int), (void *) &extractLines, "^[1-9]{1}[0-9]*$"),
PARAM_COMP_OPERATOR(PARAM_COMP_OPERATOR_ID, "--comparison-operator", "Numerical comparison operator", "Filter by comparing each entry row numerically by using the le) less-than-equal, ge) greater-than-equal or e) equal operator", typeid(std::string), (void *) &compOperator, ""),
PARAM_COMP_VALUE(PARAM_COMP_VALUE_ID, "--comparison-value", "Numerical comparison value", "Filter by comparing each entry to this value", typeid(double), (void *) &compValue, "^.*$"),
PARAM_SORT_ENTRIES(PARAM_SORT_ENTRIES_ID, "--sort-entries", "Sort entries", "Sort column set by --filter-column, by 0: no sorting, 1: increasing, 2: decreasing, 3: random shuffle", typeid(int), (void *) &sortEntries, "^[1-9]{1}[0-9]*$"),
PARAM_SORT_ENTRIES(PARAM_SORT_ENTRIES_ID, "--sort-entries", "Sort entries", "Sort column set by --filter-column, by 0: no sorting, 1: increasing, 2: decreasing, 3: random shuffle, 4: priority", typeid(int), (void *) &sortEntries, "^[0-4]{1}$"),
PARAM_BEATS_FIRST(PARAM_BEATS_FIRST_ID, "--beats-first", "Beats first", "Filter by comparing each entry to the first entry", typeid(bool), (void *) &beatsFirst, ""),
PARAM_JOIN_DB(PARAM_JOIN_DB_ID, "--join-db", "join to DB", "Join another database entry with respect to the database identifier in the chosen column", typeid(std::string), (void *) &joinDB, ""),
// besthitperset
Expand Down Expand Up @@ -866,6 +866,7 @@ Parameters::Parameters():
filterDb.push_back(&PARAM_FILTER_FILE);
filterDb.push_back(&PARAM_BEATS_FIRST);
filterDb.push_back(&PARAM_MAPPING_FILE);
filterDb.push_back(&PARAM_WEIGHT_FILE);
filterDb.push_back(&PARAM_TRIM_TO_ONE_COL);
filterDb.push_back(&PARAM_EXTRACT_LINES);
filterDb.push_back(&PARAM_COMP_OPERATOR);
Expand Down
50 changes: 46 additions & 4 deletions src/util/filterdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <fstream>
#include <random>
#include <iostream>
#include <unordered_map>

#include <regex.h>

Expand Down Expand Up @@ -51,6 +52,8 @@ ComparisonOperator mapOperator(const std::string& op) {
#define INCREASING 1
#define DECREASING 2
#define SHUFFLE 3
#define PRIORITY 4


struct compareString {
bool operator() (const std::string& lhs, const std::string& rhs) const{
Expand Down Expand Up @@ -84,6 +87,8 @@ struct compareFirstEntryDecreasing {

int filterdb(int argc, const char **argv, const Command &command) {
Parameters &par = Parameters::getInstance();
par.PARAM_WEIGHT_FILE.replaceCategory(MMseqsParameter::COMMAND_MISC);

par.parseParameters(argc, argv, command, true, 0, 0);

const size_t column = static_cast<size_t>(par.filterColumn);
Expand All @@ -108,7 +113,7 @@ int filterdb(int argc, const char **argv, const Command &command) {

// JOIN_DB
DBReader<unsigned int>* helper = NULL;

std::unordered_map<unsigned int, float> weights;
// REGEX_FILTERING
regex_t regex;
std::random_device rng;
Expand All @@ -117,6 +122,32 @@ int filterdb(int argc, const char **argv, const Command &command) {
if (par.sortEntries != 0) {
mode = SORT_ENTRIES;
Debug(Debug::INFO) << "Filtering by sorting entries\n";
if (par.sortEntries == PRIORITY) {
if (par.weightFile.empty()) {
Debug(Debug::ERROR) << "Weights file (--weights) must be specified for priority sorting.\n";
EXIT(EXIT_FAILURE);
}
Debug(Debug::INFO) << "Sorting entries by priority\n";
// Read the weights
std::ifstream weightsFile(par.weightFile);
if (!weightsFile) {
Debug(Debug::ERROR) << "Cannot open weights file " << par.weightFile << "\n";
EXIT(EXIT_FAILURE);
}

std::string line;
while (std::getline(weightsFile, line)) {
std::istringstream iss(line);
unsigned int key;
float weight;
if (!(iss >> key >> weight)) {
Debug(Debug::WARNING) << "Invalid line in weights file: " << line << "\n";
continue;
}
weights[key] = weight;
}
weightsFile.close();
}
} else if (par.filteringFile.empty() == false) {
mode = FILE_FILTERING;
Debug(Debug::INFO) << "Filtering using file(s)\n";
Expand Down Expand Up @@ -453,8 +484,19 @@ int filterdb(int argc, const char **argv, const Command &command) {
memcpy(lineBuffer, newLineBuffer, newLineBufferIndex + 1);
}
} else if (mode == SORT_ENTRIES) {
toSort.emplace_back(std::strtod(columnValue, NULL), lineBuffer);
// do not put anything in the output buffer
if (par.sortEntries == PRIORITY) {
unsigned int key = static_cast<unsigned int>(strtoul(columnPointer[column - 1], NULL, 10));
float weight = 0.0f;
auto it = weights.find(key);
if (it != weights.end()) {
weight = it->second;
}
toSort.emplace_back(weight, std::string(lineBuffer));
} else {
// Existing code
toSort.emplace_back(std::strtod(columnValue, NULL), lineBuffer);
}
// Do not put anything in the output buffer
nomatch = 1;
} else {
// Unknown filtering mode, keep all entries
Expand Down Expand Up @@ -482,7 +524,7 @@ int filterdb(int argc, const char **argv, const Command &command) {
if (mode == SORT_ENTRIES) {
if (par.sortEntries == INCREASING) {
std::stable_sort(toSort.begin(), toSort.end(), compareFirstEntry());
} else if (par.sortEntries == DECREASING) {
} else if (par.sortEntries == DECREASING || par.sortEntries == PRIORITY) {
std::stable_sort(toSort.begin(), toSort.end(), compareFirstEntryDecreasing());
} else if (par.sortEntries == SHUFFLE) {
std::shuffle(toSort.begin(), toSort.end(), urng);
Expand Down

0 comments on commit 54f8983

Please sign in to comment.