Skip to content

Commit

Permalink
Merge Removeing attributes from math functions
Browse files Browse the repository at this point in the history
This merge removes `GKO_ATTRIBUTES` from constexpr functions in math.hpp. Since we build cuda with `--expt-relaxed-constexpr` the `constexpr` is already enough to allow those functions on the device.

Related PR: #1695
  • Loading branch information
MarcelKoch authored Oct 18, 2024
2 parents 532566d + 98f29f6 commit db80c7d
Showing 1 changed file with 22 additions and 153 deletions.
175 changes: 22 additions & 153 deletions include/ginkgo/core/base/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ using is_complex_s = detail::is_complex_impl<T>;
* @return `true` if T is a complex type, `false` otherwise
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr bool is_complex()
GKO_INLINE constexpr bool is_complex()
{
return detail::is_complex_impl<T>::value;
}
Expand All @@ -307,7 +307,7 @@ using is_complex_or_scalar_s = detail::is_complex_or_scalar_impl<T>;
* @return `true` if T is a complex/scalar type, `false` otherwise
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr bool is_complex_or_scalar()
GKO_INLINE constexpr bool is_complex_or_scalar()
{
return detail::is_complex_or_scalar_impl<T>::value;
}
Expand Down Expand Up @@ -511,7 +511,7 @@ using highest_precision =
* @return the rounded down value
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr reduce_precision<T> round_down(T val)
GKO_INLINE constexpr reduce_precision<T> round_down(T val)
{
return static_cast<reduce_precision<T>>(val);
}
Expand All @@ -527,7 +527,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr reduce_precision<T> round_down(T val)
* @return the rounded up value
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr increase_precision<T> round_up(T val)
GKO_INLINE constexpr increase_precision<T> round_up(T val)
{
return static_cast<increase_precision<T>>(val);
}
Expand Down Expand Up @@ -609,141 +609,19 @@ struct default_converter {
*
* @return returns the ceiled quotient.
*/
GKO_INLINE GKO_ATTRIBUTES constexpr int64 ceildiv(int64 num, int64 den)
GKO_INLINE constexpr int64 ceildiv(int64 num, int64 den)
{
return (num + den - 1) / den;
}


#if defined(__HIPCC__) && GINKGO_HIP_PLATFORM_HCC


/**
* Returns the additive identity for T.
*
* @return additive identity for T
*/
template <typename T>
GKO_INLINE __host__ constexpr T zero()
{
return T{};
}


/**
* Returns the additive identity for T.
*
* @return additive identity for T
*
* @note This version takes an unused reference argument to avoid
* complicated calls like `zero<decltype(x)>()`. Instead, it allows
* `zero(x)`.
*/
template <typename T>
GKO_INLINE __host__ constexpr T zero(const T&)
{
return zero<T>();
}


/**
* Returns the multiplicative identity for T.
*
* @return the multiplicative identity for T
*/
template <typename T>
GKO_INLINE __host__ constexpr T one()
{
return T(1);
}


/**
* Returns the multiplicative identity for T.
*
* @return the multiplicative identity for T
*
* @note This version takes an unused reference argument to avoid
* complicated calls like `one<decltype(x)>()`. Instead, it allows
* `one(x)`.
*/
template <typename T>
GKO_INLINE __host__ constexpr T one(const T&)
{
return one<T>();
}


/**
* Returns the additive identity for T.
*
* @return additive identity for T
*/
template <typename T>
GKO_INLINE __device__ constexpr std::enable_if_t<
!std::is_same<T, std::complex<remove_complex<T>>>::value, T>
zero()
{
return T{};
}


/**
* Returns the additive identity for T.
*
* @return additive identity for T
*
* @note This version takes an unused reference argument to avoid
* complicated calls like `zero<decltype(x)>()`. Instead, it allows
* `zero(x)`.
*/
template <typename T>
GKO_INLINE __device__ constexpr T zero(const T&)
{
return zero<T>();
}


/**
* Returns the multiplicative identity for T.
*
* @return the multiplicative identity for T
*/
template <typename T>
GKO_INLINE __device__ constexpr std::enable_if_t<
!std::is_same<T, std::complex<remove_complex<T>>>::value, T>
one()
{
return T(1);
}


/**
* Returns the multiplicative identity for T.
*
* @return the multiplicative identity for T
*
* @note This version takes an unused reference argument to avoid
* complicated calls like `one<decltype(x)>()`. Instead, it allows
* `one(x)`.
*/
template <typename T>
GKO_INLINE __device__ constexpr T one(const T&)
{
return one<T>();
}


#else


/**
* Returns the additive identity for T.
*
* @return additive identity for T
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr T zero()
GKO_INLINE constexpr T zero()
{
return T{};
}
Expand All @@ -759,7 +637,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr T zero()
* `zero(x)`.
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr T zero(const T&)
GKO_INLINE constexpr T zero(const T&)
{
return zero<T>();
}
Expand All @@ -771,7 +649,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr T zero(const T&)
* @return the multiplicative identity for T
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr T one()
GKO_INLINE constexpr T one()
{
return T(1);
}
Expand All @@ -787,18 +665,12 @@ GKO_INLINE GKO_ATTRIBUTES constexpr T one()
* `one(x)`.
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr T one(const T&)
GKO_INLINE constexpr T one(const T&)
{
return one<T>();
}


#endif // defined(__HIPCC__) && GINKGO_HIP_PLATFORM_HCC


#undef GKO_BIND_ZERO_ONE


/**
* Returns true if and only if the given value is zero.
*
Expand All @@ -808,7 +680,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr T one(const T&)
* @return true iff the given value is zero, i.e. `value == zero<T>()`
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr bool is_zero(T value)
GKO_INLINE constexpr bool is_zero(T value)
{
return value == zero<T>();
}
Expand All @@ -823,7 +695,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr bool is_zero(T value)
* @return true iff the given value is not zero, i.e. `value != zero<T>()`
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr bool is_nonzero(T value)
GKO_INLINE constexpr bool is_nonzero(T value)
{
return value != zero<T>();
}
Expand All @@ -841,7 +713,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr bool is_nonzero(T value)
*
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr T max(const T& x, const T& y)
GKO_INLINE constexpr T max(const T& x, const T& y)
{
return x >= y ? x : y;
}
Expand All @@ -859,7 +731,7 @@ GKO_INLINE GKO_ATTRIBUTES constexpr T max(const T& x, const T& y)
*
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr T min(const T& x, const T& y)
GKO_INLINE constexpr T min(const T& x, const T& y)
{
return x <= y ? x : y;
}
Expand Down Expand Up @@ -1053,7 +925,7 @@ GKO_ATTRIBUTES GKO_INLINE constexpr auto conj(const T& x)
* @return The squared norm of the object.
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr auto squared_norm(const T& x)
GKO_INLINE constexpr auto squared_norm(const T& x)
-> decltype(real(conj(x) * x))
{
return real(conj(x) * x);
Expand All @@ -1070,16 +942,15 @@ GKO_INLINE GKO_ATTRIBUTES constexpr auto squared_norm(const T& x)
* @return x >= zero<T>() ? x : -x;
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t<!is_complex_s<T>::value, T>
abs(const T& x)
GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T> abs(
const T& x)
{
return x >= zero<T>() ? x : -x;
}


template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t<is_complex_s<T>::value,
remove_complex<T>>
GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, remove_complex<T>>
abs(const T& x)
{
return sqrt(squared_norm(x));
Expand All @@ -1092,7 +963,7 @@ abs(const T& x)
* @tparam T the value type to return
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr T pi()
GKO_INLINE constexpr T pi()
{
return static_cast<T>(3.1415926535897932384626433);
}
Expand All @@ -1107,8 +978,8 @@ GKO_INLINE GKO_ATTRIBUTES constexpr T pi()
* @tparam T the corresponding real value type.
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr std::complex<remove_complex<T>> unit_root(
int64 n, int64 k = 1)
GKO_INLINE constexpr std::complex<remove_complex<T>> unit_root(int64 n,
int64 k = 1)
{
return std::polar(one<remove_complex<T>>(),
remove_complex<T>{2} * pi<remove_complex<T>>() * k / n);
Expand Down Expand Up @@ -1259,8 +1130,7 @@ GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value, bool> is_nan(
* @return NaN.
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t<!is_complex_s<T>::value, T>
nan()
GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T> nan()
{
return std::numeric_limits<T>::quiet_NaN();
}
Expand All @@ -1274,8 +1144,7 @@ nan()
* @return complex{NaN, NaN}.
*/
template <typename T>
GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t<is_complex_s<T>::value, T>
nan()
GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, T> nan()
{
return T{nan<remove_complex<T>>(), nan<remove_complex<T>>()};
}
Expand Down

0 comments on commit db80c7d

Please sign in to comment.