Skip to content

Commit

Permalink
Change implementation of map_impl to support policies
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnh committed Jul 24, 2024
1 parent 986ca55 commit 3c73971
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 150 deletions.
1 change: 1 addition & 0 deletions docs/guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ Guides

guides/introduction.rst
guides/promotion.rst
guides/accuracy.rst
guides/constant.rst
49 changes: 49 additions & 0 deletions docs/guides/accuracy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Accuracy level
===

Many of the functions in Kernel Float take an additional `Accuracy` option as a template parameter.
This option can be used to increase the performance of certain operations, at the cost of lower accuracy.

There are four possible values for this parameter:

* `accurate_policy`: Use the most accurate version of the function available.
* `fast_policy`: Use the "fast math" version (for example, `__sinf` for sin on CUDA devices). Falls back to `accurate_policy` if such a version is not available.
* `approx_policy<N>`: Rough approximation using a polynomial of degree `N`. Falls back to `fast_policy` if no such polynomial exists.
* `default_policy`: Use a global default policy (see the next section).


For example, consider this code:

```C++

#include "kernel_float.h"
namespace kf = kernel_float;


int main() {
kf::vec<float, 2> input = {1.0f, 2.0f};

// Use the default policy
kf::vec<float, 2> A = kf::cos(input);

// Use the most accuracy policy
kf::vec<float, 2> B = kf::cos<kf::accurate_policy>(input);

// Use the fastest policy
kf::vec<float, 2> C = kf::cos<kf::fast_policy>(input);

printf("A = %f, %f", A[0], A[1]);
printf("B = %f, %f", B[0], B[1]);
printf("C = %f, %f", C[0], C[1]);

return EXIT_SUCCESS;
}

```
Setting `default_policy`
---
By default, the value for `default_policy` is `accurate_policy`.
Set the preprocessor option `KERNEL_FLOAT_FAST_MATH=1` to change the default policy to `fast_policy`.
81 changes: 26 additions & 55 deletions include/kernel_float/apply.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,51 +130,50 @@ struct apply_impl {

template<typename F, size_t N, typename Output, typename... Args>
struct apply_fastmath_impl: apply_impl<F, N, Output, Args...> {};
} // namespace detail

template<typename F, size_t N, typename Output, typename... Args>
struct map_impl {
static constexpr size_t packet_size = preferred_vector_size<Output>::value;

KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
if constexpr (N / packet_size > 0) {
#pragma unroll
for (size_t i = 0; i < N - N % packet_size; i += packet_size) {
apply_impl<F, packet_size, Output, Args...>::call(fun, output + i, (args + i)...);
}
}
struct accurate_policy {
template<typename F, size_t N, typename Output, typename... Args>
using type = detail::apply_impl<F, N, Output, Args...>;
};

if constexpr (N % packet_size > 0) {
#pragma unroll
for (size_t i = N - N % packet_size; i < N; i++) {
apply_impl<F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
}
}
}
struct fast_policy {
template<typename F, size_t N, typename Output, typename... Args>
using type = detail::apply_fastmath_impl<F, N, Output, Args...>;
};

template<typename F, size_t N, typename Output, typename... Args>
struct fast_map_impl {
#ifdef KERNEL_FLOAT_POLICY
using default_policy = KERNEL_FLOAT_POLICY;
#else
using default_policy = accurate_policy;
#endif

namespace detail {

template<typename Policy, typename F, size_t N, typename Output, typename... Args>
struct map_policy_impl {
static constexpr size_t packet_size = preferred_vector_size<Output>::value;

KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
if constexpr (N / packet_size > 0) {
#pragma unroll
for (size_t i = 0; i < N - N % packet_size; i += packet_size) {
apply_fastmath_impl<F, packet_size, Output, Args...>::call(
fun,
output + i,
(args + i)...);
Policy::template type<F, N, Output, Args...>::call(fun, output + i, (args + i)...);
}
}

if constexpr (N % packet_size > 0) {
#pragma unroll
for (size_t i = N - N % packet_size; i < N; i++) {
apply_fastmath_impl<F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
Policy::template type<F, N, Output, Args...>::call(fun, output + i, (args + i)...);
}
}
}
};

template<typename F, size_t N, typename Output, typename... Args>
using map_impl = map_policy_impl<default_policy, F, N, Output, Args...>;

} // namespace detail

template<typename F, typename... Args>
Expand All @@ -191,41 +190,13 @@ using map_type =
* vec<float, 4> squared = map([](auto x) { return x * x; }, input); // [1.0f, 4.0f, 9.0f, 16.0f]
* ```
*/
template<typename F, typename... Args>
template<typename Accuracy = default_policy, typename F, typename... Args>
KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
using Output = result_t<F, vector_value_type<Args>...>;
using E = broadcast_vector_extent_type<Args...>;
vector_storage<Output, extent_size<E>> result;

// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
#if KERNEL_FLOAT_FAST_MATH
using apply_impl =
detail::fast_math_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
#else
using map_impl = detail::map_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
#endif

map_impl::call(
fun,
result.data(),
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(
into_vector_storage(args))
.data())...);

return result;
}

/**
* Apply the function `F` to each element from the vector `input` and return the results as a new vector. This
* uses fast-math if available for the given function `F`, otherwise this function behaves like `map`.
*/
template<typename F, typename... Args>
KERNEL_FLOAT_INLINE map_type<F, Args...> fast_map(F fun, const Args&... args) {
using Output = result_t<F, vector_value_type<Args>...>;
using E = broadcast_vector_extent_type<Args...>;
vector_storage<Output, extent_size<E>> result;

detail::fast_map_impl<F, extent_size<E>, Output, vector_value_type<Args>...>::call(
detail::map_policy_impl<Accuracy, F, extent_size<E>, Output, vector_value_type<Args>...>::call(
fun,
result.data(),
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(
Expand Down
13 changes: 3 additions & 10 deletions include/kernel_float/binops.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,7 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co

vector_storage<O, extent_size<E>> result;

// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
#if KERNEL_FLOAT_FAST_MATH
using map_impl = detail::fast_map_impl<F, extent_size<E>, O, T, T>;
#else
using map_impl = detail::map_impl<F, extent_size<E>, O, T, T>;
#endif

map_impl::call(
detail::map_impl<F, extent_size<E>, O, T, T>::call(
fun,
result.data(),
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
Expand Down Expand Up @@ -304,7 +297,7 @@ struct apply_fastmath_impl<ops::divide<T>, N, T, T, T> {
T rhs_rcp[N];

// Fast way to perform division is to multiply by the reciprocal
apply_fastmath_impl<ops::rcp<T>, N, T, T, T>::call({}, rhs_rcp, rhs);
apply_fastmath_impl<ops::rcp<T>, N, T, T>::call({}, rhs_rcp, rhs);
apply_fastmath_impl<ops::multiply<T>, N, T, T, T>::call({}, result, lhs, rhs_rcp);
}
};
Expand All @@ -326,7 +319,7 @@ fast_divide(const L& left, const R& right) {
using E = broadcast_vector_extent_type<L, R>;
vector_storage<T, extent_size<E>> result;

detail::fast_map_impl<ops::divide<T>, extent_size<E>, T, T, T>::call(
detail::map_policy_impl<fast_policy, ops::divide<T>, extent_size<E>, T, T, T>::call(
ops::divide<T> {},
result.data(),
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
Expand Down
2 changes: 1 addition & 1 deletion include/kernel_float/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
#define KERNEL_FLOAT_MAX_ALIGNMENT (32)

#ifndef KERNEL_FLOAT_FAST_MATH
#define KERNEL_FLOAT_FAST_MATH (0)
#define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy;
#endif

#endif //KERNEL_FLOAT_MACROS_H
15 changes: 7 additions & 8 deletions include/kernel_float/unops.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ KERNEL_FLOAT_INLINE vector<R, vector_extent_type<V>> cast(const V& input) {
}

#define KERNEL_FLOAT_DEFINE_UNARY_FUN(NAME) \
template<typename V> \
template<typename Accuracy = default_policy, typename V> \
KERNEL_FLOAT_INLINE vector<vector_value_type<V>, vector_extent_type<V>> NAME(const V& input) { \
using F = ops::NAME<vector_value_type<V>>; \
return map(F {}, input); \
return ::kernel_float::map<Accuracy>(F {}, input); \
}

#define KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \
Expand Down Expand Up @@ -193,12 +193,11 @@ KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcp, 1.0 / input, 1.0f / input)

KERNEL_FLOAT_DEFINE_UNARY_FUN(rcp)

#define KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(NAME) \
template<typename V> \
KERNEL_FLOAT_INLINE vector<vector_value_type<V>, vector_extent_type<V>> fast_##NAME( \
const V& input) { \
using F = ops::NAME<vector_value_type<V>>; \
return fast_map(F {}, input); \
#define KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(NAME) \
template<typename V> \
KERNEL_FLOAT_INLINE vector<vector_value_type<V>, vector_extent_type<V>> fast_##NAME( \
const V& input) { \
return ::kernel_float::map<fast_policy>(ops::NAME<vector_value_type<V>> {}, input); \
}

KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp)
Expand Down
Loading

0 comments on commit 3c73971

Please sign in to comment.