diff --git a/include/OGL/MatrixWrapper/Distributed.H b/include/OGL/MatrixWrapper/Distributed.H index 78e30a81..053bf9e2 100644 --- a/include/OGL/MatrixWrapper/Distributed.H +++ b/include/OGL/MatrixWrapper/Distributed.H @@ -114,11 +114,13 @@ public: std::shared_ptr repartitioner, std::shared_ptr host_A); + word get_matrix_format() const { return matrix_format_; } + std::shared_ptr get_dist_mtx() const { return dist_mtx_; } RepartDistMatrix( std::shared_ptr exec, communicator comm, - std::shared_ptr dist_mtx, + word matrix_format, std::shared_ptr dist_mtx, std::shared_ptr local_sparsity, std::shared_ptr non_local_sparsity, std::shared_ptr src_comm_pattern, @@ -127,6 +129,7 @@ public: std::vector &local_interfaces) : gko::experimental::EnableDistributedLinOp(exec), gko::experimental::distributed::DistributedBase(comm), + matrix_format_(matrix_format), dist_mtx_(std::move(dist_mtx)), local_sparsity_(local_sparsity), non_local_sparsity_(non_local_sparsity), @@ -140,7 +143,7 @@ public: template void write(const ExecutorHandler &exec_handler, const word field_name_, - const objectRegistry &db_) const; + const objectRegistry &db_, bool write_global) const; // Needed for distributed/polymorphic_object.hpp RepartDistMatrix(std::shared_ptr exec, @@ -193,6 +196,8 @@ protected: private: + const word matrix_format_; + std::shared_ptr dist_mtx_; std::shared_ptr local_sparsity_; @@ -214,7 +219,7 @@ std::shared_ptr get_local( void write_distributed(const ExecutorHandler &exec_handler, word field_name, const objectRegistry &db, std::shared_ptr dist_A, - word matrix_format); + bool write_global); void update_distributed(const ExecutorHandler &exec_handler, std::shared_ptr repartitioner, diff --git a/include/OGL/lduLduBase.H b/include/OGL/lduLduBase.H index 6df75694..237fdd6d 100644 --- a/include/OGL/lduLduBase.H +++ b/include/OGL/lduLduBase.H @@ -264,10 +264,12 @@ public: bool export_system( solver_controls_.lookupOrDefault("export", false)); if (export_system && db_.time().writeTime()) { + bool write_global( + solver_controls_.lookupOrDefault("writeGlobal", false)); LOG_0(verbose_, "Export system") // dist_b.write(); write_distributed(exec_handler_, this->fieldName(), db_, dist_A_v, - "Coo"); + write_global); } LOG_1(verbose_, "start create solver") diff --git a/src/MatrixWrapper/Distributed.C b/src/MatrixWrapper/Distributed.C index f9f6512a..e388876f 100644 --- a/src/MatrixWrapper/Distributed.C +++ b/src/MatrixWrapper/Distributed.C @@ -45,47 +45,38 @@ std::vector> generate_inner_linops( template void RepartDistMatrix::write(const ExecutorHandler &exec_handler, - const word field_name, - const objectRegistry &db) const + const word field_name, const objectRegistry &db, + bool write_global) const { + auto local = gko::share( + gko::matrix::Coo::create(exec_handler.get_ref_exec())); + auto non_local = gko::share( + gko::matrix::Coo::create(exec_handler.get_ref_exec())); + if (repartitioner_->get_fused()) { - auto ret_local = - gko::as(dist_mtx_->get_local_matrix()); - auto coo_local = gko::share(gko::matrix::Coo::create( - exec_handler.get_ref_exec())); - ret_local->convert_to(coo_local.get()); - export_mtx(field_name + "_local", coo_local, db); - - auto ret_non_local = - gko::as(dist_mtx_->get_non_local_matrix()); - auto coo_non_local = gko::share(gko::matrix::Coo::create( - exec_handler.get_ref_exec())); - ret_non_local->convert_to(coo_non_local.get()); - export_mtx(field_name + "_non_local", coo_non_local, db); + gko::as(dist_mtx_->get_local_matrix()) + ->convert_to(local.get()); + gko::as(dist_mtx_->get_non_local_matrix()) + ->convert_to(non_local.get()); } else { - auto ret = gko::share(gko::matrix::Coo::create( - exec_handler.get_ref_exec())); gko::as>( dist_mtx_->get_local_matrix()) - ->convert_to(ret.get()); - export_mtx(field_name + "_local", ret, db); - - auto non_loc_ret = gko::share(gko::matrix::Coo::create( - exec_handler.get_ref_exec())); + ->convert_to(local.get()); gko::as>( dist_mtx_->get_non_local_matrix()) - ->convert_to(non_loc_ret.get()); - - // overwrite with global - bool write_global = false; - if (write_global) { - std::copy(non_local_sparsity_->col_idxs.get_const_data(), - non_local_sparsity_->col_idxs.get_const_data() + - non_local_sparsity_->num_nnz, - non_loc_ret->get_col_idxs()); - } - export_mtx(field_name + "_non_local", non_loc_ret, db); + ->convert_to(non_local.get()); } + + // overwrite column indices with global indices + if (write_global) { + std::copy(non_local_sparsity_->col_idxs.get_const_data(), + non_local_sparsity_->col_idxs.get_const_data() + + non_local_sparsity_->num_nnz, + non_local->get_col_idxs()); + } + + export_mtx(field_name + "_local", local, db); + export_mtx(field_name + "_non_local", non_local, db); } template @@ -592,7 +583,7 @@ template std::shared_ptr create_impl( const ExecutorHandler &exec_handler, std::shared_ptr repartitioner, - std::shared_ptr host_A) + std::shared_ptr host_A, word matrix_format) { using dist_mtx = gko::experimental::distributed::Matrix; @@ -672,22 +663,24 @@ std::shared_ptr create_impl( } return std::make_shared( - device_exec, comm, dist_A, repart_loc_sparsity, repart_non_loc_sparsity, - src_comm_pattern, repart_comm_pattern, repartitioner, local_interfaces); + device_exec, comm, matrix_format, dist_A, repart_loc_sparsity, + repart_non_loc_sparsity, src_comm_pattern, repart_comm_pattern, + repartitioner, local_interfaces); } void write_distributed(const ExecutorHandler &exec_handler, word field_name, const objectRegistry &db, std::shared_ptr dist_A, - word matrix_format) + bool write_global) { + auto matrix_format{dist_A->get_matrix_format()}; if (matrix_format == "Coo") { - return dist_A->write>(exec_handler, - field_name, db); + return dist_A->write>( + exec_handler, field_name, db, write_global); } if (matrix_format == "Csr") { - return dist_A->write>(exec_handler, - field_name, db); + return dist_A->write>( + exec_handler, field_name, db, write_global); } } @@ -719,15 +712,15 @@ std::shared_ptr create_distributed( { if (matrix_format == "Ell") { return create_impl>( - exec_handler, repartitioner, hostMatrix); + exec_handler, repartitioner, hostMatrix, matrix_format); } if (matrix_format == "Coo") { return create_impl>( - exec_handler, repartitioner, hostMatrix); + exec_handler, repartitioner, hostMatrix, matrix_format); } if (matrix_format == "Csr") { return create_impl>( - exec_handler, repartitioner, hostMatrix); + exec_handler, repartitioner, hostMatrix, matrix_format); } FatalErrorInFunction