Skip to content

Commit

Permalink
wip convert_to_global
Browse files Browse the repository at this point in the history
  • Loading branch information
greole committed Oct 15, 2024
1 parent be0e907 commit 07dbaf5
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/MatrixWrapper/Distributed.C
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,28 @@ void RepartDistMatrix::write(const ExecutorHandler &exec_handler,

if (write_global) {
// overwrite non_local column indices with global indices
label rank{exec_handler.get_rank()};
auto partition = get_repartitioner()->get_orig_partition();
std::copy(non_local_sparsity_->col_idxs.get_const_data(),
non_local_sparsity_->col_idxs.get_const_data() +
non_local_sparsity_->num_nnz,
non_local->get_col_idxs());

//
std::vector<gko::span> local_spans {gko::span {0, static_cast<gko::size_type>(local_sparsity_->num_nnz)}};
std::vector<label> local_ranks {rank};
auto ref_exec = exec_handler.get_ref_exec();
auto comm = exec_handler.get_gko_mpi_host_comm();
label rank{exec_handler.get_rank()};
auto partition = gko::share(
gko::experimental::distributed::build_partition_from_local_size<label, label>(ref_exec, *comm.get(), local_sparsity_->dim[0]));

auto global_local_rows =detail::convert_to_global(partition, local->get_row_idxs(), local_spans, local_ranks);
auto global_local_cols =detail::convert_to_global(partition, local->get_col_idxs(), local_spans, local_ranks);
label offset = partition->get_range_bounds()[rank];
label local_nnz = local_sparsity_->num_nnz;

std::copy(global_local_rows.begin(),
global_local_rows.end(),
local->get_row_idxs());
std::copy(global_local_cols.begin(),
global_local_cols.end(),
local->get_col_idxs());
std::transform(local->get_row_idxs(), local->get_row_idxs() + local_nnz, local->get_row_idxs(),
[&](label idx) { return idx + offset; });
std::transform(local->get_col_idxs(), local->get_col_idxs() + local_nnz, local->get_col_idxs(),
[&](label idx) { return idx + offset; });

label non_local_nnz = non_local_sparsity_->num_nnz;
std::transform(non_local->get_row_idxs(), non_local->get_row_idxs() + non_local_nnz, non_local->get_row_idxs(),
[&](label idx) { return idx + offset; });
}

export_mtx(field_name + "_local", local, db);
Expand Down

0 comments on commit 07dbaf5

Please sign in to comment.