From 5f4115caadba3dae9c82f407814d0f52917ec5c4 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 3 Nov 2023 15:20:25 +0000 Subject: [PATCH] Changed the implementation of type_info_of --- device/include/roughpy/device/core.h | 66 +++++++++++++++------- device/include/roughpy/device/kernel_arg.h | 4 +- device/include/roughpy/device/types.h | 27 +++++---- 3 files changed, 63 insertions(+), 34 deletions(-) diff --git a/device/include/roughpy/device/core.h b/device/include/roughpy/device/core.h index 95ade2280..2aa2f59fc 100644 --- a/device/include/roughpy/device/core.h +++ b/device/include/roughpy/device/core.h @@ -197,12 +197,12 @@ enum class TypeCode : uint8_t ArbitraryPrecisionRational = ArbitraryPrecision | Rational, Polynomial = 16U, - APRationalPolynomial = Polynomial // | ArbitraryPrecisionRational + APRationalPolynomial = Polynomial// | ArbitraryPrecisionRational }; /** - * @brief Basic information for identifying the type, size, and + * @brief Basic information for identifying the type, size, alignment, and * configuration of a type. * * This was originally based on the DLPack protocol, but actually that proved @@ -216,7 +216,8 @@ enum class TypeCode : uint8_t struct TypeInfo { TypeCode code; uint8_t bytes; - uint16_t lanes = 1; + uint8_t alignment; + uint8_t lanes = 1; }; template @@ -279,35 +280,60 @@ constexpr bool operator==(const DeviceInfo& lhs, const DeviceInfo& rhs) noexcept return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id; } +constexpr bool operator==(const TypeInfo& lhs, const TypeInfo& rhs) noexcept +{ + return lhs.code == rhs.code && lhs.bytes == rhs.bytes + && lhs.lanes == rhs.lanes; +} +namespace dtl { +template +struct type_code_of_impl; -namespace dtl { +#define RPY_GENERIC_TYPE_CODE_FUNCTION(cond, TC) \ + template \ + struct type_code_of_impl> { \ + static constexpr TypeCode value = (TC); \ + } -template -constexpr TypeInfo type_info(); +RPY_GENERIC_TYPE_CODE_FUNCTION( + is_integral::value&& is_signed::value, + TypeCode::Int +); -template -constexpr enable_if_t::value && is_signed::value, TypeInfo> -type_info() { - return { TypeCode::Int, sizeof(T), 1}; -} +RPY_GENERIC_TYPE_CODE_FUNCTION( + is_integral::value&& is_unsigned::value, + TypeCode::UInt +); -template -constexpr enable_if_t::value && !is_signed::value, TypeInfo> -type_info() { - return { TypeCode::UInt, sizeof(T), 1}; -} +RPY_GENERIC_TYPE_CODE_FUNCTION(is_floating_point::value, TypeCode::Float); + +#undef RPY_GENERIC_TYPE_CODE_FUNCTION + +}// namespace dtl template -constexpr enable_if_t::value, TypeInfo> -type_info() { - return { TypeCode::Float, sizeof(T), 1}; +constexpr TypeCode type_code_of() noexcept +{ + return dtl::type_code_of_impl::value; } +/** + * @brief Get the type info struct relating to the given type + * @tparam T Type to query + * @return TypeInfo struct containing information about the type. + */ +template +constexpr TypeInfo type_info() noexcept +{ + return {type_code_of(), + static_cast(sizeof(T)), + static_cast(alignof(T)), + 1U}; } -}// namespace device +}// namespace devices }// namespace rpy #endif// ROUGHPY_DEVICE_CORE_H_ diff --git a/device/include/roughpy/device/kernel_arg.h b/device/include/roughpy/device/kernel_arg.h index 3114d3e75..4c54cb952 100644 --- a/device/include/roughpy/device/kernel_arg.h +++ b/device/include/roughpy/device/kernel_arg.h @@ -73,14 +73,14 @@ class RPY_EXPORT KernelArgument explicit KernelArgument(T& data) : p_data(&data), m_mode(Pointer), - m_info(dtl::type_info()) + m_info(type_info()) {} template explicit KernelArgument(const T& data) : p_const_data(&data), m_mode(ConstPointer), - m_info(dtl::type_info()) + m_info(type_info()) {} RPY_NO_DISCARD constexpr bool is_buffer() const noexcept diff --git a/device/include/roughpy/device/types.h b/device/include/roughpy/device/types.h index 99d28385f..5de5ca327 100644 --- a/device/include/roughpy/device/types.h +++ b/device/include/roughpy/device/types.h @@ -30,6 +30,7 @@ #include +// Maybe replace this, Eigen/Core brings in so much stuff. #include #include #include @@ -79,24 +80,26 @@ using rational_poly_scalar = lal::polynomial; namespace dtl { +#define RPY_ENABLE_RETURN(C) enable_if_t<(C), TypeInfo> + template <> -constexpr TypeInfo type_info() { - return { TypeCode::Float, sizeof(half), 1}; -} +struct type_code_of_impl { + static constexpr TypeCode value = TypeCode::Float; +}; template <> -constexpr TypeInfo type_info() { - return { TypeCode::BFloat, sizeof(bfloat16), 1}; -} +struct type_code_of_impl { + static constexpr TypeCode value = TypeCode::BFloat; +}; template <> -constexpr TypeInfo type_info() { - return {TypeCode::ArbitraryPrecisionRational, sizeof(rational_scalar_type), 1}; -} +struct type_code_of_impl { + static constexpr TypeCode value = TypeCode::ArbitraryPrecisionRational; +}; template <> -constexpr TypeInfo type_info() { - return {TypeCode::APRationalPolynomial, sizeof(rational_poly_scalar), 1}; -} +struct type_code_of_impl { + static constexpr TypeCode value = TypeCode::APRationalPolynomial; +};