diff --git a/library/include/rocwmma/internal/blend.hpp b/library/include/rocwmma/internal/blend.hpp index 79f1ba33..776b0751 100644 --- a/library/include/rocwmma/internal/blend.hpp +++ b/library/include/rocwmma/internal/blend.hpp @@ -99,6 +99,7 @@ namespace rocwmma using Zip4 = Driver; using Zip8 = Driver; using Zip16 = Driver; + using Zip32 = Driver; // Unpack functions using UnpackByteLo = Driver; @@ -107,6 +108,18 @@ namespace rocwmma using UnpackWordHi = Driver; using UnpackByteLoHi = Driver; + // Extract functions + using ExtractByteEven = Driver; + using ExtractByteOdd = Driver; + using ExtractWordEven = Driver; + using ExtractWordOdd = Driver; + + using ExtractByteEvenOdd = Driver; + using ExtractWordEvenOdd = Driver; + + using ExtractByteOddEven = Driver; + using ExtractWordOddEven = Driver; + } // namespace Blend } // namespace rocwmma diff --git a/library/include/rocwmma/internal/blend_impl.hpp b/library/include/rocwmma/internal/blend_impl.hpp index 59f2a563..c0bb5861 100644 --- a/library/include/rocwmma/internal/blend_impl.hpp +++ b/library/include/rocwmma/internal/blend_impl.hpp @@ -49,6 +49,7 @@ namespace rocwmma using Properties::OP_GROUP_SIZE_1; using Properties::OP_GROUP_SIZE_16; using Properties::OP_GROUP_SIZE_2; + using Properties::OP_GROUP_SIZE_32; using Properties::OP_GROUP_SIZE_4; using Properties::OP_GROUP_SIZE_8; @@ -247,6 +248,7 @@ namespace rocwmma using Zip4 = Zip; using Zip8 = Zip; using Zip16 = Zip; + using Zip32 = Zip; // Blend sub-dword elements in regular ordered patterns using UnpackByteLo = PermByte<0u, 4u, 1u, 5u>; @@ -255,6 +257,16 @@ namespace rocwmma using UnpackWordHi = PermWord<1u, 3u>; using UnpackByteLoHi = PermByte<0u, 6u, 1u, 7u>; + using ExtractByteEven = PermByte<0u, 2u, 4u, 6u>; + using ExtractByteOdd = PermByte<1u, 3u, 5u, 7u>; + using ExtractWordEven = UnpackWordLo; + using ExtractWordOdd = UnpackWordHi; + + using ExtractByteEvenOdd = PermByte<0u, 2u, 5u, 7u>; + using ExtractByteOddEven = PermByte<1u, 3u, 4u, 6u>; + using ExtractWordEvenOdd = PermWord<0u, 3u>; + using ExtractWordOddEven = PermWord<1u, 2u>; + } // namespace Ops } // namespace BlendImpl diff --git a/library/include/rocwmma/internal/convert.hpp b/library/include/rocwmma/internal/convert.hpp index fe9807fc..1236c33f 100644 --- a/library/include/rocwmma/internal/convert.hpp +++ b/library/include/rocwmma/internal/convert.hpp @@ -27,6 +27,7 @@ #define ROCWMMA_CONVERT_HPP #include "types.hpp" +#include "utility/forward.hpp" namespace rocwmma { @@ -58,7 +59,7 @@ namespace rocwmma template ROCWMMA_DEVICE static inline auto exec(IncomingT&& regsIn) -> IncomingT&& { - return std::forward(regsIn); + return forward(regsIn); } }; diff --git a/library/include/rocwmma/internal/coop_load.hpp b/library/include/rocwmma/internal/coop_load.hpp index 8690a6ba..8a12a5a9 100644 --- a/library/include/rocwmma/internal/coop_load.hpp +++ b/library/include/rocwmma/internal/coop_load.hpp @@ -63,7 +63,7 @@ namespace rocwmma // Outer loop = index 0, // Inner loop = index N-1 - template @@ -73,14 +73,14 @@ namespace rocwmma StrideSpace&& strideSpace, Strides2d&& strides2d) { - static_assert(VecTraits>::size() - == VecTraits>::size(), + static_assert(VecTraits>::size() + == VecTraits>::size(), "Mismatched size"); - auto strideOffset = DataLayout::fromMatrixCoord(std::get(strides2d), ldm); - auto strideCount = std::get(strideSpace); + auto strideOffset = DataLayout::fromMatrixCoord(get(strides2d), ldm); + auto strideCount = get(strideSpace); // Last depth layer will invoke the load - if constexpr(Depth == (VecTraits>::size() - 1u)) + if constexpr(Depth == (VecTraits>::size() - 1u)) { #pragma unroll for(int i = 0; i < strideCount; i++) @@ -135,7 +135,7 @@ namespace rocwmma } // Split the reduced stride space. - auto workItemsPerWave = std::max(totalWorkItems / maxWaves, 1u); + auto workItemsPerWave = max(totalWorkItems / maxWaves, 1u); auto strideSpaceS = inflate_coord_left(workItemsPerWave - 1u, strideSpaceR) + 1u; // Add back in the VW dimension, for the full stride @@ -191,7 +191,7 @@ namespace rocwmma } // Split the reduced stride space. - constexpr auto workItemsPerWave = std::max(totalWorkItems / maxWaves, 1u); + constexpr auto workItemsPerWave = max(totalWorkItems / maxWaves, 1u); constexpr auto strideSpaceS = inflate_coord_left(workItemsPerWave - 1u, strideSpaceR) + 1u; diff --git a/library/include/rocwmma/internal/coop_store.hpp b/library/include/rocwmma/internal/coop_store.hpp index b781bd40..0dd6e1d9 100644 --- a/library/include/rocwmma/internal/coop_store.hpp +++ b/library/include/rocwmma/internal/coop_store.hpp @@ -64,7 +64,7 @@ namespace rocwmma // Outer loop = index 0, // Inner loop = index N-1 - template @@ -74,14 +74,14 @@ namespace rocwmma StrideSpace&& strideCounts, Strides2d&& strides2d) { - static_assert(VecTraits>::size() - == VecTraits>::size(), + static_assert(VecTraits>::size() + == VecTraits>::size(), "Mismatched size"); - auto strideOffset = DataLayout::fromMatrixCoord(std::get(strides2d), ldm); - auto strideCount = std::get(strideCounts); + auto strideOffset = DataLayout::fromMatrixCoord(get(strides2d), ldm); + auto strideCount = get(strideCounts); // Last depth layer will invoke the load - if constexpr(Depth == (VecTraits>::size() - 1u)) + if constexpr(Depth == (VecTraits>::size() - 1u)) { #pragma unroll for(int i = 0; i < strideCount; i++) @@ -136,7 +136,7 @@ namespace rocwmma } // Split the reduced stride space. - auto workItemsPerWave = std::max(totalWorkItems / maxWaves, 1u); + auto workItemsPerWave = max(totalWorkItems / maxWaves, 1u); auto strideSpaceS = inflate_coord_left(workItemsPerWave - 1u, strideSpaceR) + 1u; // Add back in the VW dimension, for the full stride @@ -190,7 +190,7 @@ namespace rocwmma } // Split the reduced stride space. - constexpr auto workItemsPerWave = std::max(totalWorkItems / maxWaves, 1u); + constexpr auto workItemsPerWave = max(totalWorkItems / maxWaves, 1u); constexpr auto strideSpaceS = inflate_coord_left(workItemsPerWave - 1u, strideSpaceR) + 1u; diff --git a/library/include/rocwmma/internal/float8.h b/library/include/rocwmma/internal/float8.h index 71c1d5fb..b32ac3c9 100644 --- a/library/include/rocwmma/internal/float8.h +++ b/library/include/rocwmma/internal/float8.h @@ -34,12 +34,6 @@ using uint8_t = __hip_internal::uint8_t; using uint16_t = __hip_internal::uint16_t; -namespace std -{ - template - struct conditional; -} - #endif // We are clipping in down conversion by default @@ -771,7 +765,7 @@ inline ROCWMMA_HOST_DEVICE bool operator!=(rocwmma_bf8 a, rocwmma_bf8 b) template {}, int>::type = 0> + typename rocwmma::enable_if{}, int>::type = 0> inline ROCWMMA_HOST_DEVICE T explicit_downcast(Ta a) { // same type, no conversion @@ -779,20 +773,20 @@ inline ROCWMMA_HOST_DEVICE T explicit_downcast(Ta a) } // Use h/w intrinsic and optimized version when __gfx940__ -template < - typename T, - typename Ta, - bool stochastic_rounding, - typename std::enable_if<(!(std::is_same{}) - && (std::is_same{} || std::is_same{})), - int>::type - = 0> +template {}) + && (rocwmma::is_same{} + || rocwmma::is_same{})), + int>::type + = 0> inline ROCWMMA_HOST_DEVICE T explicit_downcast(Ta a, uint32_t rng) { #if ROCWMMA_ARCH_GFX940 || ROCWMMA_ARCH_GFX941 || ROCWMMA_ARCH_GFX942 // NOTE: we are directly calling cast_to_f8_from_f32 instead of constructor to optimize away one runtime branch T val; - if(std::is_same::value) + if(rocwmma::is_same::value) { val.data = rocwmma_f8::cast_to_f8_from_f32(float(a), rng); } @@ -811,14 +805,14 @@ inline ROCWMMA_HOST_DEVICE T explicit_downcast(Ta a, uint32_t rng) // NOTE NOTE: The above code is good if we don't consider HIP-GEMM code and only consider the quantization // However, if we need HIP-GEMM for fall-back, we would need explicit_cast handles Tacc=f32 to To=f16/bf16 conversion -template < - typename T, - typename Ta, - bool stochastic_rounding, - typename std::enable_if<(!(std::is_same{}) - && !(std::is_same{} || std::is_same{})), - int>::type - = 0> +template {}) + && !(rocwmma::is_same{} + || rocwmma::is_same{})), + int>::type + = 0> inline ROCWMMA_HOST_DEVICE T explicit_downcast(Ta a, uint32_t rng) { // the return type is not a F8 types, no SR for those types diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index d07e5b41..42878c0f 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -132,7 +132,7 @@ namespace rocwmma { MaxVW = detail::MaxVWSelector:: Result, - VW = std::is_same::value ? MaxVW : 1u + VW = is_same::value ? MaxVW : 1u }; // Layout mapping for 1d / 2d @@ -140,7 +140,7 @@ namespace rocwmma using MatrixLayout = MatrixLayout::template ColNT; - static_assert(!(std::is_same_v && VW > 1), + static_assert(!(is_same_v && VW > 1), "matrix_a in col_major currently does not support VW > 1"); }; @@ -156,7 +156,7 @@ namespace rocwmma { MaxVW = detail::MaxVWSelector:: Result, - VW = std::is_same::value ? MaxVW : 1u + VW = is_same::value ? MaxVW : 1u }; // Layout mapping for 1d / 2d @@ -164,7 +164,7 @@ namespace rocwmma using MatrixLayout = MatrixLayout::template RowNT; - static_assert(!(std::is_same_v && VW > 1), + static_assert(!(is_same_v && VW > 1), "matrix_b in row_major currently does not support VW > 1"); }; @@ -178,8 +178,8 @@ namespace rocwmma // Vector size properties enum : uint32_t { - MaxVW = (std::is_same::value || ROCWMMA_ARCH_GFX11) ? 1u : 4u, - VW = std::is_same::value ? MaxVW : 1u + MaxVW = (is_same::value || ROCWMMA_ARCH_GFX11) ? 1u : 4u, + VW = is_same::value ? MaxVW : 1u }; // Layout mapping for 1d / 2d @@ -187,7 +187,7 @@ namespace rocwmma using MatrixLayout = MatrixLayout::template RowNT; - static_assert(!(std::is_same::value && VW > 1), + static_assert(!(is_same::value && VW > 1), "accumulator in row_major currently does not support VW > 1"); }; diff --git a/library/include/rocwmma/internal/layout.hpp b/library/include/rocwmma/internal/layout.hpp index 48462bcb..b02c231f 100644 --- a/library/include/rocwmma/internal/layout.hpp +++ b/library/include/rocwmma/internal/layout.hpp @@ -26,6 +26,7 @@ #ifndef ROCWMMA_LAYOUT_HPP #define ROCWMMA_LAYOUT_HPP +#include "utility/type_traits.hpp" #include "layout_impl.hpp" namespace rocwmma @@ -188,8 +189,8 @@ namespace rocwmma typename DataLayout, uint32_t VectorWidth, uint32_t MaxVectorWidth> - struct ColNT : public std::conditional_t< - std::is_same::value, + struct ColNT : public conditional_t< + is_same::value, detail::ColOrthoVW, detail::ColOrthoVW> { @@ -202,11 +203,11 @@ namespace rocwmma // elements in both row_major or col_major data layouts. // This layout cannot support for VW > 1 in col_major data layout otherwise the // ordering is broken. - static_assert(!(std::is_same_v && VectorWidth > 1), + static_assert(!(is_same_v && VectorWidth > 1), "ColNT in col_major does not support VectorWidth > 1"); // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert(std::is_same_v && (MaxVectorWidth <= BlockK), + static_assert(is_same_v && (MaxVectorWidth <= BlockK), "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); }; }; @@ -315,8 +316,8 @@ namespace rocwmma typename DataLayout, uint32_t VectorWidth, uint32_t MaxVectorWidth> - struct RowNT : public std::conditional_t< - std::is_same::value, + struct RowNT : public conditional_t< + is_same::value, detail::RowOrthoVW, detail::RowOrthoVW> { @@ -329,11 +330,11 @@ namespace rocwmma // elements in both in row_major or col_major data layouts. // This layout cannot support for VW > 1 in row_major data layout otherwise the // ordering is broken. - static_assert(!(std::is_same_v && VectorWidth > 1), + static_assert(!(is_same_v && VectorWidth > 1), "RowNT in row_major does not support VectorWidth > 1"); // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert(std::is_same_v && (MaxVectorWidth <= BlockK), + static_assert(is_same_v && (MaxVectorWidth <= BlockK), "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); }; }; @@ -498,8 +499,8 @@ namespace rocwmma typename DataLayout, uint32_t VectorWidth, uint32_t MaxVectorWidth = VectorWidth> - struct Col : public std::conditional_t< - std::is_same::value, + struct Col : public conditional_t< + is_same::value, detail::ColInlineVW, detail::ColOrthoVW> { @@ -664,8 +665,8 @@ namespace rocwmma typename DataLayout, uint32_t VectorWidth, uint32_t MaxVectorWidth = VectorWidth> - struct Row : public std::conditional_t< - std::is_same::value, + struct Row : public conditional_t< + is_same::value, detail::RowInlineVW, detail::RowOrthoVW> { diff --git a/library/include/rocwmma/internal/layout_impl.hpp b/library/include/rocwmma/internal/layout_impl.hpp index 7dd08439..0047d394 100644 --- a/library/include/rocwmma/internal/layout_impl.hpp +++ b/library/include/rocwmma/internal/layout_impl.hpp @@ -59,12 +59,12 @@ namespace rocwmma /// Helper to ensure layout types are consistent (same) /// template - struct ConsistencyCheck : public std::false_type + struct ConsistencyCheck : public false_type { }; template - struct ConsistencyCheck : public std::true_type + struct ConsistencyCheck : public true_type { }; @@ -72,12 +72,12 @@ namespace rocwmma /// Helper to check if layout types are orthogonal /// template - struct OrthogonalCheck : public std::true_type + struct OrthogonalCheck : public true_type { }; template - struct OrthogonalCheck : public std::false_type + struct OrthogonalCheck : public false_type { }; @@ -208,13 +208,13 @@ namespace rocwmma /// Check for consistency in element ordering between two layouts /// template - struct ConsistencyCheck : public std::false_type + struct ConsistencyCheck : public false_type { }; // Same type is compatible template - struct ConsistencyCheck : public std::true_type + struct ConsistencyCheck : public true_type { }; @@ -229,7 +229,7 @@ namespace rocwmma MatrixLayout::ColNT, MatrixLayout:: ColNT> - : public std::true_type + : public true_type { }; @@ -242,7 +242,7 @@ namespace rocwmma MatrixLayout:: ColNT, MatrixLayout::ColNT> - : public std::true_type + : public true_type { }; @@ -255,7 +255,7 @@ namespace rocwmma MatrixLayout:: RowNT, MatrixLayout::RowNT> - : public std::true_type + : public true_type { }; @@ -268,7 +268,7 @@ namespace rocwmma MatrixLayout::RowNT, MatrixLayout:: RowNT> - : public std::true_type + : public true_type { }; @@ -286,7 +286,7 @@ namespace rocwmma Col, MatrixLayout:: Col> - : public std::true_type + : public true_type { }; @@ -302,7 +302,7 @@ namespace rocwmma Row, MatrixLayout:: Row> - : public std::true_type + : public true_type { }; @@ -311,13 +311,13 @@ namespace rocwmma /// template - struct OrthogonalCheck : public std::false_type + struct OrthogonalCheck : public false_type { }; // Same type is not orthogonal template - struct OrthogonalCheck : public std::false_type + struct OrthogonalCheck : public false_type { }; @@ -325,7 +325,7 @@ namespace rocwmma struct OrthogonalCheck< MatrixLayout::ColNT, MatrixLayout::RowNT> - : public std::true_type + : public true_type { }; @@ -340,7 +340,7 @@ namespace rocwmma ColNT, MatrixLayout:: RowNT> - : public std::true_type + : public true_type { }; @@ -348,7 +348,7 @@ namespace rocwmma struct OrthogonalCheck< MatrixLayout::RowNT, MatrixLayout::ColNT> - : public std::true_type + : public true_type { }; @@ -363,7 +363,7 @@ namespace rocwmma RowNT, MatrixLayout:: ColNT> - : public std::true_type + : public true_type { }; @@ -382,7 +382,7 @@ namespace rocwmma DataT, typename DataLayout::template OrthogonalLayout_t, RhsVectorWidth, - MaxVectorWidth>> : public std::true_type + MaxVectorWidth>> : public true_type { }; @@ -401,7 +401,7 @@ namespace rocwmma DataT, typename DataLayout::template OrthogonalLayout_t, RhsVectorWidth, - MaxVectorWidth>> : public std::true_type + MaxVectorWidth>> : public true_type { }; @@ -416,7 +416,7 @@ namespace rocwmma Col, MatrixLayout:: Row> - : public std::true_type + : public true_type { }; @@ -505,7 +505,7 @@ namespace rocwmma WaveSize = IOTraits::ThreadsPerIO, // Number of BlockDim columns gathered per cycle of MaxVW - MaxKPerIO = WaveSize * MaxVectorWidth / std::min(BlockDim, WaveSize), + MaxKPerIO = WaveSize * MaxVectorWidth / min(BlockDim, WaveSize), BlockDimStride_X = WaveSize, BlockDimStride_Y = 0u, @@ -520,7 +520,7 @@ namespace rocwmma LargeDim = BlockDim >= WaveSize, // Number of segments in BlockDim direction - BlockDimSegs = std::max(BlockDim / BlockDimStride_X, 1u), + BlockDimSegs = max(BlockDim / BlockDimStride_X, 1u), // Number of segments in the BlockK direction BlockKSegs = BlockK / BlockKStride_Y, @@ -742,7 +742,7 @@ namespace rocwmma MaxElementsPerIO = WaveSize * MaxVectorWidth, // Number of BlockDim columns gathered per cycle of MaxVW - MaxKPerIO = std::max(1u, MaxElementsPerIO / BlockDim), + MaxKPerIO = max(1u, MaxElementsPerIO / BlockDim), VWStride_X = VectorWidth, VWStride_Y = 0u, @@ -757,13 +757,13 @@ namespace rocwmma LargeDim = BlockDim >= MaxElementsPerIO, // Number of segments in BlockDim direction - BlockDimSegs = std::max(1u, BlockDim / BlockDimStride_X), + BlockDimSegs = max(1u, BlockDim / BlockDimStride_X), // Number of segments in the BlockK direction - BlockKSegs = std::max(1u, BlockK / BlockKStride_Y), + BlockKSegs = max(1u, BlockK / BlockKStride_Y), // Number of segments in the MaxVW direction - VWSegs = std::max(1u, MaxVectorWidth / VWStride_X), + VWSegs = max(1u, MaxVectorWidth / VWStride_X), // Log2 Values Log2BlockDim = Log2::value, @@ -933,8 +933,7 @@ namespace rocwmma ROCWMMA_DEVICE constexpr static inline auto strides() { auto t = Traits::OrthoLayout::strides(); - return make_vector( - swap(std::get<0>(t)), swap(std::get<1>(t)), swap(std::get<2>(t))); + return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); } ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT @@ -979,8 +978,7 @@ namespace rocwmma ROCWMMA_DEVICE constexpr static inline auto strides() { auto t = Traits::OrthoLayout::strides(); - return make_vector( - swap(std::get<0>(t)), swap(std::get<1>(t)), swap(std::get<2>(t))); + return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); } ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT diff --git a/library/include/rocwmma/internal/mapping_util.hpp b/library/include/rocwmma/internal/mapping_util.hpp index 8719d31f..b98a201e 100644 --- a/library/include/rocwmma/internal/mapping_util.hpp +++ b/library/include/rocwmma/internal/mapping_util.hpp @@ -27,7 +27,7 @@ #define ROCWMMA_MAPPING_UTIL_HPP #include "types.hpp" - +#include "utility/type_traits.hpp" namespace rocwmma { // Fwd declaration @@ -59,11 +59,11 @@ namespace rocwmma // Size of workgroup, normalized to wave count. template 0u && TBlockY > 0u), - typename std::enable_if_t* = nullptr> + enable_if_t* = nullptr> ROCWMMA_DEVICE constexpr static inline WorkgroupDimT workgroupDim(); template 0u && TBlockY > 0u), - typename std::enable_if_t* = nullptr> + enable_if_t* = nullptr> ROCWMMA_DEVICE static inline WorkgroupDimT workgroupDim(); }; @@ -93,8 +93,8 @@ namespace rocwmma enum : uint32_t { - MajorIndex = std::is_same::value ? 0 : 1, - MinorIndex = std::is_same::value ? 1 : 0 + MajorIndex = is_same::value ? 0 : 1, + MinorIndex = is_same::value ? 1 : 0 }; // Determine the leading dimension of a matrix. diff --git a/library/include/rocwmma/internal/mapping_util_impl.hpp b/library/include/rocwmma/internal/mapping_util_impl.hpp index e13e690a..106f755f 100644 --- a/library/include/rocwmma/internal/mapping_util_impl.hpp +++ b/library/include/rocwmma/internal/mapping_util_impl.hpp @@ -105,7 +105,7 @@ namespace rocwmma template template 0u && TBlockY > 0u) */, - typename std::enable_if_t* /* = nullptr */> + enable_if_t* /* = nullptr */> ROCWMMA_DEVICE constexpr inline auto WaveSpace::workgroupDim() -> WorkgroupDimT { @@ -114,7 +114,7 @@ namespace rocwmma template template 0u && TBlockY > 0u) */, - typename std::enable_if_t* /* = nullptr */> + enable_if_t* /* = nullptr */> ROCWMMA_DEVICE inline auto WaveSpace::workgroupDim() -> WorkgroupDimT { return waveCount(make_coord2d(blockDim.x, blockDim.y)); @@ -251,7 +251,7 @@ namespace rocwmma ROCWMMA_DEVICE inline auto MappingUtil::matrixCoord( BlockCoordT const& blockCoord) -> MatrixCoordT { - return MatrixSpace::fromBlockCoord(std::forward(blockCoord)); + return MatrixSpace::fromBlockCoord(forward(blockCoord)); } template @@ -259,7 +259,7 @@ namespace rocwmma MappingUtil::dataOffset( MatrixCoordT const& matrixCoord, uint32_t ldm) { - return DataSpace::fromMatrixCoord(std::forward(matrixCoord), ldm); + return DataSpace::fromMatrixCoord(forward(matrixCoord), ldm); } template @@ -268,7 +268,7 @@ namespace rocwmma DataT const* baseAddr, MatrixCoordT const& matrixCoord, uint32_t ldm) { return baseAddr - + DataSpace::fromMatrixCoord(std::forward(matrixCoord), ldm); + + DataSpace::fromMatrixCoord(forward(matrixCoord), ldm); } template @@ -276,7 +276,7 @@ namespace rocwmma DataT* baseAddr, MatrixCoordT const& matrixCoord, uint32_t ldm) { return baseAddr - + DataSpace::fromMatrixCoord(std::forward(matrixCoord), ldm); + + DataSpace::fromMatrixCoord(forward(matrixCoord), ldm); } } // namespace rocwmma diff --git a/library/include/rocwmma/internal/mfma.hpp b/library/include/rocwmma/internal/mfma.hpp index 6747eef2..794bea81 100644 --- a/library/include/rocwmma/internal/mfma.hpp +++ b/library/include/rocwmma/internal/mfma.hpp @@ -52,7 +52,7 @@ namespace rocwmma BlockM, BlockN, BlockK, - typename std::enable_if_t> + enable_if_t> { // Full-fragment IO traits using IOTraitsA = IOTraits; @@ -90,10 +90,10 @@ namespace rocwmma // A / B and C / D types must match static_assert( - std::is_same::value, + is_same::value, "A and B registers must be of same type"); static_assert( - std::is_same::value, + is_same::value, "C and D registers must be of same type"); // Full fragment counts must match packed IO counts diff --git a/library/include/rocwmma/internal/opaque_load.hpp b/library/include/rocwmma/internal/opaque_load.hpp index 1bdd47cd..c14e1978 100644 --- a/library/include/rocwmma/internal/opaque_load.hpp +++ b/library/include/rocwmma/internal/opaque_load.hpp @@ -78,7 +78,7 @@ namespace rocwmma // Outer loop = index 0, // Inner loop = index N-1 - template @@ -92,7 +92,7 @@ namespace rocwmma auto strideCount = get(strideCounts); // Last depth layer will invoke the load - if constexpr(Depth == (VecTraits>::size() - 1u)) + if constexpr(Depth == (VecTraits>::size() - 1u)) { #pragma unroll for(int i = 0; i < strideCount; i++) diff --git a/library/include/rocwmma/internal/opaque_store.hpp b/library/include/rocwmma/internal/opaque_store.hpp index 7c89b12d..1f1f9990 100644 --- a/library/include/rocwmma/internal/opaque_store.hpp +++ b/library/include/rocwmma/internal/opaque_store.hpp @@ -73,7 +73,7 @@ namespace rocwmma using StoreVecTraits = VecTraits; - template @@ -87,7 +87,7 @@ namespace rocwmma auto strideCount = get(strideCounts); // Last depth layer will invoke the load - if constexpr(Depth == (VecTraits>::size() - 1u)) + if constexpr(Depth == (VecTraits>::size() - 1u)) { #pragma unroll for(int i = 0; i < strideCount; i++) diff --git a/library/include/rocwmma/internal/pack_util_impl.hpp b/library/include/rocwmma/internal/pack_util_impl.hpp index dac20823..22866437 100644 --- a/library/include/rocwmma/internal/pack_util_impl.hpp +++ b/library/include/rocwmma/internal/pack_util_impl.hpp @@ -29,6 +29,7 @@ #include "pack_util.hpp" #include "types.hpp" #include "utils.hpp" +#include "vector_util.hpp" namespace rocwmma { @@ -104,6 +105,18 @@ namespace rocwmma using PackedT = int32_t; }; + template <> + struct PackTraits + { + enum : uint32_t + { + PackRatio = 1 // No pack + }; + + using UnpackedT = int64_t; + using PackedT = int64_t; + }; + template <> struct PackTraits { @@ -251,12 +264,13 @@ namespace rocwmma ROCWMMA_DEVICE /*static*/ inline auto& PackUtil::packHelper(VecT const& v) { - static_assert(VecSize % Traits::PackRatio == 0, "Use paddedPack32 instead."); + static_assert(VecSize % Traits::PackRatio == 0, + "Cannot pack partial b32 vector. Use paddedPack instead."); // NOTE: Assumes that there is NO padding... using PackedVecT = VecT; - using UnpackedVecT = std::decay_t; - return *reinterpret_cast(&(const_cast(v))); + using UnpackedVecT = decay_t; + return reinterpret_cast(v); } template @@ -264,10 +278,16 @@ namespace rocwmma ROCWMMA_DEVICE /*static*/ inline auto& PackUtil::unpackHelper(VecT const& v) { + if constexpr(is_same_v) + { + static_assert(Traits::PackRatio == 1, "Input vector must be packed"); + } + // NOTE: Assumes that there is NO padding... - using PackedVecT = std::decay_t; + using PackedVecT = decay_t; using UnpackedVecT = VecT; - return *reinterpret_cast(&(const_cast(v))); + + return reinterpret_cast(v); } template @@ -353,7 +373,8 @@ namespace rocwmma // Duplicate the inputs for padding else if constexpr((VecSize * 2u) == Traits::PackRatio) { - return packHelper(concat(v, v)); + // Make sure to return by value here as concat produces rval + return VecT(packHelper(concat(v, v))); } // Pad single element data to b32 else if constexpr(VecSize == 1u) @@ -375,7 +396,7 @@ namespace rocwmma // Take lower half of vector else if constexpr((UnpaddedSize * 2u) == Traits::PackRatio) { - return extractLo(v); + return extractLo(unpackHelper(v)); } // Pad single element data to b32 else if constexpr(UnpaddedSize == 1u) diff --git a/library/include/rocwmma/internal/permute.hpp b/library/include/rocwmma/internal/permute.hpp index 765d7670..3e258259 100644 --- a/library/include/rocwmma/internal/permute.hpp +++ b/library/include/rocwmma/internal/permute.hpp @@ -56,9 +56,9 @@ namespace rocwmma static_assert((PermuteOp::opId() == CrossLaneOps::Properties::OP_ID_BLOCK_BCAST) || (PermuteOp::opId() == CrossLaneOps::Properties::OP_ID_SHUFFLE) || (PermuteOp::opId() == CrossLaneOps::Properties::OP_ID_GATHER) - || (PermuteOp::opId() == CrossLaneOps::Properties::OP_ID_SCATTER), + || (PermuteOp::opId() == CrossLaneOps::Properties::OP_ID_SCATTER) + || (PermuteOp::opId() == CrossLaneOps::Properties::OP_ID_ROTATE), "PermuteOp is unsupported"); - template ROCWMMA_DEVICE static inline auto exec(DataT const& src) { diff --git a/library/include/rocwmma/internal/rocwmma_hip_f8_impl.h b/library/include/rocwmma/internal/rocwmma_hip_f8_impl.h index bea6f035..4ef850a7 100644 --- a/library/include/rocwmma/internal/rocwmma_hip_f8_impl.h +++ b/library/include/rocwmma/internal/rocwmma_hip_f8_impl.h @@ -27,8 +27,12 @@ #ifndef ROCWMMA_HIP_FP8_IMPL_H #define ROCWMMA_HIP_FP8_IMPL_H +#include "utility/type_traits.hpp" + namespace rocwmma_hip_f8_impl { + using rocwmma::is_same; + using rocwmma::conditional; ROCWMMA_HOST inline int clz(uint32_t x) { @@ -42,8 +46,8 @@ namespace rocwmma_hip_f8_impl template ROCWMMA_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) { - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = is_same::value; + constexpr bool is_float = is_same::value; static_assert(wm + we == 7, "wm+we==7"); static_assert(is_half || is_float, "Only half and float can be cast to f8"); @@ -239,8 +243,8 @@ namespace rocwmma_hip_f8_impl template ROCWMMA_HOST_DEVICE T cast_from_f8(uint8_t x) { - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = is_same::value; + constexpr bool is_float = is_same::value; static_assert(is_half || is_float, "only half and float are supported"); constexpr int weo = is_half ? 5 : 8; @@ -296,7 +300,7 @@ namespace rocwmma_hip_f8_impl return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; } } - typename std::conditional::type retval; + typename conditional::type retval; if(we == 5 && is_half && !negative_zero_nan) { retval = x << 8; diff --git a/library/include/rocwmma/internal/rocwmma_xfloat32.hpp b/library/include/rocwmma/internal/rocwmma_xfloat32.hpp index b5ad57c8..f47ac775 100644 --- a/library/include/rocwmma/internal/rocwmma_xfloat32.hpp +++ b/library/include/rocwmma/internal/rocwmma_xfloat32.hpp @@ -179,6 +179,7 @@ typedef struct float data; } rocwmma_xfloat32_public; +#if !defined(__HIPCC_RTC__) static_assert(std::is_standard_layout{}, "rocwmma_xfloat32 is not a standard layout type, and thus is " "incompatible with C."); @@ -187,7 +188,6 @@ static_assert(std::is_trivial{}, "rocwmma_xfloat32 is not a trivial type, and thus is " "incompatible with C."); -#if !defined(__HIPCC_RTC__) static_assert(sizeof(rocwmma_xfloat32) == sizeof(rocwmma_xfloat32_public) && offsetof(rocwmma_xfloat32, data) == offsetof(rocwmma_xfloat32_public, data), "internal rocwmma_xfloat32 does not match public rocwmma_xfloat32"); diff --git a/library/include/rocwmma/internal/transforms_impl.hpp b/library/include/rocwmma/internal/transforms_impl.hpp index e2762410..f234b045 100644 --- a/library/include/rocwmma/internal/transforms_impl.hpp +++ b/library/include/rocwmma/internal/transforms_impl.hpp @@ -33,6 +33,7 @@ #include "pack_util.hpp" #include "permute.hpp" #include "utils.hpp" +#include "vector_util.hpp" namespace rocwmma { @@ -148,6 +149,41 @@ namespace rocwmma return PackUtil::template paddedUnpack(concat(lo, hi)); } + template + ROCWMMA_DEVICE static inline auto unpackLoHi16(VecT const& v) + { + static_assert(VecSize % 2 == 0, "VecSize must be a multiple of 2"); + using PackUtil = PackUtil; + + auto lo = PackUtil::paddedPack(extractEven(v)); + auto hi = PackUtil::paddedPack(extractOdd(v)); + auto rot_lo = Swizzle::RotateR32<16>::exec(lo); + auto rot_hi = Swizzle::RotateR32<16>::exec(hi); + lo = Blend::Zip16::exec(lo, rot_hi); + hi = Blend::Zip16::exec(rot_lo, hi); + + return PackUtil::template paddedUnpack(concat(lo, hi)); + } + + // TODO: Wave64 only? + template + ROCWMMA_DEVICE static inline auto unpackLoHi32(VecT const& v) + { + static_assert(VecSize % 2 == 0, "VecSize must be a multiple of 2"); + using PackUtil = PackUtil; + + auto lo = PackUtil::paddedPack(extractEven(v)); + auto hi = PackUtil::paddedPack(extractOdd(v)); + + // TODO: label as rotateR64 for consistency? + auto rot_lo = Permute::RotateWaveR<32>::exec(lo); + auto rot_hi = Permute::RotateWaveR<32>::exec(hi); + lo = Blend::Zip32::exec(lo, rot_hi); + hi = Blend::Zip32::exec(rot_lo, hi); + + return PackUtil::template paddedUnpack(concat(lo, hi)); + } + template ROCWMMA_DEVICE static inline auto aos_soa_16xk_b32(VecT const& v) { @@ -194,8 +230,8 @@ namespace rocwmma // In order to save some operations, we can // rotate the odds components only and make up the // offset later in gather. - auto evens = PackUtil::paddedPack(extractEven(v)); - auto odds = PackUtil::paddedPack(extractOdd(v)); + auto evens = PackUtil::paddedPack(extractEven(result)); + auto odds = PackUtil::paddedPack(extractOdd(result)); auto rot = Swizzle::RotateR32<16>::exec(odds); auto lo = Blend::Zip16::exec(evens, rot); @@ -224,7 +260,32 @@ namespace rocwmma template ROCWMMA_DEVICE static inline auto aos_soa_64xk_b32(VecT const& v) { - return 0; + using PackUtil = PackUtil; + + // Step 1 : Unpack groups of 8 + auto result = unpackLoHi8(v); + + // Step 2 : Unpack groups of 16 + result = unpackLoHi16(result); + + // Step 3 : Unpack groups of 32 + // In order to save some operations, we can + // rotate the odds components only and make up the + // offset later in gather. + auto lo = PackUtil::paddedPack(extractEven(result)); + auto hi = PackUtil::paddedPack(extractOdd(result)); + + // TODO: label as rotateR64 for consistency? + auto rot_hi = Permute::RotateWaveR<32>::exec(hi); + hi = Blend::Zip32::exec(rot_hi, lo); + lo = Blend::Zip32::exec(lo, rot_hi); + + // Step 4 : Gather + // Note the offset of 32 in hi + lo = Permute::GatherWave<8, 0>::exec(lo); + hi = Permute::GatherWave<8, 32>::exec(hi); + + return PackUtil::template paddedUnpack<8>(concat(lo, hi)); } template @@ -275,6 +336,197 @@ namespace rocwmma return 0; } + template + struct AosToSoa; + + template <> + struct AosToSoa<16, 8> + { + constexpr static uint32_t VW = 8; + constexpr static uint32_t VecSize = 8; + + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + using PackUtil = PackUtil; + + // Step 1 : Unpack groups of 2 + auto result = unpackLoHi2(v); + + // Step 2 : Unpack groups of 4 + result = unpackLoHi4(result); + + // Step 3 : Unpack groups of 8 + result = unpackLoHi8(result); + + // Step 4 : Gather + return PackUtil::template paddedUnpack( + Permute::Gather16::exec(PackUtil::paddedPack(result))); + } + }; + + template <> + struct AosToSoa<32, 8> + { + constexpr static uint32_t VW = 8; + constexpr static uint32_t VecSize = 8; + + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + using PackUtil = PackUtil; + + // Step 1 : Unpack groups of 4 + auto result = unpackLoHi4(v); + + // Step 2 : Unpack groups of 8 + result = unpackLoHi8(result); + + // Step 3 : Unpack groups of 16 + // In order to save some operations, we can + // rotate the odds components only and make up the + // offset later in gather. + auto evens = PackUtil::paddedPack(extractEven(result)); + auto odds = PackUtil::paddedPack(extractOdd(result)); + + auto rot = Swizzle::RotateR32<16>::exec(odds); + auto lo = Blend::Zip16::exec(evens, rot); + auto hi = Blend::Zip16::exec(rot, evens); + + // Step 4 : Gather + // Note the offset of 16 in hi + lo = Permute::Gather32::exec(lo); + hi = Permute::Gather32::exec(hi); + + return PackUtil::template paddedUnpack(concat(lo, hi)); + } + }; + + template <> + struct AosToSoa<64, 8> + { + constexpr static uint32_t VW = 8; + constexpr static uint32_t VecSize = 8; + + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + using PackUtil = PackUtil; + + // Step 1 : Unpack groups of 8 + auto result = unpackLoHi8(v); + + // Step 2 : Unpack groups of 16 + result = unpackLoHi16(result); + + // Step 3 : Unpack groups of 32 + // In order to save some operations, we can + // rotate the odds components only and make up the + // offset later in gather. + auto lo = PackUtil::paddedPack(extractEven(result)); + auto hi = PackUtil::paddedPack(extractOdd(result)); + + // TODO: label as rotateR64 for consistency? + auto rot_hi = Permute::RotateWaveR<32>::exec(hi); + hi = Blend::Zip32::exec(rot_hi, lo); + lo = Blend::Zip32::exec(lo, rot_hi); + + // Step 4 : Gather + // Note the offset of 32 in hi + lo = Permute::GatherWave::exec(lo); + hi = Permute::GatherWave::exec(hi); + + return PackUtil::template paddedUnpack(concat(lo, hi)); + } + }; + + template <> + struct AosToSoa<128, 8> + { + constexpr static uint32_t VW = 8; + constexpr static uint32_t VecSize = 16; + + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + using PackUtil = PackUtil; + + // Data comes in as AOS format: + // There are TWO sets of VW = 8 registers (because this case BlockDim / 64 = 2): + // 1. Vecs 0-7 + // 2. Vecs 8-15 + // + // Register/ | VW = 8 | + // Tidx |___0___|___1___|___...___|___7___| + // 0 | 0 | 1 | ... | 7 | + // 1 | 8 | 9 | ... | 15 | + // ... | ... | ... | ... | ... | + // 63 |__504__|__505__|___...___|__511__| + // + // Register/ | VW = 8 | + // Tidx |___8___|___9___|___...___|___15__| + // 0 | 512 | 513 | ... | 519 | + // 1 | 520 | 521 | ... | 527 | + // ... | ... | ... | ... | ... | + // 63 |__1016_|__1017_|___...___|__1023_| + + // For each batch of VW registers + auto v0 = extractLo(v); + auto v1 = extractHi(v); + + // Step 1 : Unpack groups of 8 + auto r0 = unpackLoHi8(v0); + auto r1 = unpackLoHi8(v1); + + // Step 2 : isolate data for upper 64 dim from lower 64 dim + v0 = concat(extractLo(r0), extractLo(r1)); + v1 = concat(extractHi(r0), extractHi(r1)); + + // Continue from here as if r0 and r1 are independent 64 dim. + + // Step 3 : Unpack groups of 16 + v0 = unpackLoHi16(v0); + v1 = unpackLoHi16(v1); + + // Step 4 : Unpack groups of 32 + // In order to save some operations, we can + // rotate the odds components only and make up the + // offset later in gather. + auto lo0 = PackUtil::paddedPack(extractEven(v0)); + auto hi0 = PackUtil::paddedPack(extractOdd(v0)); + + auto lo1 = PackUtil::paddedPack(extractEven(v1)); + auto hi1 = PackUtil::paddedPack(extractOdd(v1)); + + // TODO: label as rotateR64 for consistency? + auto rot_hi0 = Permute::RotateWaveR<32>::exec(hi0); + hi0 = Blend::Zip32::exec(rot_hi0, lo0); + lo0 = Blend::Zip32::exec(lo0, rot_hi0); + + auto rot_hi1 = Permute::RotateWaveR<32>::exec(hi1); + hi1 = Blend::Zip32::exec(rot_hi1, lo1); + lo1 = Blend::Zip32::exec(lo1, rot_hi1); + + // Step 5 : Gather + // Note the offset of 32 in hi + lo0 = Permute::GatherWave::exec(lo0); + hi0 = Permute::GatherWave::exec(hi0); + + lo1 = Permute::GatherWave::exec(lo1); + hi1 = Permute::GatherWave::exec(hi1); + + // Step 6 : Unpack and re-order. + auto c0 = PackUtil::template paddedUnpack(concat(lo0, hi0)); + //c0 = reorderEvenOdd(c0); + c0 = concat(extractEven(c0), extractOdd(c0)); + auto c1 = PackUtil::template paddedUnpack(concat(lo1, hi1)); + //c1 = reorderEvenOdd(c1); + c1 = concat(extractEven(c1), extractOdd(c1)); + + return concat(c0, c1); + } + }; + // SOA -> AOS // Transform from ortho VW to inline VW template diff --git a/library/include/rocwmma/internal/tuple.hpp b/library/include/rocwmma/internal/tuple.hpp index ca9c1c48..54fafc3b 100644 --- a/library/include/rocwmma/internal/tuple.hpp +++ b/library/include/rocwmma/internal/tuple.hpp @@ -33,13 +33,12 @@ #endif // !defined(__HIPCC_RTC__) +#include "utility/forward.hpp" +#include "utility/sequence.hpp" #include "utils.hpp" namespace rocwmma { - using detail::index_sequence; - using detail::make_index_sequence; - template ROCWMMA_HOST_DEVICE inline constexpr non_native_vector_base operator+(non_native_vector_base const& x, U y) noexcept @@ -98,64 +97,63 @@ namespace rocwmma namespace detail { - template + template constexpr static auto copy_impl(VecT&& t, index_sequence&&) { - return make_vector(std::get(std::forward(t))...); + return make_vector(get(forward(t))...); } } template constexpr static auto pop_right(VecT&& t) { - return detail::copy_impl(std::forward(t), - make_index_sequence>::size() - 1>{}); + return detail::copy_impl(forward(t), + make_index_sequence>::size() - 1>{}); } template constexpr static auto pop_left(VecT&& t) { auto pop_front = [](auto front, auto... rest) { return make_vector(rest...); }; - return apply(pop_front, std::forward(t)); + return apply(pop_front, forward(t)); } template constexpr static decltype(auto) get_first(VecT&& t) { - return std::get<0>(std::forward(t)); + return get<0>(forward(t)); } template constexpr static decltype(auto) get_last(VecT&& t) { - return std::get>::size() - 1u>(std::forward(t)); + return get>::size() - 1u>(forward(t)); } namespace detail { - template + template constexpr static decltype(auto) reverse_impl(VecT&& t, index_sequence) { - return make_vector( - std::get(std::forward(t))...); + return make_vector(get(forward(t))...); } } template constexpr static decltype(auto) reverse(VecT&& t) { - return detail::reverse_impl(std::forward(t), - make_index_sequence>::size()>{}); + return detail::reverse_impl(forward(t), + make_index_sequence>::size()>{}); } namespace detail { - template + template constexpr static decltype(auto) flatten_coord_right_impl(Vec0&& coord, Vec1&& dims, index_sequence) { - static_assert(VecTraits>::size() == sizeof...(Indices) - && VecTraits>::size() == sizeof...(Indices), + static_assert(VecTraits>::size() == sizeof...(Indices) + && VecTraits>::size() == sizeof...(Indices), "coord and dims vectors must be the same size"); auto flatten = [](auto&& c, auto&& d, auto& mul) { @@ -164,10 +162,10 @@ namespace rocwmma return result; }; - auto mult = typename VecTraits>::DataT{1}; - return (flatten(std::get(std::forward(coord)), - std::get(std::forward(dims)), - std::forward(mult)) + auto mult = typename VecTraits>::DataT{1}; + return (flatten(get(forward(coord)), + get(forward(dims)), + forward(mult)) + ...); } } @@ -176,19 +174,19 @@ namespace rocwmma constexpr static decltype(auto) flatten_coord_right(Vec0&& coord, Vec1&& dims) { return detail::flatten_coord_right_impl( - std::forward(coord), - std::forward(dims), - make_index_sequence>::size()>{}); + forward(coord), + forward(dims), + make_index_sequence>::size()>{}); } namespace detail { - template + template constexpr static decltype(auto) flatten_coord_left_impl(Vec0&& coord, Vec1&& dims, index_sequence) { - static_assert(VecTraits>::size() == sizeof...(Indices) - && VecTraits>::size() == sizeof...(Indices), + static_assert(VecTraits>::size() == sizeof...(Indices) + && VecTraits>::size() == sizeof...(Indices), "coord and dims vectors must be the same size"); auto flatten = [](auto&& c, auto&& d, auto& mul) { @@ -197,10 +195,10 @@ namespace rocwmma return result; }; - auto mult = typename VecTraits>::DataT{1}; - return (flatten(std::get(std::forward(coord)), - std::get(std::forward(dims)), - std::forward(mult)) + auto mult = typename VecTraits>::DataT{1}; + return (flatten(get(forward(coord)), + get(forward(dims)), + forward(mult)) + ...); } } @@ -209,14 +207,14 @@ namespace rocwmma constexpr static decltype(auto) flatten_coord_left(Vec0&& coord, Vec1&& dims) { return detail::flatten_coord_left_impl( - std::forward(coord), - std::forward(dims), - make_index_sequence>::size()>{}); + forward(coord), + forward(dims), + make_index_sequence>::size()>{}); } namespace detail { - template + template constexpr static inline decltype(auto) inflate_coord_right_impl(Coord1d&& flatCoord, VecT&& dims, index_sequence) { @@ -226,10 +224,10 @@ namespace rocwmma return result; }; - auto div = std::decay_t{1}; - return make_vector(inflate(std::forward(flatCoord), - std::get(std::forward(dims)), - std::forward(div), + auto div = decay_t{1}; + return make_vector(inflate(forward(flatCoord), + get(forward(dims)), + forward(div), Indices == sizeof...(Indices) - 1)...); } } @@ -238,14 +236,14 @@ namespace rocwmma constexpr static inline decltype(auto) inflate_coord_right(Coord1d&& flatCoord, VecT&& dims) { return detail::inflate_coord_right_impl( - std::forward(flatCoord), - std::forward(dims), - make_index_sequence>::size()>{}); + forward(flatCoord), + forward(dims), + make_index_sequence>::size()>{}); } namespace detail { - template + template constexpr static inline decltype(auto) inflate_coord_left_impl(Coord1d&& flatCoord, VecT&& dims, index_sequence) { @@ -255,13 +253,12 @@ namespace rocwmma return result; }; - auto div = std::decay_t{1}; - return reverse( - make_vector(inflate(std::forward(flatCoord), - std::get>::size() - 1 - Indices>( - std::forward(dims)), - std::forward(div), - Indices == sizeof...(Indices) - 1)...)); + auto div = decay_t{1}; + return reverse(make_vector( + inflate(forward(flatCoord), + get>::size() - 1 - Indices>(forward(dims)), + forward(div), + Indices == sizeof...(Indices) - 1)...)); } } @@ -269,25 +266,25 @@ namespace rocwmma constexpr static inline decltype(auto) inflate_coord_left(Coord1d&& flatCoord, VecT&& dims) { return detail::inflate_coord_left_impl( - std::forward(flatCoord), - std::forward(dims), - make_index_sequence>::size()>{}); + forward(flatCoord), + forward(dims), + make_index_sequence>::size()>{}); } namespace detail { - template + template constexpr static inline decltype(auto) to_matrix_space_impl(Vec0&& strides, Vec1&& strideSpace, index_sequence) { - static_assert(VecTraits>::size() == sizeof...(Indices) - && VecTraits>::size() == sizeof...(Indices), + static_assert(VecTraits>::size() == sizeof...(Indices) + && VecTraits>::size() == sizeof...(Indices), "strides and strideSpace vectors must be the same size"); auto inflate = [](auto&& stride, auto&& dim) { return stride * dim; }; - return (inflate(std::get(std::forward(strides)), - std::get(std::forward(strideSpace))) + return (inflate(get(forward(strides)), + get(forward(strideSpace))) + ...); } } @@ -296,25 +293,25 @@ namespace rocwmma constexpr static inline decltype(auto) to_matrix_space(Vec0&& strides, Vec1&& strideSpace) { return detail::to_matrix_space_impl( - std::forward(strides), - std::forward(strideSpace), - make_index_sequence>::size()>{}); + forward(strides), + forward(strideSpace), + make_index_sequence>::size()>{}); } #if !defined(__HIPCC_RTC__) template - auto& print(std::ostream& os, T&& t, std::index_sequence&&) + auto& print(std::ostream& os, T&& t, index_sequence&&) { os << "("; - (..., (os << (I == 0 ? "" : ", ") << std::get(std::forward(t)))); + (..., (os << (I == 0 ? "" : ", ") << get(forward(t)))); return os << ")\n"; } template auto& print(std::ostream& os, std::tuple const& t) { - return print(os, t, std::make_index_sequence()); + return print(os, t, make_index_sequence()); } #endif // !defined(__HIPCC_RTC__) diff --git a/library/include/rocwmma/internal/type_traits.hpp b/library/include/rocwmma/internal/type_traits.hpp index 8f1e7f08..404cea52 100644 --- a/library/include/rocwmma/internal/type_traits.hpp +++ b/library/include/rocwmma/internal/type_traits.hpp @@ -124,356 +124,22 @@ namespace rocwmma } // namespace detail } // namespace rocwmma -/////////////////////////////////////////////////////////// -///////////// std replacements for hipRTC /////////////// -/////////////////////////////////////////////////////////// -#if defined(__HIPCC_RTC__) -namespace std -{ - template - class numeric_limits - { - public: - ROCWMMA_HOST_DEVICE static constexpr T min() noexcept; - ROCWMMA_HOST_DEVICE static constexpr T lowest() noexcept; - ROCWMMA_HOST_DEVICE static constexpr T max() noexcept; - ROCWMMA_HOST_DEVICE static constexpr T epsilon() noexcept; - ROCWMMA_HOST_DEVICE static constexpr T round_error() noexcept; - ROCWMMA_HOST_DEVICE static constexpr T infinity() noexcept; - ROCWMMA_HOST_DEVICE static constexpr T quiet_NaN() noexcept; - ROCWMMA_HOST_DEVICE static constexpr T signaling_NaN() noexcept; - ROCWMMA_HOST_DEVICE static constexpr T denorm_min() noexcept; - }; - - template - using enable_if_t = typename enable_if::type; - - template - struct conditional - { - }; - - template - struct conditional - { - using type = T; - }; - - template - struct conditional - { - using type = F; - }; - - template - using conditional_t = typename conditional::type; - - template - ROCWMMA_HOST_DEVICE constexpr const T& max(const T& a, const T& b) - { - return (a < b) ? b : a; - } - - template - ROCWMMA_HOST_DEVICE constexpr const T& min(const T& a, const T& b) - { - return (a < b) ? a : b; - } - - // Meta programming helper types. - - template - struct conditional; - - template - struct __or_; - - template <> - struct __or_<> : public false_type - { - }; - - template - struct __or_<_B1> : public _B1 - { - }; - - template - struct __or_<_B1, _B2> : public conditional<_B1::value, _B1, _B2>::type - { - }; - - template - struct __or_<_B1, _B2, _B3, _Bn...> - : public conditional<_B1::value, _B1, __or_<_B2, _B3, _Bn...>>::type - { - }; - - template - struct __and_; - - template <> - struct __and_<> : public true_type - { - }; - - template - struct __and_<_B1> : public _B1 - { - }; - - template - struct __and_<_B1, _B2> : public conditional<_B1::value, _B2, _B1>::type - { - }; - - template - struct __and_<_B1, _B2, _B3, _Bn...> - : public conditional<_B1::value, __and_<_B2, _B3, _Bn...>, _B1>::type - { - }; - - template - using __bool_constant = integral_constant; - - template - struct __not_ : public __bool_constant - { - }; - - // remove_reference - template - struct remove_reference - { - typedef T type; - }; - - template - struct remove_reference - { - typedef T type; - }; - - template - struct remove_reference - { - typedef T type; - }; - - // is_lvalue_reference - template - struct is_lvalue_reference : public false_type - { - }; - - template - struct is_lvalue_reference : public true_type - { - }; - - // is_rvalue_reference - template - struct is_rvalue_reference : public false_type - { - }; - - template - struct is_rvalue_reference : public true_type - { - }; - - // lvalue forwarding - template - constexpr T&& forward(typename remove_reference::type& __t) noexcept - { - return static_cast(__t); - } - - // rvalue forwarding - template - constexpr T&& forward(typename remove_reference::type&& __t) noexcept - { - static_assert(!is_lvalue_reference::value, - "template argument" - " substituting T is an lvalue reference type"); - return static_cast(__t); - } - - // remove_const - template - struct remove_const - { - typedef T type; - }; - - template - struct remove_const - { - typedef T type; - }; - - // remove_volatile - template - struct remove_volatile - { - typedef T type; - }; - - template - struct remove_volatile - { - typedef T type; - }; - - // remove_cv - template - struct remove_cv - { - typedef typename remove_const::type>::type type; - }; - - // remove_extent - template - struct remove_extent - { - typedef T type; - }; - - template - struct remove_extent - { - typedef T type; - }; - - template - struct remove_extent - { - typedef T type; - }; - - // is_void - template - struct __is_void_helper : public false_type - { - }; - - template <> - struct __is_void_helper : public true_type - { - }; - - template - struct is_void : public __is_void_helper::type>::type - { - }; - - // is_reference - template - struct is_reference : public __or_, is_rvalue_reference>::type - { - }; - - // is_function - template - struct is_function : public false_type - { - }; - - // is_object - template - struct is_object : public __not_<__or_, is_reference, is_void>>::type - { - }; - - // __is_referenceable - template - struct __is_referenceable : public __or_, is_reference>::type{}; - - // add_pointer - template , is_void>::value> - struct __add_pointer_helper - { - typedef T type; - }; - - template - struct __add_pointer_helper - { - typedef typename remove_reference::type* type; - }; - - template - struct add_pointer : public __add_pointer_helper - { - }; - - // is_array - template - struct is_array : public false_type - { - }; - - template - struct is_array : public true_type - { - }; - - template - struct is_array : public true_type - { - }; - - // decay selectors - template ::value, - bool _IsFunction = is_function<_Up>::value> - struct __decay_selector; +#include "utility/numeric_limits.hpp" - template - struct __decay_selector<_Up, false, false> - { - typedef typename remove_cv<_Up>::type __type; - }; - - template - struct __decay_selector<_Up, true, false> - { - typedef typename remove_extent<_Up>::type* __type; - }; - - template - struct __decay_selector<_Up, false, true> - { - typedef typename add_pointer<_Up>::type __type; - }; - - // decay - template - class decay - { - typedef typename remove_reference::type __remove_type; - - public: - typedef typename __decay_selector<__remove_type>::__type type; - }; - - template - using decay_t = typename decay::type; - - template - inline constexpr bool is_same_v = is_same::value; - -} // namespace std +#if defined(__HIPCC_RTC__) +#define NUMERIC_LIMITS_NAMESPACE rocwmma::detail +#else +#define NUMERIC_LIMITS_NAMESPACE std #endif -namespace std +namespace NUMERIC_LIMITS_NAMESPACE { #if defined(__HIPCC_RTC__) using uint16_t = rocwmma::uint16_t; #endif /////////////////////////////////////////////////////////// - /////////// std::numeric_limits ////////////// + /////////// numeric_limits ////////////// /////////////////////////////////////////////////////////// // @cond template <> @@ -533,7 +199,7 @@ namespace std } /////////////////////////////////////////////////////////// - /////////// std::numeric_limits ////////////// + /////////// numeric_limits ////////////// /////////////////////////////////////////////////////////// template <> @@ -593,7 +259,7 @@ namespace std } /////////////////////////////////////////////////////////// - /////////// std::numeric_limits ////////////// + /////////// numeric_limits ////////////// /////////////////////////////////////////////////////////// template <> @@ -653,7 +319,7 @@ namespace std } /////////////////////////////////////////////////////////// - /////////// std::numeric_limits ///////////// + /////////// numeric_limits ///////////// /////////////////////////////////////////////////////////// #if !ROCWMMA_NO_HALF template <> @@ -715,7 +381,7 @@ namespace std #endif // !ROCWMMA_NO_HALF /////////////////////////////////////////////////////////// - /////////// std::numeric_limits ///////////// + /////////// numeric_limits ///////////// /////////////////////////////////////////////////////////// template <> @@ -775,7 +441,7 @@ namespace std } /////////////////////////////////////////////////////////// - /////////// std::numeric_limits ////////////// + /////////// numeric_limits ////////////// /////////////////////////////////////////////////////////// template <> @@ -835,36 +501,32 @@ namespace std } // @endcond -} // namespace std +} // namespace rocwmma namespace rocwmma { #if !defined(__HIPCC_RTC__) - template ::value, int> = 0> - constexpr auto maxExactInteger() -> decltype(std::numeric_limits::max()) + template ::value, int> = 0> + constexpr auto maxExactInteger() -> decltype(numeric_limits::max()) { - return std::numeric_limits::max(); + return numeric_limits::max(); } template ::value - && std::numeric_limits::digits, - int> - = 0> - constexpr auto maxExactInteger() -> - typename std::conditional_t::value, int64_t, int32_t> + enable_if_t::value && numeric_limits::digits, int> = 0> + constexpr auto maxExactInteger() + -> conditional_t::value, int64_t, int32_t> { - using RetT = - typename std::conditional_t::value, int64_t, int32_t>; - return ((RetT)1 << std::numeric_limits::digits); + using RetT = conditional_t::value, int64_t, int32_t>; + return ((RetT)1 << numeric_limits::digits); } template ::value || + is_same::value || #endif // !ROCWMMA_NO_HALF - std::is_same::value, + is_same::value, int> = 0> constexpr auto maxExactInteger() -> int32_t @@ -873,28 +535,28 @@ namespace rocwmma return ((int32_t)1 << 11); } - template ::value, int> = 0> + template ::value, int> = 0> constexpr auto maxExactInteger() -> int32_t { // b16 mantissa is 7 bits return ((int32_t)1 << 8); } - template ::value, int> = 0> + template ::value, int> = 0> constexpr auto maxExactInteger() -> int32_t { // f8 mantissa is 3 bits return ((int32_t)1 << 4); } - template ::value, int> = 0> + template ::value, int> = 0> constexpr auto maxExactInteger() -> int32_t { // bf8 mantissa is 2 bits return ((int32_t)1 << 3); } - template ::value, int> = 0> + template ::value, int> = 0> constexpr auto maxExactInteger() -> int32_t { // xf32 mantissa is 7 bits diff --git a/library/include/rocwmma/internal/types_ext.hpp b/library/include/rocwmma/internal/types_ext.hpp index 692d3ba8..e5d4771e 100644 --- a/library/include/rocwmma/internal/types_ext.hpp +++ b/library/include/rocwmma/internal/types_ext.hpp @@ -48,11 +48,11 @@ namespace rocwmma //////////////////////////////////////////////////////////////////////// template , int> = 0> + enable_if_t, int> = 0> __host__ __device__ inline Outgoing convert(const Incoming& value) { #if !ROCWMMA_NO_HALF - if constexpr(std::is_same_v) + if constexpr(is_same_v) { #if defined(__HIP_NO_HALF_CONVERSIONS__) @@ -62,7 +62,7 @@ namespace rocwmma return static_cast(value); #endif // defined(__HIP_NO_HALF_CONVERSIONS__) } - else if constexpr(std::is_same_v) + else if constexpr(is_same_v) { #if defined(__HIP_NO_HALF_CONVERSIONS__) @@ -81,7 +81,7 @@ namespace rocwmma template , int> = 0> + enable_if_t, int> = 0> __host__ __device__ inline Outgoing const& convert(const Incoming& value) { return value; @@ -105,8 +105,8 @@ namespace rocwmma { auto absDiff = std::fabs(__half2float(x) - __half2float(y)); auto absAdd = std::fabs(__half2float(x) + __half2float(y)); - return absDiff <= __half2float(std::numeric_limits::epsilon()) * absAdd * 2.0f - || absDiff < __half2float(std::numeric_limits::min()); + return absDiff <= __half2float(numeric_limits::epsilon()) * absAdd * 2.0f + || absDiff < __half2float(numeric_limits::min()); } ROCWMMA_HALF_OP_ATTR inline bool operator!=(const hfloat16_t& x, const hfloat16_t& y) diff --git a/library/include/rocwmma/internal/utility/forward.hpp b/library/include/rocwmma/internal/utility/forward.hpp new file mode 100644 index 00000000..7dce49ef --- /dev/null +++ b/library/include/rocwmma/internal/utility/forward.hpp @@ -0,0 +1,52 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_FORWARD_HPP +#define ROCWMMA_UTILITY_FORWARD_HPP + +#if defined(__HIPCC_RTC__) || defined(__clang__) + +#include "forward_impl.hpp" +namespace rocwmma +{ + // Use drop-in replacement + using detail::forward; + +} // namespace rocwmma + +#else + +#include +namespace rocwmma +{ + // Use STL + using std::forward; + +} // namespace rocwmma + +#endif // defined(__HIPCC_RTC__) || defined(__clang__) + +#endif // ROCWMMA_UTILITY_FORWARD_HPP diff --git a/library/include/rocwmma/internal/utility/forward_impl.hpp b/library/include/rocwmma/internal/utility/forward_impl.hpp new file mode 100644 index 00000000..59d69b02 --- /dev/null +++ b/library/include/rocwmma/internal/utility/forward_impl.hpp @@ -0,0 +1,55 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_FORWARD_IMPL_HPP +#define ROCWMMA_UTILITY_FORWARD_IMPL_HPP + +#include "type_traits.hpp" + +namespace rocwmma +{ + namespace detail + { + + template + ROCWMMA_HOST_DEVICE constexpr T&& forward(typename remove_reference::type& t) noexcept + { + return static_cast(t); + } + + template + ROCWMMA_HOST_DEVICE constexpr T&& forward(typename remove_reference::type&& t) noexcept + { + static_assert(!is_lvalue_reference::value, + "template argument substituting T is an lvalue reference type"); + return static_cast(t); + } + + } // namespace detail + +} // namespace rocwmma + +#endif // ROCWMMA_UTILITY_FORWARD_IMPL_HPP diff --git a/library/include/rocwmma/internal/utility/get.hpp b/library/include/rocwmma/internal/utility/get.hpp new file mode 100644 index 00000000..773794a0 --- /dev/null +++ b/library/include/rocwmma/internal/utility/get.hpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_GET_HPP +#define ROCWMMA_UTILITY_GET_HPP + +#include "get_impl.hpp" + +namespace rocwmma +{ + // get overloads + using detail::get; +} + +#if !defined(__HIPCC_RTC__) + +#include +namespace rocwmma +{ + // Use STL + using std::get; + +} // namespace rocwmma + +#endif // !defined(__HIPCC_RTC__) + +#endif // ROCWMMA_UTILITY_GET_HPP diff --git a/library/include/rocwmma/internal/utility/get_impl.hpp b/library/include/rocwmma/internal/utility/get_impl.hpp new file mode 100644 index 00000000..88e833cb --- /dev/null +++ b/library/include/rocwmma/internal/utility/get_impl.hpp @@ -0,0 +1,81 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_GET_IMPL_HPP +#define ROCWMMA_UTILITY_GET_IMPL_HPP + +#include "../vector.hpp" + +namespace rocwmma +{ + namespace detail + { + // HIP_vector_type overloads + template + ROCWMMA_HOST_DEVICE constexpr inline DataT get(HIP_vector_type&& v) + { + return v.data[Idx]; + } + + template + ROCWMMA_HOST_DEVICE constexpr inline DataT& get(HIP_vector_type& v) + { + return reinterpret_cast(&v.data)[Idx]; + } + + template + ROCWMMA_HOST_DEVICE constexpr inline DataT get(HIP_vector_type const& v) + { + return v.data[Idx]; + } + + // non_native_vector_base overloads + template + ROCWMMA_HOST_DEVICE constexpr static inline DataT + get(non_native_vector_base&& v) + { + return v[Idx]; + } + + template + ROCWMMA_HOST_DEVICE constexpr static inline DataT& + get(non_native_vector_base& v) + { + return v[Idx]; + } + + template + ROCWMMA_HOST_DEVICE constexpr static inline DataT + get(non_native_vector_base const& v) + { + return v[Idx]; + } + + } // namespace detail + +} // namespace rocwmma + +#endif // ROCWMMA_UTILITY_GET_IMPL_HPP diff --git a/library/include/rocwmma/internal/utility/numeric_limits.hpp b/library/include/rocwmma/internal/utility/numeric_limits.hpp new file mode 100644 index 00000000..76a96c08 --- /dev/null +++ b/library/include/rocwmma/internal/utility/numeric_limits.hpp @@ -0,0 +1,52 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_NUMERIC_LIMITS_HPP +#define ROCWMMA_UTILITY_NUMERIC_LIMITS_HPP + +#if defined(__HIPCC_RTC__) + +#include "numeric_limits_impl.hpp" +namespace rocwmma +{ + // Use drop-in replacement + using detail::numeric_limits; + +} // namespace rocwmma + +#else + +#include +namespace rocwmma +{ + // Use STL + using std::numeric_limits; + +} // namespace rocwmma + +#endif // defined(__HIPCC_RTC__) + +#endif // ROCWMMA_UTILITY_NUMERIC_LIMITS_HPP diff --git a/library/include/rocwmma/internal/utility/numeric_limits_impl.hpp b/library/include/rocwmma/internal/utility/numeric_limits_impl.hpp new file mode 100644 index 00000000..0349bf14 --- /dev/null +++ b/library/include/rocwmma/internal/utility/numeric_limits_impl.hpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_NUMERIC_LIMITS_IMPL_HPP +#define ROCWMMA_UTILITY_NUMERIC_LIMITS_IMPL_HPP + +namespace rocwmma +{ + namespace detail + { + // Currently does not have implementation as there is no current + // library needs for regular arithmetic types. + // Specializations do exist for f8, bf8 and xf32 types where they + // are currently defined. + template + class numeric_limits + { + public: + ROCWMMA_HOST_DEVICE static constexpr T min() noexcept; + ROCWMMA_HOST_DEVICE static constexpr T lowest() noexcept; + ROCWMMA_HOST_DEVICE static constexpr T max() noexcept; + ROCWMMA_HOST_DEVICE static constexpr T epsilon() noexcept; + ROCWMMA_HOST_DEVICE static constexpr T round_error() noexcept; + ROCWMMA_HOST_DEVICE static constexpr T infinity() noexcept; + ROCWMMA_HOST_DEVICE static constexpr T quiet_NaN() noexcept; + ROCWMMA_HOST_DEVICE static constexpr T signaling_NaN() noexcept; + ROCWMMA_HOST_DEVICE static constexpr T denorm_min() noexcept; + }; + + } // namespace detail + +} // namespace rocwmma + +#endif // ROCWMMA_UTILITY_NUMERIC_LIMITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/utility/sequence.hpp b/library/include/rocwmma/internal/utility/sequence.hpp new file mode 100644 index 00000000..de5b26e2 --- /dev/null +++ b/library/include/rocwmma/internal/utility/sequence.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_SEQUENCE_HPP +#define ROCWMMA_UTILITY_SEQUENCE_HPP + +#include "sequence_impl.hpp" +namespace rocwmma +{ + using detail::index_sequence; + using detail::make_index_sequence; +} // namespace rocwmma + +#endif // ROCWMMA_UTILITY_SEQUENCE_HPP diff --git a/library/include/rocwmma/internal/utility/sequence_impl.hpp b/library/include/rocwmma/internal/utility/sequence_impl.hpp new file mode 100644 index 00000000..10b16460 --- /dev/null +++ b/library/include/rocwmma/internal/utility/sequence_impl.hpp @@ -0,0 +1,98 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_SEQUENCE_IMPL_HPP +#define ROCWMMA_UTILITY_SEQUENCE_IMPL_HPP + +#include "type_traits.hpp" + +namespace rocwmma +{ + namespace detail + { + template + struct integer_sequence + { + using value_type = Int; + constexpr integer_sequence() {} + static constexpr size_t size() noexcept + { + return sizeof...(Ints); + } + }; + + template + using index_sequence = integer_sequence; + + namespace + { + // Merge two integer sequences, adding an offset to the right-hand side. + template + struct merge; + + template + struct merge, + integer_sequence, + integer_sequence> + { + using type = integer_sequence; + }; + + template + struct log_make_sequence + { + using L = integral_constant; + using R = integral_constant; + using type = typename merge::type, + typename log_make_sequence::type>::type; + }; + + // An empty sequence. + template + struct log_make_sequence> + { + using type = integer_sequence; + }; + + // A single-element sequence. + template + struct log_make_sequence> + { + using type = integer_sequence; + }; + } + + template + using make_integer_sequence = + typename log_make_sequence>::type; + + template + using make_index_sequence = make_integer_sequence; + } // namespace detail +} // namespace rocwmma + +#endif // ROCWMMA_UTILITY_SEQUENCE_IMPL_HPP diff --git a/library/include/rocwmma/internal/utility/type_traits.hpp b/library/include/rocwmma/internal/utility/type_traits.hpp new file mode 100644 index 00000000..fdeb5f9d --- /dev/null +++ b/library/include/rocwmma/internal/utility/type_traits.hpp @@ -0,0 +1,148 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_TYPE_TRAITS_HPP +#define ROCWMMA_UTILITY_TYPE_TRAITS_HPP + +#if defined(__HIPCC_RTC__) + +#include "type_traits_impl.hpp" +namespace rocwmma +{ + // Use drop-in replacement + using detail::add_pointer; + using detail::add_pointer_t; + using detail::bool_constant; + using detail::conditional; + using detail::conditional_t; + using detail::decay; + using detail::decay_t; + using detail::enable_if; + using detail::enable_if_t; + using detail::false_type; + using detail::integral_constant; + using detail::is_arithmetic; + using detail::is_arithmetic_v; + using detail::is_array; + using detail::is_array_v; + using detail::is_convertible; + using detail::is_convertible_v; + using detail::is_floating_point; + using detail::is_floating_point_v; + using detail::is_function; + using detail::is_function_v; + using detail::is_integral; + using detail::is_integral_v; + using detail::is_lvalue_reference; + using detail::is_lvalue_reference_v; + using detail::is_reference; + using detail::is_reference_v; + using detail::is_rvalue_reference; + using detail::is_rvalue_reference_v; + using detail::is_same; + using detail::is_same_v; + using detail::is_signed; + using detail::is_signed_v; + using detail::is_void; + using detail::is_void_v; + using detail::remove_const; + using detail::remove_const_t; + using detail::remove_cv; + using detail::remove_cv_t; + using detail::remove_extent; + using detail::remove_extent_t; + using detail::remove_reference; + using detail::remove_reference_t; + using detail::remove_volatile; + using detail::remove_volatile_t; + using detail::true_type; + + using detail::max; + using detail::min; + +} // namespace rocwmma + +#else + +#include +namespace rocwmma +{ + // std implementations + using std::add_pointer; + using std::add_pointer_t; + using std::bool_constant; + using std::conditional; + using std::conditional_t; + using std::decay; + using std::decay_t; + using std::enable_if; + using std::enable_if_t; + using std::false_type; + using std::integral_constant; + using std::is_arithmetic; + using std::is_arithmetic_v; + using std::is_array; + using std::is_array_v; + using std::is_convertible; + using std::is_convertible_v; + using std::is_floating_point; + using std::is_floating_point_v; + using std::is_function; + using std::is_function_v; + using std::is_integral; + using std::is_integral_v; + using std::is_lvalue_reference; + using std::is_lvalue_reference_v; + using std::is_reference; + using std::is_reference_v; + using std::is_rvalue_reference; + using std::is_rvalue_reference_v; + using std::is_same; + using std::is_same_v; + using std::is_signed; + using std::is_signed_v; + using std::is_void; + using std::is_void_v; + using std::remove_const; + using std::remove_const_t; + using std::remove_cv; + using std::remove_cv_t; + using std::remove_extent; + using std::remove_extent_t; + using std::remove_reference; + using std::remove_reference_t; + using std::remove_volatile; + using std::remove_volatile_t; + using std::true_type; + + using std::max; + using std::min; + +} // namespace rocwmma + +#endif // defined(__HIPCC_RTC__) || defined(__clang__) + +#endif // ROCWMMA_UTILITY_TYPE_TRAITS_HPP diff --git a/library/include/rocwmma/internal/utility/type_traits_impl.hpp b/library/include/rocwmma/internal/utility/type_traits_impl.hpp new file mode 100644 index 00000000..02b7a426 --- /dev/null +++ b/library/include/rocwmma/internal/utility/type_traits_impl.hpp @@ -0,0 +1,631 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_TYPE_TRAITS_IMPL_HPP +#define ROCWMMA_UTILITY_TYPE_TRAITS_IMPL_HPP + +namespace rocwmma +{ + namespace detail + { + // TODO: Separate file? + template + ROCWMMA_HOST_DEVICE constexpr const T& max(const T& a, const T& b) + { + return (a < b) ? b : a; + } + + template + ROCWMMA_HOST_DEVICE constexpr const T& min(const T& a, const T& b) + { + return (a < b) ? a : b; + } + + using ::size_t; + + template + struct integral_constant + { + static constexpr const T value = Val; + using value_type = T; + using type = integral_constant; + constexpr operator value_type() const + { + return value; + } + constexpr value_type operator()() const + { + return value; + } + }; + + template + constexpr const T integral_constant::value; + + using true_type = integral_constant; + using false_type = integral_constant; + + template + using bool_constant = integral_constant; + + using true_type = bool_constant; + using false_type = bool_constant; + + template + struct true_or_false_type : public false_type + { + }; + template <> + struct true_or_false_type : public true_type + { + }; + + // Static conditional + template + struct conditional + { + }; + + template + struct conditional + { + using type = T; + }; + + template + struct conditional + { + using type = F; + }; + + template + using conditional_t = typename conditional::type; + + // Logical ops + template + struct logical_or; + + template <> + struct logical_or<> : public false_type + { + }; + + template + struct logical_or : public T + { + }; + + template + struct logical_or : public conditional_t + { + }; + + template + struct logical_or + : public conditional_t> + { + }; + + template + using logical_or_t = typename logical_or::type; + + template + struct logical_and; + + template <> + struct logical_and<> : public true_type + { + }; + + template + struct logical_and : public B1 + { + }; + + template + struct logical_and : public conditional_t + { + }; + + template + struct logical_and + : public conditional_t, B1> + { + }; + + template + using logical_and_t = typename logical_and::type; + + template + struct logical_not : public bool_constant + { + }; + + template + using logical_not_t = typename logical_not::type; + + // remove_reference + template + struct remove_reference + { + using type = T; + }; + + template + struct remove_reference + { + using type = T; + }; + + template + struct remove_reference + { + using type = T; + }; + + template + using remove_reference_t = typename remove_reference::type; + + // remove_const + template + struct remove_const + { + using type = T; + }; + + template + struct remove_const + { + using type = T; + }; + + template + using remove_const_t = typename remove_const::type; + + // remove_volatile + template + struct remove_volatile + { + using type = T; + }; + + template + struct remove_volatile + { + using type = T; + }; + + template + using remove_volatile_t = typename remove_volatile::type; + + // remove_cv + template + struct remove_cv + { + using type = remove_const_t>; + }; + + template + using remove_cv_t = typename remove_cv::type; + + // remove_extent + template + struct remove_extent + { + using type = T; + }; + + template + struct remove_extent + { + using type = T; + }; + + template + struct remove_extent + { + using type = T; + }; + + template + using remove_extent_t = typename remove_extent::type; + + // add_pointer + template + struct is_referenceable; + + template + struct is_void; + + template , is_void>::value> + struct add_pointer_helper + { + using type = T; + }; + + template + struct add_pointer_helper + { + using type = remove_reference_t*; + }; + + template + struct add_pointer : public add_pointer_helper + { + }; + + template + using add_pointer_t = typename add_pointer::type; + + // is_lvalue_reference + template + struct is_lvalue_reference : public false_type + { + }; + + template + struct is_lvalue_reference : public true_type + { + }; + + template + inline constexpr bool is_lvalue_reference_v = is_lvalue_reference::value; + + // is_rvalue_reference + template + struct is_rvalue_reference : public false_type + { + }; + + template + struct is_rvalue_reference : public true_type + { + }; + + template + inline constexpr bool is_rvalue_reference_v = is_rvalue_reference::value; + + // is_void + template + struct is_void_helper : public false_type + { + }; + + template <> + struct is_void_helper : public true_type + { + }; + + template + struct is_void : public is_void_helper>::type + { + }; + + template + inline constexpr bool is_void_v = is_void::value; + + // is_reference + template + struct is_reference : public logical_or_t, is_rvalue_reference> + { + }; + + template + inline constexpr bool is_reference_v = is_reference::value; + + // is_function + template + struct is_function : public false_type + { + }; + + template + inline constexpr bool is_function_v = is_function::value; + + // is_object + template + struct is_object + : public logical_not_t, is_reference, is_void>> + { + }; + + template + inline constexpr bool is_object_v = is_object::value; + + // __is_referenceable + template + struct is_referenceable : public logical_or_t, is_reference> + { + }; + + template + inline constexpr bool is_referenceable_v = is_referenceable::value; + + // is_array + template + struct is_array : public false_type + { + }; + + template + struct is_array : public true_type + { + }; + + template + struct is_array : public true_type + { + }; + + template + inline constexpr bool is_array_v = is_array::value; + + // is_integral + template + struct is_integral : public false_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + template <> + struct is_integral : public true_type + { + }; + + template + inline constexpr bool is_integral_v = is_integral::value; + + // is_arithmetic + template + struct is_arithmetic : public false_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + template <> + struct is_arithmetic : public true_type + { + }; + + template + inline constexpr bool is_arithmetic_v = is_arithmetic::value; + + // is_floating_point + template + struct is_floating_point : public false_type + { + }; + template <> + struct is_floating_point : public true_type + { + }; + template <> + struct is_floating_point : public true_type + { + }; + template <> + struct is_floating_point : public true_type + { + }; + + template + inline constexpr bool is_floating_point_v = is_floating_point::value; + + // is_signed + template ::value> + struct is_signed : public false_type + { + }; + + template + struct is_signed : public true_or_false_type + { + }; + + template + inline constexpr bool is_signed_v = is_signed::value; + + // is_same + template + struct is_same : public false_type + { + }; + template + struct is_same : public true_type + { + }; + + template + inline constexpr bool is_same_v = is_same::value; + + // is_convertible + template + struct is_convertible : public true_or_false_type<__is_convertible_to(T1, T2)> + { + }; + + template + inline constexpr bool is_convertible_v = is_convertible::value; + + // decay selectors + template ::value, + bool IsFunction = is_function::value> + struct decay_selector; + + template + struct decay_selector + { + using type = remove_cv_t; + }; + + template + struct decay_selector + { + using type = remove_extent_t*; + }; + + template + struct decay_selector + { + using type = add_pointer_t; + }; + + template + using decay_selector_t = typename decay_selector::type; + + // decay + template + class decay + { + using remove_type = remove_reference_t; + + public: + using type = decay_selector_t; + }; + + template + using decay_t = typename decay::type; + + // SFINAE enable_if + template + struct enable_if + { + }; + template + struct enable_if + { + using type = T; + }; + + template + using enable_if_t = typename enable_if::type; + + } // namespace detail + +} // namespace rocwmma + +#endif // ROCWMMA_UTILITY_TYPE_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/utility/vector.hpp b/library/include/rocwmma/internal/utility/vector.hpp new file mode 100644 index 00000000..060c8d59 --- /dev/null +++ b/library/include/rocwmma/internal/utility/vector.hpp @@ -0,0 +1,51 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_VECTOR_HPP +#define ROCWMMA_UTILITY_VECTOR_HPP + +#include "vector_impl.hpp" + +namespace rocwmma +{ + template + ROCWMMA_HOST_DEVICE constexpr inline auto vector_size(VecT const& v); + + template + ROCWMMA_HOST_DEVICE constexpr decltype(auto) make_vector(Ts&&... ts); + + template + ROCWMMA_HOST_DEVICE constexpr decltype(auto) vector_cat(Lhs&& lhs, Rhs&& rhs); + + template + ROCWMMA_HOST_DEVICE constexpr static inline decltype(auto) + vector_reduce_and(VecT&& lhs) noexcept; + + template + ROCWMMA_HOST_DEVICE constexpr inline auto swap(HIP_vector_type const& v); +} // namespace rocwmma + +#endif // ROCWMMA_UTILITY_VECTOR_HPP diff --git a/library/include/rocwmma/internal/utility/vector_impl.hpp b/library/include/rocwmma/internal/utility/vector_impl.hpp new file mode 100644 index 00000000..6e73289d --- /dev/null +++ b/library/include/rocwmma/internal/utility/vector_impl.hpp @@ -0,0 +1,223 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_UTILITY_VECTOR_IMPL_HPP +#define ROCWMMA_UTILITY_VECTOR_IMPL_HPP + +#include "type_traits.hpp" +#include "get.hpp" + +namespace rocwmma +{ + template + ROCWMMA_HOST_DEVICE constexpr inline auto vector_size(VecT const& v) + { + return VecTraits::size(); + } + + namespace detail + { + template + struct first_type; + + template + struct first_type + { + using type = T; + }; + + template + using first_type_t = typename first_type::type; + + template + struct is_same_type; + + template + struct is_same_type : true_type + { + }; + + template + struct is_same_type + : conditional_t{}, is_same_type, false_type> + { + }; + + template + constexpr bool is_same_type_v = is_same_type::value; + } + + template + ROCWMMA_HOST_DEVICE constexpr inline auto swap(HIP_vector_type const& v) + { + return HIP_vector_type{get<1>(v), get<0>(v)}; + } + + // temporary apply impl + namespace detail + { + template + constexpr decltype(auto) + apply_impl(F fn, HIP_vector_type const& v, index_sequence) + { + return fn(get(v)...); + } + + } // namespace detail + + template + constexpr decltype(auto) apply(F fn, HIP_vector_type& v) + { + constexpr size_t size = VecTraits>::size(); + return detail::apply_impl(fn, v, detail::make_index_sequence()); + } + + namespace detail + { + template + constexpr decltype(auto) + apply_impl(F fn, non_native_vector_base const& v, index_sequence) + { + return fn(get(v)...); + } + + } // namespace detail + + template + constexpr decltype(auto) apply(F fn, non_native_vector_base const& v) + { + constexpr size_t size = VecTraits>::size(); + return detail::apply_impl(fn, v, detail::make_index_sequence()); + } + + template + ROCWMMA_HOST_DEVICE constexpr decltype(auto) make_vector(Ts&&... ts) + { + // TODO: When HIP_vector_type becomes constexpr replace with non_native_vector type. + + // Ensure that all the arguments are the same type + static_assert(detail::is_same_type_v...>, + "Vector arguments must all be the same type"); + + using DataT = typename detail::first_type_t...>; + return non_native_vector_base{forward(ts)...}; + } + + namespace detail + { + template + constexpr static inline decltype(auto) + vector_cat_impl(non_native_vector_base const& lhs, + index_sequence, + non_native_vector_base const& rhs, + index_sequence) + { + return make_vector(get(lhs)..., get(rhs)...); + } + + } // namespace detail + + template + ROCWMMA_HOST_DEVICE constexpr decltype(auto) vector_cat(Lhs&& lhs, Rhs&& rhs) + { + constexpr size_t Size0 = VecTraits>::size(); + constexpr size_t Size1 = VecTraits>::size(); + + return detail::vector_cat_impl(forward(lhs), + detail::make_index_sequence(), + forward(rhs), + detail::make_index_sequence()); + } + + namespace detail + { + template + constexpr static inline decltype(auto) + mult_poly_vec_impl(non_native_vector_base const& lhs, + non_native_vector_base const& rhs, + index_sequence) + { + return make_vector((get(lhs) * get(rhs))...); + } + + } // namespace detail + + template + constexpr decltype(auto) operator*(non_native_vector_base const& lhs, + non_native_vector_base const& rhs) + { + return detail::mult_poly_vec_impl(lhs, rhs, detail::make_index_sequence()); + } + + namespace detail + { + template + ROCWMMA_HOST_DEVICE constexpr static inline decay_t reduceOp_impl(T&& t, + Ts&&... ts) noexcept + { + using CastT = decay_t; + if constexpr(sizeof...(Ts) >= 1) + { + return BinOp::exec(static_cast(t), reduceOp_impl(forward(ts)...)); + } + else + { + return static_cast(t); + } + } + + template + ROCWMMA_HOST_DEVICE constexpr static inline decltype(auto) + vector_reduce_impl(VecT&& v, index_sequence) noexcept + { + return reduceOp_impl(get(forward(v))...); + } + + // Use with operations that have 1 operands + template + ROCWMMA_HOST_DEVICE constexpr static inline decltype(auto) + vector_reduce(VecT&& lhs) noexcept + { + return vector_reduce_impl( + forward(lhs), + detail::make_index_sequence>::size()>{}); + } + } + + template + ROCWMMA_HOST_DEVICE constexpr static inline decltype(auto) + vector_reduce_and(VecT&& lhs) noexcept + { + return detail::vector_reduce(forward(lhs)); + } +} // namespace rocwmma + +#endif // ROCWMMA_UTILITY_VECTOR_IMPL_HPP diff --git a/library/include/rocwmma/internal/utils.hpp b/library/include/rocwmma/internal/utils.hpp index 51edddba..c06e874a 100644 --- a/library/include/rocwmma/internal/utils.hpp +++ b/library/include/rocwmma/internal/utils.hpp @@ -27,6 +27,8 @@ #define ROCWMMA_UTILS_HPP #include "types.hpp" + +#include "utility/get.hpp" #include "vector.hpp" namespace rocwmma @@ -37,137 +39,6 @@ namespace rocwmma /// Note: performs static unroll /// /////////////////////////////////////////////////////////////////// - namespace detail - { - template - ROCWMMA_DEVICE constexpr static inline auto extractEven(VecT const& v, - detail::SeqT) - { - static_assert(sizeof...(Idx) == VecSize / 2u, - "Index count must be half the vector size"); - return VecT{get(v)...}; - } - - template - ROCWMMA_DEVICE constexpr static inline auto extractOdd(VecT const& v, - detail::SeqT) - { - static_assert(sizeof...(Idx) == VecSize / 2u, - "Index count must be half the vector size"); - return VecT{get(v)...}; - } - - template - ROCWMMA_DEVICE constexpr static inline auto extractLo(VecT const& v, - detail::SeqT) - { - static_assert(sizeof...(Idx) == VecSize / 2u, - "Index count must be half the vector size"); - return VecT{get(v)...}; - } - - template - ROCWMMA_DEVICE constexpr static inline auto extractHi(VecT const& v, - detail::SeqT) - { - static_assert(sizeof...(Idx) == VecSize / 2u, - "Index count must be half the vector size"); - return VecT{get(v)...}; - } - - template - ROCWMMA_DEVICE constexpr static inline auto concat(VecT const& v0, - VecT const& v1, - detail::SeqT) - { - static_assert(sizeof...(Idx) == VecSize, "Index count must equal the vector size"); - return VecT{get(v0)..., get(v1)...}; - } - - template - ROCWMMA_DEVICE constexpr static inline auto zip(VecT const& v0, - VecT const& v1, - detail::SeqT) - { - static_assert(sizeof...(Idx) == VecSize, "Index count must equal the vector size"); - return VecT{((Idx % 2 == 0) ? get(v0) : get(v1))...}; - } - - template - ROCWMMA_DEVICE constexpr static inline auto unpackLo(VecT const& v0, - VecT const& v1, - detail::SeqT) - { - static_assert(sizeof...(Idx) == VecSize, "Index count must equal the vector size"); - return VecT{ - ((Idx % 2 == 0) ? get(v0) : get(v1))...}; - } - - template - ROCWMMA_DEVICE constexpr static inline auto unpackHi(VecT const& v0, - VecT const& v1, - detail::SeqT) - { - constexpr auto startIdx = VecSize / 2u; - static_assert(sizeof...(Idx) == VecSize, "Index count must equal the vector size"); - return VecT{ - ((Idx % 2 == 0) ? get(v0) : get(v1))...}; - } - - } // namespace detail - - template - ROCWMMA_DEVICE constexpr static inline auto extractEven(VecT const& v) - { - return detail::extractEven(v, detail::Seq{}); - } - - template - ROCWMMA_DEVICE constexpr static inline auto extractLo(VecT const& v) - { - return detail::extractLo(v, detail::Seq{}); - } - - template - ROCWMMA_DEVICE constexpr static inline auto extractHi(VecT const& v) - { - return detail::extractHi(v, detail::Seq{}); - } - - template - ROCWMMA_DEVICE constexpr static inline auto extractOdd(VecT const& v) - { - return detail::extractOdd(v, detail::Seq{}); - } - - template - ROCWMMA_DEVICE constexpr static inline auto concat(VecT const& v0, - VecT const& v1) - { - return detail::concat(v0, v1, detail::Seq{}); - } - - template - ROCWMMA_DEVICE constexpr static inline auto zip(VecT const& v0, - VecT const& v1) - { - return detail::zip(v0, v1, detail::Seq{}); - } - - template - ROCWMMA_DEVICE constexpr static inline auto unpackLo(VecT const& v0, - VecT const& v1) - { - return detail::unpackLo(v0, v1, detail::Seq{}); - } - - template - ROCWMMA_DEVICE constexpr static inline auto unpackHi(VecT const& v0, - VecT const& v1) - { - return detail::unpackHi(v0, v1, detail::Seq{}); - } - // Unary swap only considered in 2d vectors. template ROCWMMA_HOST_DEVICE constexpr static inline auto swap(non_native_vector_base const& v) @@ -222,7 +93,7 @@ namespace std template auto apply_impl(F fn, Tuple t, std::index_sequence) { - return fn(std::get(t)...); + return fn(get(t)...); } template auto apply(F fn, Tuple t) @@ -296,9 +167,9 @@ namespace rocwmma { // Computes ceil(numerator/divisor) for integer types. template ::value>::type, + class = typename enable_if::value>::type, typename intT2, - class = typename std::enable_if::value>::type> + class = typename enable_if::value>::type> static constexpr intT1 ceilDiv(const intT1 numerator, const intT2 divisor) { return (numerator + divisor - 1) / divisor; diff --git a/library/include/rocwmma/internal/vector.hpp b/library/include/rocwmma/internal/vector.hpp index 2b2b93ce..bf6ef234 100644 --- a/library/include/rocwmma/internal/vector.hpp +++ b/library/include/rocwmma/internal/vector.hpp @@ -30,10 +30,15 @@ // #include "types.hpp" // #include "types_ext.hpp" #if !defined(__HIPCC_RTC__) + #include #include + #endif +#include "utility/forward.hpp" +#include "utility/type_traits.hpp" + /** * rocWMMA vectors are implemented as HIP_vector_type objects, which will ultimately * serve as the backend storage for fragment objects. The intention is to be compatible @@ -147,13 +152,13 @@ namespace rocwmma ROCWMMA_HOST_DEVICE inline VecT& operator=(VecT&&) = default; - template {}) && (Rank > 1)>::type* = nullptr> + template {}) && (Rank > 1)>::type* = nullptr> ROCWMMA_HOST_DEVICE explicit constexpr non_native_vector_base(T x_) noexcept; template ::type* = nullptr> + typename U = T, + typename enable_if<(sizeof...(Ts) == Rank)>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr non_native_vector_base(Ts... args) noexcept; ROCWMMA_HOST_DEVICE @@ -186,28 +191,28 @@ namespace rocwmma ROCWMMA_HOST_DEVICE constexpr inline VecT operator/(const VecT& x_) noexcept; - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE inline VecT& operator%=(const VecT& x_) noexcept; - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE inline VecT operator-() const noexcept; - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE inline VecT& operator&=(const VecT& x_) noexcept; - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE inline VecT& operator|=(const VecT& x_) noexcept; - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE inline VecT operator~() const noexcept; - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE inline VecT& operator^=(const VecT& x_) noexcept; - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE inline VecT& operator>>=(const VecT& x_) noexcept; - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE inline VecT& operator<<=(const VecT& x_) noexcept; ROCWMMA_HOST_DEVICE @@ -395,6 +400,7 @@ ROCWMMA_REGISTER_HIP_NON_NATIVE_VECTOR_TYPE_WITH_INC_DEC_OPS_AS_FLOAT(rocwmma::b ROCWMMA_REGISTER_HIP_NON_NATIVE_VECTOR_TYPE_WITH_INC_DEC_OPS_AS_FLOAT(rocwmma::bfloat16_t, 512); #include "type_traits.hpp" +#include "utility/get.hpp" namespace rocwmma { @@ -434,223 +440,8 @@ namespace rocwmma return VecSize; } }; +} - namespace detail - { - template - struct first_type; - - template - struct first_type - { - using type = T; - }; - - template - using first_type_t = typename first_type::type; - - template - struct is_same_type; - - template - struct is_same_type : std::true_type - { - }; - - template - struct is_same_type - : std::conditional_t{}, is_same_type, std::false_type> - { - }; - - template - constexpr bool is_same_type_v = is_same_type::value; - } - - /////////////////////////////////////////////////////////////////// - /// HIP_vector_type utility overrides /// - /// /// - /// Note: HIP_vector_type uses vector extensions. /// - /// Element-wise access of vectors in constexpr is forbidden. /// - /////////////////////////////////////////////////////////////////// - template - ROCWMMA_HOST_DEVICE constexpr inline DataT& get(HIP_vector_type& v) - { - return reinterpret_cast(&v.data)[Idx]; - } - - template - ROCWMMA_HOST_DEVICE constexpr inline DataT get(HIP_vector_type const& v) - { - return v.data[Idx]; - } - - template - ROCWMMA_HOST_DEVICE constexpr inline auto swap(HIP_vector_type const& v) - { - return HIP_vector_type{get<1>(v), get<0>(v)}; - } - - namespace detail - { - template - constexpr decltype(auto) - apply_impl(F fn, HIP_vector_type const& v, index_sequence) - { - return fn(get(v)...); - } - - } // namespace detail - - template - constexpr decltype(auto) apply(F fn, HIP_vector_type& v) - { - constexpr std::size_t size = VecTraits>::size(); - return detail::apply_impl(fn, v, detail::make_index_sequence()); - } - - /////////////////////////////////////////////////////////////////// - /// non_native_vector_base utility overrides /// - /////////////////////////////////////////////////////////////////// - template - ROCWMMA_HOST_DEVICE constexpr static inline DataT& - get(non_native_vector_base& v) - { - return v[Idx]; - } - - template - ROCWMMA_HOST_DEVICE constexpr static inline DataT - get(non_native_vector_base const& v) - { - return v[Idx]; - } - - namespace detail - { - template - constexpr decltype(auto) - apply_impl(F fn, non_native_vector_base const& v, index_sequence) - { - return fn(get(v)...); - } - - } // namespace detail - - template - constexpr decltype(auto) apply(F fn, non_native_vector_base const& v) - { - constexpr std::size_t size = VecTraits>::size(); - return detail::apply_impl(fn, v, detail::make_index_sequence()); - } - - template - constexpr decltype(auto) make_vector(Ts&&... ts) - { - // TODO: When HIP_vector_type becomes constexpr replace with non_native_vector type. - - // Ensure that all the arguments are the same type - static_assert(detail::is_same_type_v...>, - "Vector arguments must all be the same type"); - - using DataT = typename detail::first_type_t...>; - return non_native_vector_base{std::forward(ts)...}; - } - - namespace detail - { - template - constexpr static inline decltype(auto) - vector_cat_impl(non_native_vector_base const& lhs, - index_sequence, - non_native_vector_base const& rhs, - index_sequence) - { - return make_vector(get(lhs)..., get(rhs)...); - } - - } // namespace detail - - template - constexpr decltype(auto) vector_cat(Lhs&& lhs, Rhs&& rhs) - { - constexpr std::size_t Size0 = VecTraits>::size(); - constexpr std::size_t Size1 = VecTraits>::size(); - - return detail::vector_cat_impl(std::forward(lhs), - detail::make_index_sequence(), - std::forward(rhs), - detail::make_index_sequence()); - } - - namespace detail - { - template - constexpr static inline decltype(auto) - mult_poly_vec_impl(non_native_vector_base const& lhs, - non_native_vector_base const& rhs, - index_sequence) - { - return make_vector((get(lhs) * get(rhs))...); - } - - } // namespace detail - - template - constexpr decltype(auto) operator*(non_native_vector_base const& lhs, - non_native_vector_base const& rhs) - { - return detail::mult_poly_vec_impl(lhs, rhs, detail::make_index_sequence()); - } - - namespace detail - { - template - ROCWMMA_HOST_DEVICE constexpr static inline std::decay_t - reduceOp_impl(T&& t, Ts&&... ts) noexcept - { - using CastT = std::decay_t; - if constexpr(sizeof...(Ts) >= 1) - { - return BinOp::exec(static_cast(t), - reduceOp_impl(std::forward(ts)...)); - } - else - { - return static_cast(t); - } - } - - template - ROCWMMA_HOST_DEVICE constexpr static inline decltype(auto) - vector_reduce_impl(VecT&& v, index_sequence) noexcept - { - return reduceOp_impl(get(std::forward(v))...); - } - - // Use with operations that have 1 operands - template - ROCWMMA_HOST_DEVICE constexpr static inline decltype(auto) - vector_reduce(VecT&& lhs) noexcept - { - return vector_reduce_impl( - std::forward(lhs), - detail::make_index_sequence>::size()>{}); - } - } - - template - ROCWMMA_HOST_DEVICE constexpr static inline decltype(auto) - vector_reduce_and(VecT&& lhs) noexcept - { - return detail::vector_reduce(std::forward(lhs)); - } - -} // namespace rocwmma +#include "utility/vector.hpp" #endif // ROCWMMA_VECTOR_HPP diff --git a/library/include/rocwmma/internal/vector_impl.hpp b/library/include/rocwmma/internal/vector_impl.hpp index d082ee4e..e5cc1692 100644 --- a/library/include/rocwmma/internal/vector_impl.hpp +++ b/library/include/rocwmma/internal/vector_impl.hpp @@ -27,6 +27,7 @@ #ifndef ROCWMMA_VECTOR_IMPL_HPP #define ROCWMMA_VECTOR_IMPL_HPP +#include "utility/sequence.hpp" #include "vector.hpp" namespace rocwmma @@ -69,8 +70,7 @@ namespace rocwmma }; struct Mod { - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline auto exec(TT lhs, TT rhs) { return lhs % rhs; @@ -78,8 +78,7 @@ namespace rocwmma }; struct Minus { - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline auto exec(TT lhs) { return -lhs; @@ -92,8 +91,7 @@ namespace rocwmma { struct And { - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline TT exec(TT lhs, TT rhs) { return lhs & rhs; @@ -102,8 +100,7 @@ namespace rocwmma struct Or { - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline TT exec(TT lhs, TT rhs) { return lhs | rhs; @@ -112,8 +109,7 @@ namespace rocwmma struct Not { - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline TT exec(TT lhs) { return ~lhs; @@ -122,8 +118,7 @@ namespace rocwmma struct Xor { - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline TT exec(TT lhs, TT rhs) { return lhs ^ rhs; @@ -132,8 +127,7 @@ namespace rocwmma struct ShiftR { - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline TT exec(TT lhs, TT rhs) { return lhs >> rhs; @@ -142,8 +136,7 @@ namespace rocwmma struct ShiftL { - template {}>::type* = nullptr> + template {}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline TT exec(TT lhs, TT rhs) { return lhs >> rhs; @@ -157,7 +150,7 @@ namespace rocwmma struct And { template {}>::type* = nullptr> + typename enable_if{}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline auto exec(TT lhs, TT rhs) { return lhs && rhs; @@ -167,7 +160,7 @@ namespace rocwmma struct Or { template {}>::type* = nullptr> + typename enable_if{}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline auto exec(TT lhs, TT rhs) { return lhs || rhs; @@ -177,7 +170,7 @@ namespace rocwmma struct Not { template {}>::type* = nullptr> + typename enable_if{}>::type* = nullptr> ROCWMMA_HOST_DEVICE constexpr static inline auto exec(TT lhs) { return !lhs; @@ -244,82 +237,6 @@ namespace rocwmma } // namespace RelationalOp - template - struct integral_constant - { - static constexpr IntT value = val; - using value_type = IntT; - using type = integral_constant; - constexpr operator value_type() const noexcept - { - return value; - } - constexpr value_type operator()() const noexcept - { - return value; - } - }; - - template - struct integer_sequence - { - using value_type = Int; - constexpr integer_sequence() {} - static constexpr std::size_t size() noexcept - { - return sizeof...(Ints); - } - }; - - template - using index_sequence = integer_sequence; - - namespace - { - // Merge two integer sequences, adding an offset to the right-hand side. - template - struct merge; - - template - struct merge, - integer_sequence, - integer_sequence> - { - using type = integer_sequence; - }; - - template - struct log_make_sequence - { - using L = integral_constant; - using R = integral_constant; - using type = typename merge::type, - typename log_make_sequence::type>::type; - }; - - // An empty sequence. - template - struct log_make_sequence> - { - using type = integer_sequence; - }; - - // A single-element sequence. - template - struct log_make_sequence> - { - using type = integer_sequence; - }; - } - - template - using make_integer_sequence = - typename log_make_sequence>::type; - - template - using make_index_sequence = make_integer_sequence; - // Helpers for expression expansion, specific to non_native_vector_base template using SeqT = integer_sequence; @@ -367,7 +284,7 @@ namespace rocwmma // As a solution, Rank == 1 should fall into the ctor(Ts... args) for initializer // list construction, and NOT bCast initialization. template - template {}) && (Rank > 1)>::type*> + template {}) && (Rank > 1)>::type*> ROCWMMA_HOST_DEVICE constexpr non_native_vector_base::non_native_vector_base( T x_) noexcept : non_native_vector_base(detail::template bCast(x_, detail::Seq{})) @@ -378,7 +295,7 @@ namespace rocwmma // Default template depth is currently not deep enough to // support vector sizes of 512 template - template ::type*> + template ::type*> ROCWMMA_HOST_DEVICE constexpr non_native_vector_base::non_native_vector_base( Ts... args) noexcept : d{static_cast(args)...} @@ -460,7 +377,7 @@ namespace rocwmma } template - template {}>::type*> + template {}>::type*> ROCWMMA_HOST_DEVICE inline auto non_native_vector_base::operator%=(const VecT& x_) noexcept -> VecT& { @@ -468,7 +385,7 @@ namespace rocwmma } template - template {}>::type*> + template {}>::type*> ROCWMMA_HOST_DEVICE inline auto non_native_vector_base::operator-() const noexcept -> VecT { @@ -477,7 +394,7 @@ namespace rocwmma // @cond template - template {}>::type*> + template {}>::type*> ROCWMMA_HOST_DEVICE inline auto non_native_vector_base::operator&=(const VecT& x_) noexcept -> VecT& { @@ -486,7 +403,7 @@ namespace rocwmma // @endcond template - template {}>::type*> + template {}>::type*> ROCWMMA_HOST_DEVICE inline auto non_native_vector_base::operator|=(const VecT& x_) noexcept -> VecT& { @@ -494,7 +411,7 @@ namespace rocwmma } template - template {}>::type*> + template {}>::type*> ROCWMMA_HOST_DEVICE inline auto non_native_vector_base::operator~() const noexcept -> VecT { @@ -502,7 +419,7 @@ namespace rocwmma } template - template {}>::type*> + template {}>::type*> ROCWMMA_HOST_DEVICE inline auto non_native_vector_base::operator^=(const VecT& x_) noexcept -> VecT& { @@ -510,7 +427,7 @@ namespace rocwmma } template - template {}>::type*> + template {}>::type*> ROCWMMA_HOST_DEVICE inline auto non_native_vector_base::operator>>=(const VecT& x_) noexcept -> VecT& { @@ -518,7 +435,7 @@ namespace rocwmma } template - template {}>::type*> + template {}>::type*> ROCWMMA_HOST_DEVICE inline auto non_native_vector_base::operator<<=(const VecT& x_) noexcept -> VecT& { @@ -660,44 +577,43 @@ namespace rocwmma /// OR native vector extension. The latter doesn't have the required built-in broadcast. /// //////////////////////////////////////////////////////////////////////////////////////////////// -#define ROCWMMA_REGISTER_HIP_VECTOR_BASE(TYPE, RANK, STORAGE_IMPL) \ - template <> \ - struct HIP_vector_base \ - { \ - STORAGE_IMPL(TYPE, RANK); \ - \ - using value_type = TYPE; \ - \ - ROCWMMA_HOST_DEVICE \ - HIP_vector_base() = default; \ - template ::type* = nullptr> \ - ROCWMMA_HOST_DEVICE constexpr HIP_vector_base(ArgsT... args) noexcept \ - : data{args...} \ - { \ - } \ - \ - template < \ - typename U = TYPE, \ - typename std::enable_if<(std::is_same{}) && (RANK > 1)>::type* = nullptr> \ - ROCWMMA_HOST_DEVICE constexpr explicit HIP_vector_base(TYPE val) noexcept \ - : HIP_vector_base(rocwmma::detail::template bCast( \ - val, rocwmma::detail::Seq{})) \ - { \ - } \ - \ - ROCWMMA_HOST_DEVICE \ - constexpr HIP_vector_base(const HIP_vector_base&) = default; \ - \ - ROCWMMA_HOST_DEVICE \ - constexpr HIP_vector_base(HIP_vector_base&&) = default; \ - \ - ROCWMMA_HOST_DEVICE \ - ~HIP_vector_base() = default; \ - \ - ROCWMMA_HOST_DEVICE \ - HIP_vector_base& operator=(const HIP_vector_base& x_) noexcept = default; \ +#define ROCWMMA_REGISTER_HIP_VECTOR_BASE(TYPE, RANK, STORAGE_IMPL) \ + template <> \ + struct HIP_vector_base \ + { \ + STORAGE_IMPL(TYPE, RANK); \ + \ + using value_type = TYPE; \ + \ + ROCWMMA_HOST_DEVICE \ + HIP_vector_base() = default; \ + template * = nullptr> \ + ROCWMMA_HOST_DEVICE constexpr HIP_vector_base(ArgsT... args) noexcept \ + : data{args...} \ + { \ + } \ + \ + template {}) && (RANK > 1)>* = nullptr> \ + ROCWMMA_HOST_DEVICE constexpr explicit HIP_vector_base(TYPE val) noexcept \ + : HIP_vector_base(rocwmma::detail::template bCast( \ + val, rocwmma::detail::Seq{})) \ + { \ + } \ + \ + ROCWMMA_HOST_DEVICE \ + constexpr HIP_vector_base(const HIP_vector_base&) = default; \ + \ + ROCWMMA_HOST_DEVICE \ + constexpr HIP_vector_base(HIP_vector_base&&) = default; \ + \ + ROCWMMA_HOST_DEVICE \ + ~HIP_vector_base() = default; \ + \ + ROCWMMA_HOST_DEVICE \ + HIP_vector_base& operator=(const HIP_vector_base& x_) noexcept = default; \ }; /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/library/include/rocwmma/internal/vector_util.hpp b/library/include/rocwmma/internal/vector_util.hpp new file mode 100644 index 00000000..abc56c4f --- /dev/null +++ b/library/include/rocwmma/internal/vector_util.hpp @@ -0,0 +1,129 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_VECTOR_UTIL_HPP +#define ROCWMMA_VECTOR_UTIL_HPP + +#include "types.hpp" +#include "vector.hpp" + +namespace rocwmma +{ + //! Extracts the first (lo) half of elements from a given vector + /*! + \param v Vector to extract the lo elements from. + */ + template + ROCWMMA_DEVICE constexpr static inline auto extractLo(VecT const& v); + + //! Extracts the second (hi) half of elements from a given vector + /*! + \param v Vector to extract the hi elements from. + */ + template + ROCWMMA_DEVICE constexpr static inline auto extractHi(VecT const& v); + + //! Extracts the the even elements elements from a given vector + /*! + \param v Vector to extract the even elements from. + */ + template + ROCWMMA_HOST_DEVICE constexpr static inline auto extractEven(VecT const& v); + + //! Extracts the the odd elements elements from a given vector + /*! + \param v Vector to extract the odd elements from. + */ + template + ROCWMMA_DEVICE constexpr static inline auto extractOdd(VecT const& v); + + //! Re-orders vector elements such that even elements are concatenated with odd elements. + /*! + \param v Vector to reorder elements from. + */ + template + ROCWMMA_DEVICE constexpr static inline auto reorderEvenOdd(VecT const& v); + + //! Re-orders vector elements such that odd elements are concatenated with even elements. + /*! + \param v Vector to reorder elements from. + */ + template + ROCWMMA_DEVICE constexpr static inline auto reorderOddEven(VecT const& v); + + //! Concatenates the contents of two vectors together in order. + /*! + \param v0 First vector to concatenate + \param v1 Second vector to concatenate + */ + template + ROCWMMA_DEVICE constexpr static inline auto concat(VecT const& v0, + VecT const& v1); + + //! Alternates selecting even elements from the first vector and odd elements from the second vector. + //! Analogous to a zipper. + //! E.g. + //! v0 = [0, 1, 2, 3] + //! v1 = [4, 5, 6, 7] + //! result = [0, 5, 2, 7] + /*! + \param v0 Vector from which even elements are alternately selected + \param v1 Vector from which odd elements are alternately selected + */ + template + ROCWMMA_DEVICE constexpr static inline auto zip(VecT const& v0, + VecT const& v1); + + //! Alternates selecting the first (lo) half of elements from each vector + //! E.g. + //! v0 = [0, 1, 2, 3] + //! v1 = [4, 5, 6, 7] + //! result = [0, 4, 1, 5] + /*! + \param v0 Vector from which lo elements are alternately selected + \param v1 Vector from which lo elements are alternately selected + */ + template + ROCWMMA_DEVICE constexpr static inline auto unpackLo(VecT const& v0, + VecT const& v1); + + //! Alternates selecting the second (hi) half of elements from each vector + //! E.g. + //! v0 = [0, 1, 2, 3] + //! v1 = [4, 5, 6, 7] + //! result = [2, 6, 3, 7] + /*! + \param v0 Vector from which hi elements are alternately selected + \param v1 Vector from which hi elements are alternately selected + */ + template + ROCWMMA_DEVICE constexpr static inline auto unpackHi(VecT const& v0, + VecT const& v1); +} // namespace rocwmma + +#include "vector_util_impl.hpp" + +#endif // ROCWMMA_VECTOR_UTIL_HPP diff --git a/library/include/rocwmma/internal/vector_util_impl.hpp b/library/include/rocwmma/internal/vector_util_impl.hpp new file mode 100644 index 00000000..7d5d8396 --- /dev/null +++ b/library/include/rocwmma/internal/vector_util_impl.hpp @@ -0,0 +1,441 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_VECTOR_UTIL_IMPL_HPP +#define ROCWMMA_VECTOR_UTIL_IMPL_HPP + +#include "blend.hpp" +#include "types.hpp" +#include "vector.hpp" + +namespace rocwmma +{ + namespace detail + { + template + using Number = integral_constant; + + // Can be used to build any vector class of + // Either VecT or non_native_vector_vase. + // Class acts as a static for_each style generator: + // Incoming functor F will be called with each index + args in sequence. + // Results of functor calls are used to construct a new vector. + template