Skip to content

Commit

Permalink
Merge pull request #1974 from cwpearson/enhancement/issue-1671
Browse files Browse the repository at this point in the history
Parallel prefix sum can infer view type
  • Loading branch information
lucbv authored Sep 27, 2023
2 parents cad61f4 + 70e391f commit e8469fd
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 55 deletions.
12 changes: 6 additions & 6 deletions common/src/KokkosKernels_SimpleUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct InclusiveParallelPrefixSum {
* \param num_elements: size of the array
* \param arr: the array for which the prefix sum will be performed.
*/
template <typename view_t, typename MyExecSpace>
template <typename MyExecSpace, typename view_t>
inline void kk_exclusive_parallel_prefix_sum(
const MyExecSpace &exec, typename view_t::value_type num_elements,
view_t arr) {
Expand All @@ -100,7 +100,7 @@ inline void kk_exclusive_parallel_prefix_sum(
* \param num_elements: size of the array
* \param arr: the array for which the prefix sum will be performed.
*/
template <typename view_t, typename MyExecSpace>
template <typename MyExecSpace, typename view_t>
inline void kk_exclusive_parallel_prefix_sum(
typename view_t::value_type num_elements, view_t arr) {
kk_exclusive_parallel_prefix_sum(MyExecSpace(), num_elements, arr);
Expand All @@ -116,7 +116,7 @@ inline void kk_exclusive_parallel_prefix_sum(
* \param finalSum: will be set to arr[num_elements - 1] after computing the
* prefix sum.
*/
template <typename view_t, typename MyExecSpace>
template <typename MyExecSpace, typename view_t>
inline void kk_exclusive_parallel_prefix_sum(
const MyExecSpace &exec, typename view_t::value_type num_elements,
view_t arr, typename view_t::non_const_value_type &finalSum) {
Expand All @@ -135,7 +135,7 @@ inline void kk_exclusive_parallel_prefix_sum(
* \param finalSum: will be set to arr[num_elements - 1] after computing the
* prefix sum.
*/
template <typename view_t, typename MyExecSpace>
template <typename MyExecSpace, typename view_t>
inline void kk_exclusive_parallel_prefix_sum(
typename view_t::value_type num_elements, view_t arr,
typename view_t::non_const_value_type &finalSum) {
Expand All @@ -149,7 +149,7 @@ inline void kk_exclusive_parallel_prefix_sum(
/// \param num_elements: size of the array
/// \param arr: the array for which the prefix sum will be performed.
///
template <typename forward_array_type, typename MyExecSpace>
template <typename MyExecSpace, typename forward_array_type>
void kk_inclusive_parallel_prefix_sum(
MyExecSpace my_exec_space,
typename forward_array_type::value_type num_elements,
Expand All @@ -166,7 +166,7 @@ void kk_inclusive_parallel_prefix_sum(
/// \param num_elements: size of the array
/// \param arr: the array for which the prefix sum will be performed.
///
template <typename forward_array_type, typename MyExecSpace>
template <typename MyExecSpace, typename forward_array_type>
void kk_inclusive_parallel_prefix_sum(
typename forward_array_type::value_type num_elements,
forward_array_type arr) {
Expand Down
6 changes: 2 additions & 4 deletions common/src/KokkosKernels_Utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,7 @@ void inclusive_parallel_prefix_sum(
MyExecSpace my_exec_space,
typename forward_array_type::value_type num_elements,
forward_array_type arr) {
return kk_inclusive_parallel_prefix_sum<forward_array_type, MyExecSpace>(
my_exec_space, num_elements, arr);
return kk_inclusive_parallel_prefix_sum(my_exec_space, num_elements, arr);
}

template <typename forward_array_type, typename MyExecSpace>
Expand All @@ -475,8 +474,7 @@ template <typename forward_array_type, typename MyExecSpace>
void exclusive_parallel_prefix_sum(
typename forward_array_type::value_type num_elements,
forward_array_type arr) {
kk_exclusive_parallel_prefix_sum<forward_array_type, MyExecSpace>(
num_elements, arr);
kk_exclusive_parallel_prefix_sum<MyExecSpace>(num_elements, arr);
}

template <typename array_type>
Expand Down
2 changes: 1 addition & 1 deletion common/unit_test/Test_Common_Sorting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ size_t generateRandomOffsets(OrdView randomCounts, OrdView randomOffsets,
}
Kokkos::deep_copy(randomCounts, countsHost);
Kokkos::deep_copy(randomOffsets, randomCounts);
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<OrdView, ExecSpace>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<ExecSpace>(
n, randomOffsets);
return total;
}
Expand Down
4 changes: 2 additions & 2 deletions sparse/impl/KokkosSparse_par_ilut_numeric_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ struct IlutWrap {
static size_type prefix_sum(RowMapType& row_map) {
size_type result = 0;
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<
RowMapType, typename IlutHandle::HandleExecSpace>(row_map.extent(0),
row_map, result);
typename IlutHandle::HandleExecSpace>(row_map.extent(0), row_map,
result);
return result;
}

Expand Down
9 changes: 3 additions & 6 deletions sparse/impl/KokkosSparse_spadd_symbolic_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,7 @@ void spadd_symbolic_impl(
runSortedCountEntries<KernelHandle, alno_row_view_t_, alno_nnz_view_t_,
blno_row_view_t_, blno_nnz_view_t_, clno_row_view_t_>(
a_rowmap, a_entries, b_rowmap, b_entries, c_rowmap);
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<clno_row_view_t_,
execution_space>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<execution_space>(
nrows + 1, c_rowmap);
} else {
// note: scoping individual parts of the process to free views sooner,
Expand All @@ -542,8 +541,7 @@ void spadd_symbolic_impl(
Kokkos::parallel_for(
"KokkosSparse::SpAdd:Symbolic::InputNotSorted::CountEntries",
range_type(0, nrows), countEntries);
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<offset_view_t,
execution_space>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<execution_space>(
nrows + 1, c_rowmap_upperbound);
Kokkos::deep_copy(c_nnz_upperbound,
Kokkos::subview(c_rowmap_upperbound, nrows));
Expand Down Expand Up @@ -585,8 +583,7 @@ void spadd_symbolic_impl(
"KokkosSparse::SpAdd:Symbolic::InputNotSorted::MergeEntries",
range_type(0, nrows), mergeEntries);
// compute actual c_rowmap
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<clno_row_view_t_,
execution_space>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<execution_space>(
nrows + 1, c_rowmap);
}
addHandle->set_a_b_pos(a_pos, b_pos);
Expand Down
6 changes: 2 additions & 4 deletions sparse/impl/KokkosSparse_spgemm_impl_symbolic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1672,8 +1672,7 @@ void KokkosSPGEMM<HandleType, a_row_view_t_, a_lno_nnz_view_t_,
<< std::endl;
}
typename c_row_view_t::non_const_value_type c_nnz_size = 0;
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<c_row_view_t,
MyExecSpace>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<MyExecSpace>(
m + 1, rowmapC, c_nnz_size);
this->handle->get_spgemm_handle()->set_c_nnz(c_nnz_size);
nnz_lno_t c_max_nnz =
Expand Down Expand Up @@ -2188,8 +2187,7 @@ void KokkosSPGEMM<
}
#endif
typename c_row_view_t::non_const_value_type c_nnz_size = 0;
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<c_row_view_t,
MyExecSpace>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<MyExecSpace>(
m + 1, rowmapC, c_nnz_size);
this->handle->get_spgemm_handle()->set_c_nnz(c_nnz_size);
nnz_lno_t c_max_nnz =
Expand Down
3 changes: 1 addition & 2 deletions sparse/impl/KokkosSparse_spgemm_impl_triangle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1818,8 +1818,7 @@ void KokkosSPGEMM<HandleType, a_row_view_t_, a_lno_nnz_view_t_,
p_entriesA, bnnz, p_rowmapB_begins, p_rowmapB_ends,
p_set_index_b, p_set_b, p_rowmapC, NULL, dummy);

KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<c_row_view_t,
MyExecSpace>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<MyExecSpace>(
this->a_row_cnt + 1, rowmapC_);
MyExecSpace().fence();

Expand Down
12 changes: 4 additions & 8 deletions sparse/impl/KokkosSparse_twostage_gauss_seidel_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,22 +633,18 @@ class TwostageGaussSeidel {
// shift ptr so that it now contains offsets (combine it with the previous
// functor calls?)
if (direction == GS_FORWARD || direction == GS_SYMMETRIC) {
KokkosKernels::Impl::kk_inclusive_parallel_prefix_sum<row_map_view_t,
execution_space>(
KokkosKernels::Impl::kk_inclusive_parallel_prefix_sum<execution_space>(
1 + num_rows, rowmap_viewL);
if (compact_form) {
KokkosKernels::Impl::kk_inclusive_parallel_prefix_sum<row_map_view_t,
execution_space>(
KokkosKernels::Impl::kk_inclusive_parallel_prefix_sum<execution_space>(
1 + num_rows, rowmap_viewLa);
}
}
if (direction == GS_BACKWARD || direction == GS_SYMMETRIC) {
KokkosKernels::Impl::kk_inclusive_parallel_prefix_sum<row_map_view_t,
execution_space>(
KokkosKernels::Impl::kk_inclusive_parallel_prefix_sum<execution_space>(
1 + num_rows, rowmap_viewU);
if (compact_form) {
KokkosKernels::Impl::kk_inclusive_parallel_prefix_sum<row_map_view_t,
execution_space>(
KokkosKernels::Impl::kk_inclusive_parallel_prefix_sum<execution_space>(
1 + num_rows, rowmap_viewUa);
}
}
Expand Down
6 changes: 2 additions & 4 deletions sparse/src/KokkosSparse_SortCrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,7 @@ void sort_and_merge_matrix(const exec_space& exec,
auto entries_orig = entries_in;
auto values_orig = values_in;
// Prefix sum to get rowmap
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<nc_rowmap_t,
exec_space>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<exec_space>(
exec, numRows + 1, nc_rowmap_out);
rowmap_out = nc_rowmap_out;
entries_out = entries_t(Kokkos::view_alloc(exec, Kokkos::WithoutInitializing,
Expand Down Expand Up @@ -761,8 +760,7 @@ void sort_and_merge_graph(const exec_space& exec,
// In the case where the output rowmap is the same as the input, we could just
// assign "rowmap_out = rowmap_in" except that would break const-correctness.
// Can skip filling the entries, however.
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<nc_rowmap_t,
exec_space>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<exec_space>(
exec, numRows + 1, nc_rowmap_out);
rowmap_out = nc_rowmap_out;
entries_out = entries_t(Kokkos::view_alloc(exec, Kokkos::WithoutInitializing,
Expand Down
27 changes: 9 additions & 18 deletions sparse/src/KokkosSparse_Utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,7 @@ void transpose_matrix(
team_size, thread_size),
tm);

KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<out_row_view_t,
MyExecSpace>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<MyExecSpace>(
num_cols + 1, t_xadj);

Kokkos::deep_copy(tmp_row_view, t_xadj);
Expand Down Expand Up @@ -497,8 +496,7 @@ void transpose_graph(
team_size, thread_size),
tm);

KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<out_row_view_t,
MyExecSpace>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<MyExecSpace>(
num_cols + 1, t_xadj);

Kokkos::deep_copy(tmp_row_view, t_xadj);
Expand Down Expand Up @@ -802,8 +800,7 @@ void kk_create_reverse_map(

// kk_inclusive_parallel_prefix_sum<reverse_array_type,
// MyExecSpace>(tmp_reverse_size + 1, tmp_color_xadj);
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<reverse_array_type,
MyExecSpace>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<MyExecSpace>(
tmp_reverse_size + 1, tmp_color_xadj);
MyExecSpace().fence();

Expand Down Expand Up @@ -838,8 +835,7 @@ void kk_create_reverse_map(

// kk_inclusive_parallel_prefix_sum<reverse_array_type,
// MyExecSpace>(num_reverse_elements + 1, reverse_map_xadj);
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<reverse_array_type,
MyExecSpace>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<MyExecSpace>(
num_reverse_elements + 1, tmp_color_xadj);
MyExecSpace().fence();

Expand Down Expand Up @@ -1500,8 +1496,7 @@ crstmat_t kk_get_lower_triangle(
nr, ne, rowmap, entries, new_row_map.data(), new_indices,
use_dynamic_scheduling, chunksize);

KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<row_map_view_t,
exec_space>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<exec_space>(
nr + 1, new_row_map);
exec_space().fence();

Expand Down Expand Up @@ -1558,8 +1553,7 @@ crstmat_t kk_get_lower_crs_matrix(
nr, ne, rowmap, entries, new_row_map.data(), new_indices,
use_dynamic_scheduling, chunksize);

KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<row_map_view_t,
exec_space>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<exec_space>(
nr + 1, new_row_map);
exec_space().fence();

Expand Down Expand Up @@ -1612,8 +1606,7 @@ graph_t kk_get_lower_crs_graph(graph_t in_crs_matrix,
kk_get_lower_triangle_count<size_type, lno_t, exec_space>(
nr, ne, rowmap, entries, new_row_map.data(), new_indices);

KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<row_map_view_t,
exec_space>(
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<exec_space>(
nr + 1, new_row_map);
exec_space().fence();

Expand Down Expand Up @@ -1666,8 +1659,7 @@ void kk_get_lower_triangle(typename cols_view_t::non_const_value_type nr,
nr, ne, rowmap, entries, out_rowmap.data(), new_indices.data(),
use_dynamic_scheduling, chunksize, is_lower);

KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<out_row_map_view_t,
exec_space>(nr + 1,
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<exec_space>(nr + 1,
out_rowmap);
exec_space().fence();

Expand Down Expand Up @@ -1775,8 +1767,7 @@ void kk_create_incidence_matrix_from_original_matrix(
permutation.data(), use_dynamic_scheduling, chunksize,
sort_decreasing_order);
exec_space().fence();
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<out_row_map_view_t,
exec_space>(nr + 1,
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<exec_space>(nr + 1,
out_rowmap);

// kk_print_1Dview(out_rowmap, false, 20);
Expand Down

0 comments on commit e8469fd

Please sign in to comment.