Skip to content

Commit

Permalink
Changed the implementation of type_info_of
Browse files Browse the repository at this point in the history
  • Loading branch information
inakleinbottle committed Nov 3, 2023
1 parent 0a24b0e commit 5f4115c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 34 deletions.
66 changes: 46 additions & 20 deletions device/include/roughpy/device/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,12 @@ enum class TypeCode : uint8_t
ArbitraryPrecisionRational = ArbitraryPrecision | Rational,

Polynomial = 16U,
APRationalPolynomial = Polynomial // | ArbitraryPrecisionRational
APRationalPolynomial = Polynomial// | ArbitraryPrecisionRational

};

/**
* @brief Basic information for identifying the type, size, and
* @brief Basic information for identifying the type, size, alignment, and
* configuration of a type.
*
* This was originally based on the DLPack protocol, but actually that proved
Expand All @@ -216,7 +216,8 @@ enum class TypeCode : uint8_t
struct TypeInfo {
TypeCode code;
uint8_t bytes;
uint16_t lanes = 1;
uint8_t alignment;
uint8_t lanes = 1;
};

template <typename I>
Expand Down Expand Up @@ -279,35 +280,60 @@ constexpr bool operator==(const DeviceInfo& lhs, const DeviceInfo& rhs) noexcept
return lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id;
}

constexpr bool operator==(const TypeInfo& lhs, const TypeInfo& rhs) noexcept
{
return lhs.code == rhs.code && lhs.bytes == rhs.bytes
&& lhs.lanes == rhs.lanes;
}

namespace dtl {

template <typename T, typename SFINAE = void>
struct type_code_of_impl;

namespace dtl {
#define RPY_GENERIC_TYPE_CODE_FUNCTION(cond, TC) \
template <typename T> \
struct type_code_of_impl<T, enable_if_t<(cond)>> { \
static constexpr TypeCode value = (TC); \
}

template <typename T>
constexpr TypeInfo type_info();
RPY_GENERIC_TYPE_CODE_FUNCTION(
is_integral<T>::value&& is_signed<T>::value,
TypeCode::Int
);

template <typename T>
constexpr enable_if_t<is_integral<T>::value && is_signed<T>::value, TypeInfo>
type_info() {
return { TypeCode::Int, sizeof(T), 1};
}
RPY_GENERIC_TYPE_CODE_FUNCTION(
is_integral<T>::value&& is_unsigned<T>::value,
TypeCode::UInt
);

template <typename T>
constexpr enable_if_t<is_integral<T>::value && !is_signed<T>::value, TypeInfo>
type_info() {
return { TypeCode::UInt, sizeof(T), 1};
}
RPY_GENERIC_TYPE_CODE_FUNCTION(is_floating_point<T>::value, TypeCode::Float);

#undef RPY_GENERIC_TYPE_CODE_FUNCTION

}// namespace dtl

template <typename T>
constexpr enable_if_t<is_floating_point<T>::value, TypeInfo>
type_info() {
return { TypeCode::Float, sizeof(T), 1};
constexpr TypeCode type_code_of() noexcept
{
return dtl::type_code_of_impl<T>::value;
}

/**
* @brief Get the type info struct relating to the given type
* @tparam T Type to query
* @return TypeInfo struct containing information about the type.
*/
template <typename T>
constexpr TypeInfo type_info() noexcept
{
return {type_code_of<T>(),
static_cast<uint8_t>(sizeof(T)),
static_cast<uint8_t>(alignof(T)),
1U};
}

}// namespace device
}// namespace devices
}// namespace rpy

#endif// ROUGHPY_DEVICE_CORE_H_
4 changes: 2 additions & 2 deletions device/include/roughpy/device/kernel_arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ class RPY_EXPORT KernelArgument
explicit KernelArgument(T& data)
: p_data(&data),
m_mode(Pointer),
m_info(dtl::type_info<T>())
m_info(type_info<T>())
{}

template <typename T>
explicit KernelArgument(const T& data)
: p_const_data(&data),
m_mode(ConstPointer),
m_info(dtl::type_info<T>())
m_info(type_info<T>())
{}

RPY_NO_DISCARD constexpr bool is_buffer() const noexcept
Expand Down
27 changes: 15 additions & 12 deletions device/include/roughpy/device/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <complex>

// Maybe replace this, Eigen/Core brings in so much stuff.
#include <Eigen/Core>
#include <libalgebra_lite/coefficients.h>
#include <libalgebra_lite/packed_integer.h>
Expand Down Expand Up @@ -79,24 +80,26 @@ using rational_poly_scalar = lal::polynomial<lal::rational_field>;

namespace dtl {

#define RPY_ENABLE_RETURN(C) enable_if_t<(C), TypeInfo>

template <>
constexpr TypeInfo type_info<half>() {
return { TypeCode::Float, sizeof(half), 1};
}
struct type_code_of_impl<half, void> {
static constexpr TypeCode value = TypeCode::Float;
};

template <>
constexpr TypeInfo type_info<bfloat16>() {
return { TypeCode::BFloat, sizeof(bfloat16), 1};
}
struct type_code_of_impl<bfloat16, void> {
static constexpr TypeCode value = TypeCode::BFloat;
};

template <>
constexpr TypeInfo type_info<rational_scalar_type>() {
return {TypeCode::ArbitraryPrecisionRational, sizeof(rational_scalar_type), 1};
}
struct type_code_of_impl<rational_scalar_type , void> {
static constexpr TypeCode value = TypeCode::ArbitraryPrecisionRational;
};
template <>
constexpr TypeInfo type_info<rational_poly_scalar>() {
return {TypeCode::APRationalPolynomial, sizeof(rational_poly_scalar), 1};
}
struct type_code_of_impl<rational_poly_scalar , void> {
static constexpr TypeCode value = TypeCode::APRationalPolynomial;
};



Expand Down

0 comments on commit 5f4115c

Please sign in to comment.