diff --git a/src/io/numpy.cpp b/src/io/numpy.cpp index c91727c..e1a44fd 100644 --- a/src/io/numpy.cpp +++ b/src/io/numpy.cpp @@ -316,11 +316,27 @@ namespace { THROW_ERROR("Currently, only row-major npy files can be read"); } + // load the matrix row-by-row, to make sure this works even if Eigen decides to include padding Eigen::Matrix target(header.Rows, header.Cols); - io::binary_load(source, target.data(), target.data() + header.Rows * header.Cols); + for(int row = 0; row < target.rows(); ++row) { + auto row_data = target.row(row); + io::binary_load(source, row_data.data(), row_data.data() + row_data.size()); + } + return target; } + template + void save_matrix_to_npy_imp(std::streambuf& target, const Eigen::Matrix& matrix) { + io::write_npy_header(target, io::make_npy_description(matrix)); + + // save the matrix row-by-row, to make sure this works even if Eigen decides to include padding + for(int row = 0; row < matrix.rows(); ++row) { + const auto& row_data = matrix.row(row); + io::binary_dump(target, row_data.data(), row_data.data() + row_data.size()); + } + } + } Eigen::Matrix io::load_matrix_from_npy(std::istream& source) { @@ -335,6 +351,20 @@ Eigen::Matrix io::load_ return load_matrix_from_npy(file); } +void io::save_matrix_to_npy(std::ostream& source, + const Eigen::Matrix& matrix) { + save_matrix_to_npy_imp(*source.rdbuf(), matrix); +} + +void io::save_matrix_to_npy(const std::string& path, + const Eigen::Matrix& matrix) { + std::ofstream file(path); + if(!file.is_open()) { + THROW_ERROR("Could not open file {} for writing.", path) + } + return save_matrix_to_npy(file, matrix); +} + #include "doctest.h" #include @@ -491,4 +521,16 @@ TEST_CASE("make description") { CHECK(io::make_npy_description("i4", true, 17) == "{\"descr\": \">i4\", \"fortran_order\": True, \"shape\": (17,)}"); CHECK(io::make_npy_description(" matrix = types::DenseRowMajor::Random(4, 5); + io::save_matrix_to_npy(save_stream, matrix); + + std::istringstream load_stream; + load_stream.str(save_stream.str()); + auto ref = io::load_matrix_from_npy(load_stream); + + CHECK( matrix == ref ); } \ No newline at end of file diff --git a/src/io/numpy.h b/src/io/numpy.h index 7a3642c..478c611 100644 --- a/src/io/numpy.h +++ b/src/io/numpy.h @@ -89,8 +89,14 @@ namespace dismec::io { /*! * \brief Loads a matrix from a numpy array. */ - Eigen::Matrix load_matrix_from_npy(std::istream& source); - Eigen::Matrix load_matrix_from_npy(const std::string& path); + types::DenseRowMajor load_matrix_from_npy(std::istream& source); + types::DenseRowMajor load_matrix_from_npy(const std::string& path); + + /*! + * \brief Saves a matrix to a numpy array. + */ + void save_matrix_to_npy(std::ostream& source, const types::DenseRowMajor&); + void save_matrix_to_npy(const std::string& path, const types::DenseRowMajor&); } #endif //DISMEC_NUMPY_H diff --git a/src/io/prediction.cpp b/src/io/prediction.cpp index 82db1b5..ba99b55 100644 --- a/src/io/prediction.cpp +++ b/src/io/prediction.cpp @@ -5,6 +5,7 @@ #include "io/prediction.h" #include "io/common.h" +#include "io/numpy.h" #include using namespace dismec; @@ -99,17 +100,25 @@ std::pair prediction::read_sparse_prediction(cons return read_sparse_prediction(stream); } -void prediction::save_dense_predictions(const path& target, const PredictionMatrix & values) { +void prediction::save_dense_predictions_as_txt(const path& target, const PredictionMatrix & values) { std::fstream file(target, std::fstream::out); - save_dense_predictions(file, values); + save_dense_predictions_as_txt(file, values); } -void prediction::save_dense_predictions(std::ostream& target, const PredictionMatrix& values) { +void prediction::save_dense_predictions_as_txt(std::ostream& target, const PredictionMatrix& values) { target << values.rows() << " " << values.cols() << "\n"; for(int row = 0; row < values.rows(); ++row) { io::write_vector_as_text(target, values.row(row)) << '\n'; } } +void prediction::save_dense_predictions_as_npy(const path& target, const PredictionMatrix & values) { + std::fstream file(target, std::fstream::out); + save_dense_predictions_as_npy(file, values); +} + +void prediction::save_dense_predictions_as_npy(std::ostream& target, const PredictionMatrix& values) { + io::save_matrix_to_npy(target, values); +} #include "doctest.h" @@ -220,6 +229,6 @@ TEST_CASE("save_dense_predictions") "1.5 0.9 0.4\n"; std::stringstream target; - prediction::save_dense_predictions(target, values); + prediction::save_dense_predictions_as_txt(target, values); CHECK(target.str() == as_text); } diff --git a/src/io/prediction.h b/src/io/prediction.h index c9d64b7..7eaae9d 100644 --- a/src/io/prediction.h +++ b/src/io/prediction.h @@ -51,10 +51,19 @@ namespace dismec::io::prediction * \param target Path to the file which will be created or overwritten, or output stream. * \param values Matrix with the results. Each row corresponds to an instance and each column to a label. */ - void save_dense_predictions(const path& target, const PredictionMatrix& values); + void save_dense_predictions_as_txt(const path& target, const PredictionMatrix& values); - /// \copydoc save_dense_predictions() - void save_dense_predictions(std::ostream& target, const PredictionMatrix& values); + /// \copydoc save_dense_predictions_as_txt() + void save_dense_predictions_as_txt(std::ostream& target, const PredictionMatrix& values); + /*! + * \brief Saves predictions as a dense npy file. + * \param target Path to the file which will be created or overwritten, or output stream. + * \param values Matrix with the results. Each row corresponds to an instance and each column to a label. + */ + void save_dense_predictions_as_npy(const path& target, const PredictionMatrix& values); + + /// \copydoc save_dense_predictions_as_npy() + void save_dense_predictions_as_npy(std::ostream& target, const PredictionMatrix& values); } #endif //DISMEC_IO_PREDICTION_H diff --git a/src/predict.cpp b/src/predict.cpp index 21edd70..3e994ba 100644 --- a/src/predict.cpp +++ b/src/predict.cpp @@ -91,6 +91,7 @@ int main(int argc, const char** argv) { std::filesystem::path save_metrics; int threads = -1; int top_k = 5; + bool save_as_npy = false; DataProcessing DataProc; DataProc.setup_data_args(app); @@ -101,6 +102,7 @@ int main(int argc, const char** argv) { app.add_option("--save-metrics", save_metrics, "Target file in which the metric values are saved"); app.add_option("--topk, --top-k", top_k, "Only the top k predictions will be saved. " "Set to -1 if you need all predictions. (Warning: This may result in very large files!)"); + app.add_flag("--save-as-npy", save_as_npy, "Save the predictions as a numpy file instead of plain text."); int Verbose = 0; app.add_flag("-v", Verbose); @@ -232,6 +234,10 @@ int main(int argc, const char** argv) { } const auto& predictions = task.get_predictions(); - io::prediction::save_dense_predictions(result_file, predictions); + if(save_as_npy) { + io::prediction::save_dense_predictions_as_npy(result_file, predictions); + } else { + io::prediction::save_dense_predictions_as_txt(result_file, predictions); + } } } diff --git a/src/prediction/prediction.h b/src/prediction/prediction.h index f021112..1ea8c05 100644 --- a/src/prediction/prediction.h +++ b/src/prediction/prediction.h @@ -40,6 +40,8 @@ namespace dismec::prediction { const DatasetBase* m_Data; //!< Data on which the prediction is run std::shared_ptr m_Model; //!< Model (possibly partial) for which prediction is run + /// This function resizes the internal thread local feature buffer to correspond to the number of threads. + /// This needs to be called before any call to `init_thread()`. void make_thread_local_features(long num_threads); void init_thread(thread_id_t thread_id) final;