Skip to content

Commit

Permalink
unify write out, store matrix format as word, read writeGlobal
Browse files Browse the repository at this point in the history
  • Loading branch information
greole committed Oct 10, 2024
1 parent fb657bf commit 23fc1df
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 48 deletions.
11 changes: 8 additions & 3 deletions include/OGL/MatrixWrapper/Distributed.H
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ public:
std::shared_ptr<const Repartitioner> repartitioner,
std::shared_ptr<const HostMatrixWrapper> host_A);

word get_matrix_format() const { return matrix_format_; }

std::shared_ptr<const gko::LinOp> get_dist_mtx() const { return dist_mtx_; }

RepartDistMatrix(
std::shared_ptr<const gko::Executor> exec, communicator comm,
std::shared_ptr<dist_mtx> dist_mtx,
word matrix_format, std::shared_ptr<dist_mtx> dist_mtx,
std::shared_ptr<const SparsityPattern> local_sparsity,
std::shared_ptr<const SparsityPattern> non_local_sparsity,
std::shared_ptr<const CommunicationPattern> src_comm_pattern,
Expand All @@ -127,6 +129,7 @@ public:
std::vector<InterfaceLocality> &local_interfaces)
: gko::experimental::EnableDistributedLinOp<RepartDistMatrix>(exec),
gko::experimental::distributed::DistributedBase(comm),
matrix_format_(matrix_format),
dist_mtx_(std::move(dist_mtx)),
local_sparsity_(local_sparsity),
non_local_sparsity_(non_local_sparsity),
Expand All @@ -140,7 +143,7 @@ public:

template <typename LocalMatrixType>
void write(const ExecutorHandler &exec_handler, const word field_name_,
const objectRegistry &db_) const;
const objectRegistry &db_, bool write_global) const;

// Needed for distributed/polymorphic_object.hpp
RepartDistMatrix(std::shared_ptr<const gko::Executor> exec,
Expand Down Expand Up @@ -193,6 +196,8 @@ protected:


private:
const word matrix_format_;

std::shared_ptr<dist_mtx> dist_mtx_;

std::shared_ptr<const SparsityPattern> local_sparsity_;
Expand All @@ -214,7 +219,7 @@ std::shared_ptr<const gko::LinOp> get_local(
void write_distributed(const ExecutorHandler &exec_handler, word field_name,
const objectRegistry &db,
std::shared_ptr<RepartDistMatrix> dist_A,
word matrix_format);
bool write_global);

void update_distributed(const ExecutorHandler &exec_handler,
std::shared_ptr<const Repartitioner> repartitioner,
Expand Down
4 changes: 3 additions & 1 deletion include/OGL/lduLduBase.H
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,12 @@ public:
bool export_system(
solver_controls_.lookupOrDefault<Switch>("export", false));
if (export_system && db_.time().writeTime()) {
bool write_global(
solver_controls_.lookupOrDefault<Switch>("writeGlobal", false));
LOG_0(verbose_, "Export system")
// dist_b.write();
write_distributed(exec_handler_, this->fieldName(), db_, dist_A_v,
"Coo");
write_global);
}

LOG_1(verbose_, "start create solver")
Expand Down
81 changes: 37 additions & 44 deletions src/MatrixWrapper/Distributed.C
Original file line number Diff line number Diff line change
Expand Up @@ -45,47 +45,38 @@ std::vector<std::shared_ptr<gko::LinOp>> generate_inner_linops(

template <typename LocalMatrixType>
void RepartDistMatrix::write(const ExecutorHandler &exec_handler,
const word field_name,
const objectRegistry &db) const
const word field_name, const objectRegistry &db,
bool write_global) const
{
auto local = gko::share(
gko::matrix::Coo<scalar, label>::create(exec_handler.get_ref_exec()));
auto non_local = gko::share(
gko::matrix::Coo<scalar, label>::create(exec_handler.get_ref_exec()));

if (repartitioner_->get_fused()) {
auto ret_local =
gko::as<LocalMatrixType>(dist_mtx_->get_local_matrix());
auto coo_local = gko::share(gko::matrix::Coo<scalar, label>::create(
exec_handler.get_ref_exec()));
ret_local->convert_to(coo_local.get());
export_mtx(field_name + "_local", coo_local, db);

auto ret_non_local =
gko::as<LocalMatrixType>(dist_mtx_->get_non_local_matrix());
auto coo_non_local = gko::share(gko::matrix::Coo<scalar, label>::create(
exec_handler.get_ref_exec()));
ret_non_local->convert_to(coo_non_local.get());
export_mtx(field_name + "_non_local", coo_non_local, db);
gko::as<LocalMatrixType>(dist_mtx_->get_local_matrix())
->convert_to(local.get());
gko::as<LocalMatrixType>(dist_mtx_->get_non_local_matrix())
->convert_to(non_local.get());
} else {
auto ret = gko::share(gko::matrix::Coo<scalar, label>::create(
exec_handler.get_ref_exec()));
gko::as<CombinationMatrix<LocalMatrixType>>(
dist_mtx_->get_local_matrix())
->convert_to(ret.get());
export_mtx(field_name + "_local", ret, db);

auto non_loc_ret = gko::share(gko::matrix::Coo<scalar, label>::create(
exec_handler.get_ref_exec()));
->convert_to(local.get());
gko::as<CombinationMatrix<LocalMatrixType>>(
dist_mtx_->get_non_local_matrix())
->convert_to(non_loc_ret.get());

// overwrite with global
bool write_global = false;
if (write_global) {
std::copy(non_local_sparsity_->col_idxs.get_const_data(),
non_local_sparsity_->col_idxs.get_const_data() +
non_local_sparsity_->num_nnz,
non_loc_ret->get_col_idxs());
}
export_mtx(field_name + "_non_local", non_loc_ret, db);
->convert_to(non_local.get());
}

// overwrite column indices with global indices
if (write_global) {
std::copy(non_local_sparsity_->col_idxs.get_const_data(),
non_local_sparsity_->col_idxs.get_const_data() +
non_local_sparsity_->num_nnz,
non_local->get_col_idxs());
}

export_mtx(field_name + "_local", local, db);
export_mtx(field_name + "_non_local", non_local, db);
}

template <typename LocalMatrixType>
Expand Down Expand Up @@ -592,7 +583,7 @@ template <typename LocalMatrixType>
std::shared_ptr<RepartDistMatrix> create_impl(
const ExecutorHandler &exec_handler,
std::shared_ptr<const Repartitioner> repartitioner,
std::shared_ptr<const HostMatrixWrapper> host_A)
std::shared_ptr<const HostMatrixWrapper> host_A, word matrix_format)
{
using dist_mtx =
gko::experimental::distributed::Matrix<scalar, label, label>;
Expand Down Expand Up @@ -672,22 +663,24 @@ std::shared_ptr<RepartDistMatrix> create_impl(
}

return std::make_shared<RepartDistMatrix>(
device_exec, comm, dist_A, repart_loc_sparsity, repart_non_loc_sparsity,
src_comm_pattern, repart_comm_pattern, repartitioner, local_interfaces);
device_exec, comm, matrix_format, dist_A, repart_loc_sparsity,
repart_non_loc_sparsity, src_comm_pattern, repart_comm_pattern,
repartitioner, local_interfaces);
}

void write_distributed(const ExecutorHandler &exec_handler, word field_name,
const objectRegistry &db,
std::shared_ptr<RepartDistMatrix> dist_A,
word matrix_format)
bool write_global)
{
auto matrix_format{dist_A->get_matrix_format()};
if (matrix_format == "Coo") {
return dist_A->write<gko::matrix::Coo<scalar, label>>(exec_handler,
field_name, db);
return dist_A->write<gko::matrix::Coo<scalar, label>>(
exec_handler, field_name, db, write_global);
}
if (matrix_format == "Csr") {
return dist_A->write<gko::matrix::Csr<scalar, label>>(exec_handler,
field_name, db);
return dist_A->write<gko::matrix::Csr<scalar, label>>(
exec_handler, field_name, db, write_global);
}
}

Expand Down Expand Up @@ -719,15 +712,15 @@ std::shared_ptr<RepartDistMatrix> create_distributed(
{
if (matrix_format == "Ell") {
return create_impl<gko::matrix::Ell<scalar, label>>(
exec_handler, repartitioner, hostMatrix);
exec_handler, repartitioner, hostMatrix, matrix_format);
}
if (matrix_format == "Coo") {
return create_impl<gko::matrix::Coo<scalar, label>>(
exec_handler, repartitioner, hostMatrix);
exec_handler, repartitioner, hostMatrix, matrix_format);
}
if (matrix_format == "Csr") {
return create_impl<gko::matrix::Csr<scalar, label>>(
exec_handler, repartitioner, hostMatrix);
exec_handler, repartitioner, hostMatrix, matrix_format);
}

FatalErrorInFunction
Expand Down

0 comments on commit 23fc1df

Please sign in to comment.