Skip to content

Commit

Permalink
enable saving (dense) predictions as npy files
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Sep 7, 2022
1 parent 2564807 commit 62cbe70
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 11 deletions.
44 changes: 43 additions & 1 deletion src/io/numpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,27 @@ namespace {
THROW_ERROR("Currently, only row-major npy files can be read");
}

// load the matrix row-by-row, to make sure this works even if Eigen decides to include padding
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> target(header.Rows, header.Cols);
io::binary_load(source, target.data(), target.data() + header.Rows * header.Cols);
for(int row = 0; row < target.rows(); ++row) {
auto row_data = target.row(row);
io::binary_load(source, row_data.data(), row_data.data() + row_data.size());
}

return target;
}

template<class T>
void save_matrix_to_npy_imp(std::streambuf& target, const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
io::write_npy_header(target, io::make_npy_description(matrix));

// save the matrix row-by-row, to make sure this works even if Eigen decides to include padding
for(int row = 0; row < matrix.rows(); ++row) {
const auto& row_data = matrix.row(row);
io::binary_dump(target, row_data.data(), row_data.data() + row_data.size());
}
}

}

Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> io::load_matrix_from_npy(std::istream& source) {
Expand All @@ -335,6 +351,20 @@ Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> io::load_
return load_matrix_from_npy(file);
}

void io::save_matrix_to_npy(std::ostream& source,
const Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
save_matrix_to_npy_imp(*source.rdbuf(), matrix);
}

void io::save_matrix_to_npy(const std::string& path,
const Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>& matrix) {
std::ofstream file(path);
if(!file.is_open()) {
THROW_ERROR("Could not open file {} for writing.", path)
}
return save_matrix_to_npy(file, matrix);
}


#include "doctest.h"
#include <sstream>
Expand Down Expand Up @@ -491,4 +521,16 @@ TEST_CASE("make description") {
CHECK(io::make_npy_description("<f8", false, 5) == "{\"descr\": \"<f8\", \"fortran_order\": False, \"shape\": (5,)}");
CHECK(io::make_npy_description(">i4", true, 17) == "{\"descr\": \">i4\", \"fortran_order\": True, \"shape\": (17,)}");
CHECK(io::make_npy_description("<f8", false, 7, 5) == "{\"descr\": \"<f8\", \"fortran_order\": False, \"shape\": (7, 5)}");
}

TEST_CASE("save/load round trip") {
std::ostringstream save_stream;
types::DenseRowMajor<real_t> matrix = types::DenseRowMajor<real_t>::Random(4, 5);
io::save_matrix_to_npy(save_stream, matrix);

std::istringstream load_stream;
load_stream.str(save_stream.str());
auto ref = io::load_matrix_from_npy(load_stream);

CHECK( matrix == ref );
}
10 changes: 8 additions & 2 deletions src/io/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,14 @@ namespace dismec::io {
/*!
* \brief Loads a matrix from a numpy array.
*/
Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> load_matrix_from_npy(std::istream& source);
Eigen::Matrix<real_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> load_matrix_from_npy(const std::string& path);
types::DenseRowMajor<real_t> load_matrix_from_npy(std::istream& source);
types::DenseRowMajor<real_t> load_matrix_from_npy(const std::string& path);

/*!
* \brief Saves a matrix to a numpy array.
*/
void save_matrix_to_npy(std::ostream& source, const types::DenseRowMajor<real_t>&);
void save_matrix_to_npy(const std::string& path, const types::DenseRowMajor<real_t>&);
}

#endif //DISMEC_NUMPY_H
17 changes: 13 additions & 4 deletions src/io/prediction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "io/prediction.h"
#include "io/common.h"
#include "io/numpy.h"
#include <fstream>

using namespace dismec;
Expand Down Expand Up @@ -99,17 +100,25 @@ std::pair<IndexMatrix, PredictionMatrix> prediction::read_sparse_prediction(cons
return read_sparse_prediction(stream);
}

void prediction::save_dense_predictions(const path& target, const PredictionMatrix & values) {
void prediction::save_dense_predictions_as_txt(const path& target, const PredictionMatrix & values) {
std::fstream file(target, std::fstream::out);
save_dense_predictions(file, values);
save_dense_predictions_as_txt(file, values);
}

void prediction::save_dense_predictions(std::ostream& target, const PredictionMatrix& values) {
void prediction::save_dense_predictions_as_txt(std::ostream& target, const PredictionMatrix& values) {
target << values.rows() << " " << values.cols() << "\n";
for(int row = 0; row < values.rows(); ++row) {
io::write_vector_as_text(target, values.row(row)) << '\n';
}
}
void prediction::save_dense_predictions_as_npy(const path& target, const PredictionMatrix & values) {
std::fstream file(target, std::fstream::out);
save_dense_predictions_as_npy(file, values);
}

void prediction::save_dense_predictions_as_npy(std::ostream& target, const PredictionMatrix& values) {
io::save_matrix_to_npy(target, values);
}

#include "doctest.h"

Expand Down Expand Up @@ -220,6 +229,6 @@ TEST_CASE("save_dense_predictions")
"1.5 0.9 0.4\n";

std::stringstream target;
prediction::save_dense_predictions(target, values);
prediction::save_dense_predictions_as_txt(target, values);
CHECK(target.str() == as_text);
}
15 changes: 12 additions & 3 deletions src/io/prediction.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,19 @@ namespace dismec::io::prediction
* \param target Path to the file which will be created or overwritten, or output stream.
* \param values Matrix with the results. Each row corresponds to an instance and each column to a label.
*/
void save_dense_predictions(const path& target, const PredictionMatrix& values);
void save_dense_predictions_as_txt(const path& target, const PredictionMatrix& values);

/// \copydoc save_dense_predictions()
void save_dense_predictions(std::ostream& target, const PredictionMatrix& values);
/// \copydoc save_dense_predictions_as_txt()
void save_dense_predictions_as_txt(std::ostream& target, const PredictionMatrix& values);
/*!
* \brief Saves predictions as a dense npy file.
* \param target Path to the file which will be created or overwritten, or output stream.
* \param values Matrix with the results. Each row corresponds to an instance and each column to a label.
*/
void save_dense_predictions_as_npy(const path& target, const PredictionMatrix& values);

/// \copydoc save_dense_predictions_as_npy()
void save_dense_predictions_as_npy(std::ostream& target, const PredictionMatrix& values);
}

#endif //DISMEC_IO_PREDICTION_H
8 changes: 7 additions & 1 deletion src/predict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ int main(int argc, const char** argv) {
std::filesystem::path save_metrics;
int threads = -1;
int top_k = 5;
bool save_as_npy = false;

DataProcessing DataProc;
DataProc.setup_data_args(app);
Expand All @@ -101,6 +102,7 @@ int main(int argc, const char** argv) {
app.add_option("--save-metrics", save_metrics, "Target file in which the metric values are saved");
app.add_option("--topk, --top-k", top_k, "Only the top k predictions will be saved. "
"Set to -1 if you need all predictions. (Warning: This may result in very large files!)");
app.add_flag("--save-as-npy", save_as_npy, "Save the predictions as a numpy file instead of plain text.");
int Verbose = 0;
app.add_flag("-v", Verbose);

Expand Down Expand Up @@ -232,6 +234,10 @@ int main(int argc, const char** argv) {
}
const auto& predictions = task.get_predictions();

io::prediction::save_dense_predictions(result_file, predictions);
if(save_as_npy) {
io::prediction::save_dense_predictions_as_npy(result_file, predictions);
} else {
io::prediction::save_dense_predictions_as_txt(result_file, predictions);
}
}
}
2 changes: 2 additions & 0 deletions src/prediction/prediction.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ namespace dismec::prediction {
const DatasetBase* m_Data; //!< Data on which the prediction is run
std::shared_ptr<const Model> m_Model; //!< Model (possibly partial) for which prediction is run

/// This function resizes the internal thread local feature buffer to correspond to the number of threads.
/// This needs to be called before any call to `init_thread()`.
void make_thread_local_features(long num_threads);

void init_thread(thread_id_t thread_id) final;
Expand Down

0 comments on commit 62cbe70

Please sign in to comment.