Skip to content

Commit

Permalink
Add support for flag predicates
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Michel <matthew.michel@intel.com>
  • Loading branch information
mmichel11 committed Oct 15, 2024
1 parent 6f80542 commit 0e0d50e
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,7 @@ __parallel_unique_copy(oneapi::dpl::__internal::__device_backend_tag __backend_t
_CopyOp{_ReduceOp{}, _Assign{}});
}
}

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3,
typename _Range4, typename _BinaryPredicate, typename _BinaryOperator>
auto
Expand All @@ -1128,19 +1129,20 @@ __parallel_reduce_by_segment(oneapi::dpl::__internal::__device_backend_tag __bac
auto __gen_reduce_input = [=](const auto& __in_rng, std::size_t __idx) {
auto&& __in_keys = std::get<0>(__in_rng.tuple());
auto&& __in_vals = std::get<1>(__in_rng.tuple());
using _KeyType = oneapi::dpl::__internal::__value_t<decltype(__in_keys)>;
using _ValueType = oneapi::dpl::__internal::__value_t<decltype(__in_vals)>;
if (__idx == 0 || __binary_pred(__in_keys[__idx], __in_keys[__idx - 1]))
return oneapi::dpl::__internal::make_tuple(size_t{0}, _ValueType{__in_vals[__idx]});
return oneapi::dpl::__internal::make_tuple(size_t{1}, _ValueType{__in_vals[__idx]});
if (__idx == 0 || __binary_pred(__in_keys[__idx - 1], __in_keys[__idx]))
return oneapi::dpl::__internal::make_tuple(size_t{0}, _ValueType{__in_vals[__idx]}, _KeyType{__in_keys[__idx]});
return oneapi::dpl::__internal::make_tuple(size_t{1}, _ValueType{__in_vals[__idx]}, _KeyType{__in_keys[__idx]});
};
auto __reduce_op = [=](const auto& __lhs_tup, const auto& __rhs_tup) {
if (std::get<0>(__rhs_tup) == 0)
{
return oneapi::dpl::__internal::make_tuple(std::get<0>(__lhs_tup),
__binary_op(std::get<1>(__lhs_tup), std::get<1>(__rhs_tup)));
__binary_op(std::get<1>(__lhs_tup), std::get<1>(__rhs_tup)), std::get<2>(__lhs_tup));
}
return oneapi::dpl::__internal::make_tuple(std::get<0>(__lhs_tup) + std::get<0>(__rhs_tup),
std::get<1>(__rhs_tup));
std::get<1>(__rhs_tup), std::get<2>(__rhs_tup));
};
auto __gen_scan_input = __gen_reduce_input;
auto __scan_input_transform = oneapi::dpl::__internal::__no_op{};
Expand All @@ -1151,17 +1153,18 @@ __parallel_reduce_by_segment(oneapi::dpl::__internal::__device_backend_tag __bac
// Assuming this will be present in L1 cache
if (__idx == __n - 1 || !__binary_pred(__in_keys[__idx], __in_keys[__idx + 1]))
{
__out_keys[std::get<0>(__tup)] = __in_keys[__idx];
__out_keys[std::get<0>(__tup)] = std::get<2>(__tup);
__out_values[std::get<0>(__tup)] = std::get<1>(__tup);
}
};
using _KeyType = oneapi::dpl::__internal::__value_t<_Range1>;
using _ValueType = oneapi::dpl::__internal::__value_t<_Range2>;
return __parallel_transform_reduce_then_scan(
__backend_tag, std::forward<_ExecutionPolicy>(__exec),
oneapi::dpl::__ranges::make_zip_view(std::forward<_Range1>(__keys), std::forward<_Range2>(__values)),
oneapi::dpl::__ranges::make_zip_view(std::forward<_Range3>(__out_keys), std::forward<_Range4>(__out_values)),
__gen_reduce_input, __reduce_op, __gen_scan_input, __scan_input_transform,
__write_out, oneapi::dpl::unseq_backend::__no_init_value<oneapi::dpl::__internal::tuple<std::size_t, _ValueType>>{},
__write_out, oneapi::dpl::unseq_backend::__no_init_value<oneapi::dpl::__internal::tuple<std::size_t, _ValueType, _KeyType>>{},
/*Inclusive*/std::true_type{}, /*_IsUniquePattern=*/std::false_type{});
}

Expand Down

0 comments on commit 0e0d50e

Please sign in to comment.