Skip to content

Commit

Permalink
Implement correct return for reduce_by_segment
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Michel <matthew.michel@intel.com>
  • Loading branch information
mmichel11 committed Oct 15, 2024
1 parent cfcf9db commit 6f80542
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
13 changes: 8 additions & 5 deletions include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -907,11 +907,14 @@ __pattern_reduce_by_segment(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&&
_Range2&& __values, _Range3&& __out_keys, _Range4&& __out_values,
_BinaryPredicate __binary_pred, _BinaryOperator __binary_op)
{
oneapi::dpl::__par_backend_hetero::__parallel_reduce_by_segment(_BackendTag{}, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__keys),
std::forward<_Range2>(__values), std::forward<_Range3>(__out_keys),
std::forward<_Range4>(__out_values), __binary_pred, __binary_op)
.wait();
return 1;
auto __res = oneapi::dpl::__par_backend_hetero::__parallel_reduce_by_segment(_BackendTag{}, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__keys),
std::forward<_Range2>(__values), std::forward<_Range3>(__out_keys),
std::forward<_Range4>(__out_values), __binary_pred, __binary_op);
__res.wait();
// Because our init type ends up being tuple<std::size_t, ValType>, return the first component which is the write index. Add 1 to return the
// past-the-end iterator pair of segmented reduction.
return std::get<0>(__res.get()) + 1;
// TODO: this needs to be enabled if reduce then scan cannot be satisfied.
#if 0
// The algorithm reduces values in __values where the
// associated keys for the values are equal to the adjacent key.
Expand Down
10 changes: 4 additions & 6 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1129,11 +1129,9 @@ __parallel_reduce_by_segment(oneapi::dpl::__internal::__device_backend_tag __bac
auto&& __in_keys = std::get<0>(__in_rng.tuple());
auto&& __in_vals = std::get<1>(__in_rng.tuple());
using _ValueType = oneapi::dpl::__internal::__value_t<decltype(__in_vals)>;
if (__idx == 0)
if (__idx == 0 || __binary_pred(__in_keys[__idx], __in_keys[__idx - 1]))
return oneapi::dpl::__internal::make_tuple(size_t{0}, _ValueType{__in_vals[__idx]});
if (!__binary_pred(__in_keys[__idx], __in_keys[__idx - 1]))
return oneapi::dpl::__internal::make_tuple(size_t{1}, _ValueType{__in_vals[__idx]});
return oneapi::dpl::__internal::make_tuple(size_t{0}, _ValueType{__in_vals[__idx]});
return oneapi::dpl::__internal::make_tuple(size_t{1}, _ValueType{__in_vals[__idx]});
};
auto __reduce_op = [=](const auto& __lhs_tup, const auto& __rhs_tup) {
if (std::get<0>(__rhs_tup) == 0)
Expand Down Expand Up @@ -1163,8 +1161,8 @@ __parallel_reduce_by_segment(oneapi::dpl::__internal::__device_backend_tag __bac
oneapi::dpl::__ranges::make_zip_view(std::forward<_Range1>(__keys), std::forward<_Range2>(__values)),
oneapi::dpl::__ranges::make_zip_view(std::forward<_Range3>(__out_keys), std::forward<_Range4>(__out_values)),
__gen_reduce_input, __reduce_op, __gen_scan_input, __scan_input_transform,
__write_out, oneapi::dpl::unseq_backend::__no_init_value<oneapi::dpl::__internal::tuple<std::size_t, _ValueType>>{}, /*Inclusive*/std::true_type{}, /*_IsUniquePattern=*/std::false_type{}
);
__write_out, oneapi::dpl::unseq_backend::__no_init_value<oneapi::dpl::__internal::tuple<std::size_t, _ValueType>>{},
/*Inclusive*/std::true_type{}, /*_IsUniquePattern=*/std::false_type{});
}

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _UnaryPredicate>
Expand Down

0 comments on commit 6f80542

Please sign in to comment.