Skip to content

Commit

Permalink
Merge pull request ROCm#341 from dlangbe/utils
Browse files Browse the repository at this point in the history
rocWMMA Utilities
  • Loading branch information
dlangbe authored Feb 20, 2024
2 parents e9048ae + 3ea39d9 commit 69694c8
Show file tree
Hide file tree
Showing 54 changed files with 3,641 additions and 1,123 deletions.
13 changes: 13 additions & 0 deletions library/include/rocwmma/internal/blend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ namespace rocwmma
using Zip4 = Driver<BlendImpl::Ops::Zip4>;
using Zip8 = Driver<BlendImpl::Ops::Zip8>;
using Zip16 = Driver<BlendImpl::Ops::Zip16>;
using Zip32 = Driver<BlendImpl::Ops::Zip32>;

// Unpack functions
using UnpackByteLo = Driver<BlendImpl::Ops::UnpackByteLo>;
Expand All @@ -107,6 +108,18 @@ namespace rocwmma
using UnpackWordHi = Driver<BlendImpl::Ops::UnpackWordHi>;
using UnpackByteLoHi = Driver<BlendImpl::Ops::UnpackByteLoHi>;

// Extract functions
using ExtractByteEven = Driver<BlendImpl::Ops::ExtractByteEven>;
using ExtractByteOdd = Driver<BlendImpl::Ops::ExtractByteOdd>;
using ExtractWordEven = Driver<BlendImpl::Ops::ExtractWordEven>;
using ExtractWordOdd = Driver<BlendImpl::Ops::ExtractWordOdd>;

using ExtractByteEvenOdd = Driver<BlendImpl::Ops::ExtractByteEvenOdd>;
using ExtractWordEvenOdd = Driver<BlendImpl::Ops::ExtractWordEvenOdd>;

using ExtractByteOddEven = Driver<BlendImpl::Ops::ExtractByteOddEven>;
using ExtractWordOddEven = Driver<BlendImpl::Ops::ExtractWordOddEven>;

} // namespace Blend

} // namespace rocwmma
Expand Down
12 changes: 12 additions & 0 deletions library/include/rocwmma/internal/blend_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace rocwmma
using Properties::OP_GROUP_SIZE_1;
using Properties::OP_GROUP_SIZE_16;
using Properties::OP_GROUP_SIZE_2;
using Properties::OP_GROUP_SIZE_32;
using Properties::OP_GROUP_SIZE_4;
using Properties::OP_GROUP_SIZE_8;

Expand Down Expand Up @@ -247,6 +248,7 @@ namespace rocwmma
using Zip4 = Zip<OP_GROUP_SIZE_4>;
using Zip8 = Zip<OP_GROUP_SIZE_8>;
using Zip16 = Zip<OP_GROUP_SIZE_16>;
using Zip32 = Zip<OP_GROUP_SIZE_32>;

// Blend sub-dword elements in regular ordered patterns
using UnpackByteLo = PermByte<0u, 4u, 1u, 5u>;
Expand All @@ -255,6 +257,16 @@ namespace rocwmma
using UnpackWordHi = PermWord<1u, 3u>;
using UnpackByteLoHi = PermByte<0u, 6u, 1u, 7u>;

using ExtractByteEven = PermByte<0u, 2u, 4u, 6u>;
using ExtractByteOdd = PermByte<1u, 3u, 5u, 7u>;
using ExtractWordEven = UnpackWordLo;
using ExtractWordOdd = UnpackWordHi;

using ExtractByteEvenOdd = PermByte<0u, 2u, 5u, 7u>;
using ExtractByteOddEven = PermByte<1u, 3u, 4u, 6u>;
using ExtractWordEvenOdd = PermWord<0u, 3u>;
using ExtractWordOddEven = PermWord<1u, 2u>;

} // namespace Ops

} // namespace BlendImpl
Expand Down
3 changes: 2 additions & 1 deletion library/include/rocwmma/internal/convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define ROCWMMA_CONVERT_HPP

#include "types.hpp"
#include "utility/forward.hpp"

namespace rocwmma
{
Expand Down Expand Up @@ -58,7 +59,7 @@ namespace rocwmma
template <typename IncomingT>
ROCWMMA_DEVICE static inline auto exec(IncomingT&& regsIn) -> IncomingT&&
{
return std::forward<IncomingT>(regsIn);
return forward<IncomingT>(regsIn);
}
};

Expand Down
16 changes: 8 additions & 8 deletions library/include/rocwmma/internal/coop_load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ namespace rocwmma

// Outer loop = index 0,
// Inner loop = index N-1
template <std::size_t Depth = 0,
template <size_t Depth = 0,
typename Iterator,
typename StrideSpace,
typename Strides2d>
Expand All @@ -73,14 +73,14 @@ namespace rocwmma
StrideSpace&& strideSpace,
Strides2d&& strides2d)
{
static_assert(VecTraits<std::decay_t<StrideSpace>>::size()
== VecTraits<std::decay_t<Strides2d>>::size(),
static_assert(VecTraits<decay_t<StrideSpace>>::size()
== VecTraits<decay_t<Strides2d>>::size(),
"Mismatched size");
auto strideOffset = DataLayout::fromMatrixCoord(std::get<Depth>(strides2d), ldm);
auto strideCount = std::get<Depth>(strideSpace);
auto strideOffset = DataLayout::fromMatrixCoord(get<Depth>(strides2d), ldm);
auto strideCount = get<Depth>(strideSpace);

// Last depth layer will invoke the load
if constexpr(Depth == (VecTraits<std::decay_t<StrideSpace>>::size() - 1u))
if constexpr(Depth == (VecTraits<decay_t<StrideSpace>>::size() - 1u))
{
#pragma unroll
for(int i = 0; i < strideCount; i++)
Expand Down Expand Up @@ -135,7 +135,7 @@ namespace rocwmma
}

// Split the reduced stride space.
auto workItemsPerWave = std::max(totalWorkItems / maxWaves, 1u);
auto workItemsPerWave = max(totalWorkItems / maxWaves, 1u);
auto strideSpaceS = inflate_coord_left(workItemsPerWave - 1u, strideSpaceR) + 1u;

// Add back in the VW dimension, for the full stride
Expand Down Expand Up @@ -191,7 +191,7 @@ namespace rocwmma
}

// Split the reduced stride space.
constexpr auto workItemsPerWave = std::max(totalWorkItems / maxWaves, 1u);
constexpr auto workItemsPerWave = max(totalWorkItems / maxWaves, 1u);
constexpr auto strideSpaceS
= inflate_coord_left(workItemsPerWave - 1u, strideSpaceR) + 1u;

Expand Down
16 changes: 8 additions & 8 deletions library/include/rocwmma/internal/coop_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace rocwmma

// Outer loop = index 0,
// Inner loop = index N-1
template <std::size_t Depth = 0,
template <size_t Depth = 0,
typename Iterator,
typename StrideSpace,
typename Strides2d>
Expand All @@ -74,14 +74,14 @@ namespace rocwmma
StrideSpace&& strideCounts,
Strides2d&& strides2d)
{
static_assert(VecTraits<std::decay_t<StrideSpace>>::size()
== VecTraits<std::decay_t<Strides2d>>::size(),
static_assert(VecTraits<decay_t<StrideSpace>>::size()
== VecTraits<decay_t<Strides2d>>::size(),
"Mismatched size");
auto strideOffset = DataLayout::fromMatrixCoord(std::get<Depth>(strides2d), ldm);
auto strideCount = std::get<Depth>(strideCounts);
auto strideOffset = DataLayout::fromMatrixCoord(get<Depth>(strides2d), ldm);
auto strideCount = get<Depth>(strideCounts);

// Last depth layer will invoke the load
if constexpr(Depth == (VecTraits<std::decay_t<StrideSpace>>::size() - 1u))
if constexpr(Depth == (VecTraits<decay_t<StrideSpace>>::size() - 1u))
{
#pragma unroll
for(int i = 0; i < strideCount; i++)
Expand Down Expand Up @@ -136,7 +136,7 @@ namespace rocwmma
}

// Split the reduced stride space.
auto workItemsPerWave = std::max(totalWorkItems / maxWaves, 1u);
auto workItemsPerWave = max(totalWorkItems / maxWaves, 1u);
auto strideSpaceS = inflate_coord_left(workItemsPerWave - 1u, strideSpaceR) + 1u;

// Add back in the VW dimension, for the full stride
Expand Down Expand Up @@ -190,7 +190,7 @@ namespace rocwmma
}

// Split the reduced stride space.
constexpr auto workItemsPerWave = std::max(totalWorkItems / maxWaves, 1u);
constexpr auto workItemsPerWave = max(totalWorkItems / maxWaves, 1u);
constexpr auto strideSpaceS
= inflate_coord_left(workItemsPerWave - 1u, strideSpaceR) + 1u;

Expand Down
42 changes: 18 additions & 24 deletions library/include/rocwmma/internal/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@
using uint8_t = __hip_internal::uint8_t;
using uint16_t = __hip_internal::uint16_t;

namespace std
{
template <bool B, class T, class F>
struct conditional;
}

#endif

// We are clipping in down conversion by default
Expand Down Expand Up @@ -771,28 +765,28 @@ inline ROCWMMA_HOST_DEVICE bool operator!=(rocwmma_bf8 a, rocwmma_bf8 b)
template <typename T,
typename Ta,
bool stochastic_rounding,
typename std::enable_if<std::is_same<T, Ta>{}, int>::type = 0>
typename rocwmma::enable_if<rocwmma::is_same<T, Ta>{}, int>::type = 0>
inline ROCWMMA_HOST_DEVICE T explicit_downcast(Ta a)
{
// same type, no conversion
return a;
}

// Use h/w intrinsic and optimized version when __gfx940__
template <
typename T,
typename Ta,
bool stochastic_rounding,
typename std::enable_if<(!(std::is_same<T, Ta>{})
&& (std::is_same<T, rocwmma_f8>{} || std::is_same<T, rocwmma_bf8>{})),
int>::type
= 0>
template <typename T,
typename Ta,
bool stochastic_rounding,
typename rocwmma::enable_if<(!(rocwmma::is_same<T, Ta>{})
&& (rocwmma::is_same<T, rocwmma_f8>{}
|| rocwmma::is_same<T, rocwmma_bf8>{})),
int>::type
= 0>
inline ROCWMMA_HOST_DEVICE T explicit_downcast(Ta a, uint32_t rng)
{
#if ROCWMMA_ARCH_GFX940 || ROCWMMA_ARCH_GFX941 || ROCWMMA_ARCH_GFX942
// NOTE: we are directly calling cast_to_f8_from_f32 instead of constructor to optimize away one runtime branch
T val;
if(std::is_same<T, rocwmma_f8>::value)
if(rocwmma::is_same<T, rocwmma_f8>::value)
{
val.data = rocwmma_f8::cast_to_f8_from_f32<stochastic_rounding>(float(a), rng);
}
Expand All @@ -811,14 +805,14 @@ inline ROCWMMA_HOST_DEVICE T explicit_downcast(Ta a, uint32_t rng)

// NOTE NOTE: The above code is good if we don't consider HIP-GEMM code and only consider the quantization
// However, if we need HIP-GEMM for fall-back, we would need explicit_cast handles Tacc=f32 to To=f16/bf16 conversion
template <
typename T,
typename Ta,
bool stochastic_rounding,
typename std::enable_if<(!(std::is_same<T, Ta>{})
&& !(std::is_same<T, rocwmma_f8>{} || std::is_same<T, rocwmma_bf8>{})),
int>::type
= 0>
template <typename T,
typename Ta,
bool stochastic_rounding,
typename rocwmma::enable_if<(!(rocwmma::is_same<T, Ta>{})
&& !(rocwmma::is_same<T, rocwmma_f8>{}
|| rocwmma::is_same<T, rocwmma_bf8>{})),
int>::type
= 0>
inline ROCWMMA_HOST_DEVICE T explicit_downcast(Ta a, uint32_t rng)
{
// the return type is not a F8 types, no SR for those types
Expand Down
14 changes: 7 additions & 7 deletions library/include/rocwmma/internal/io_layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ namespace rocwmma
{
MaxVW = detail::MaxVWSelector<matrix_a, BlockDim, KDim, DataT, DataLayoutT, WaveCount>::
Result,
VW = std::is_same<DataLayoutT, row_major>::value ? MaxVW : 1u
VW = is_same<DataLayoutT, row_major>::value ? MaxVW : 1u
};

// Layout mapping for 1d / 2d
using DataLayout = DataLayout::template Array1d<DataLayoutT>;
using MatrixLayout
= MatrixLayout::template ColNT<BlockDim, KDim, DataT, DataLayoutT, VW, MaxVW>;

static_assert(!(std::is_same_v<DataLayoutT, col_major> && VW > 1),
static_assert(!(is_same_v<DataLayoutT, col_major> && VW > 1),
"matrix_a in col_major currently does not support VW > 1");
};

Expand All @@ -156,15 +156,15 @@ namespace rocwmma
{
MaxVW = detail::MaxVWSelector<matrix_b, BlockDim, KDim, DataT, DataLayoutT, WaveCount>::
Result,
VW = std::is_same<DataLayoutT, col_major>::value ? MaxVW : 1u
VW = is_same<DataLayoutT, col_major>::value ? MaxVW : 1u
};

// Layout mapping for 1d / 2d
using DataLayout = DataLayout::template Array1d<DataLayoutT>;
using MatrixLayout
= MatrixLayout::template RowNT<BlockDim, KDim, DataT, DataLayoutT, VW, MaxVW>;

static_assert(!(std::is_same_v<DataLayoutT, row_major> && VW > 1),
static_assert(!(is_same_v<DataLayoutT, row_major> && VW > 1),
"matrix_b in row_major currently does not support VW > 1");
};

Expand All @@ -178,16 +178,16 @@ namespace rocwmma
// Vector size properties
enum : uint32_t
{
MaxVW = (std::is_same<DataT, float64_t>::value || ROCWMMA_ARCH_GFX11) ? 1u : 4u,
VW = std::is_same<DataLayoutT, col_major>::value ? MaxVW : 1u
MaxVW = (is_same<DataT, float64_t>::value || ROCWMMA_ARCH_GFX11) ? 1u : 4u,
VW = is_same<DataLayoutT, col_major>::value ? MaxVW : 1u
};

// Layout mapping for 1d / 2d
using DataLayout = DataLayout::template Array1d<DataLayoutT>;
using MatrixLayout
= MatrixLayout::template RowNT<BlockDim, KDim, DataT, DataLayoutT, VW, MaxVW>;

static_assert(!(std::is_same<DataLayoutT, row_major>::value && VW > 1),
static_assert(!(is_same<DataLayoutT, row_major>::value && VW > 1),
"accumulator in row_major currently does not support VW > 1");
};

Expand Down
25 changes: 13 additions & 12 deletions library/include/rocwmma/internal/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#ifndef ROCWMMA_LAYOUT_HPP
#define ROCWMMA_LAYOUT_HPP

#include "utility/type_traits.hpp"
#include "layout_impl.hpp"

namespace rocwmma
Expand Down Expand Up @@ -188,8 +189,8 @@ namespace rocwmma
typename DataLayout,
uint32_t VectorWidth,
uint32_t MaxVectorWidth>
struct ColNT : public std::conditional_t<
std::is_same<DataLayout, col_major>::value,
struct ColNT : public conditional_t<
is_same<DataLayout, col_major>::value,
detail::ColOrthoVW<BlockDim, BlockK, DataT, 1, MaxVectorWidth>,
detail::ColOrthoVW<BlockDim, BlockK, DataT, VectorWidth, MaxVectorWidth>>
{
Expand All @@ -202,11 +203,11 @@ namespace rocwmma
// elements in both row_major or col_major data layouts.
// This layout cannot support for VW > 1 in col_major data layout otherwise the
// ordering is broken.
static_assert(!(std::is_same_v<DataLayout, col_major> && VectorWidth > 1),
static_assert(!(is_same_v<DataLayout, col_major> && VectorWidth > 1),
"ColNT in col_major does not support VectorWidth > 1");

// Must ensure that MaxVectorWidth fits inside the leading dimension
static_assert(std::is_same_v<DataLayout, row_major> && (MaxVectorWidth <= BlockK),
static_assert(is_same_v<DataLayout, row_major> && (MaxVectorWidth <= BlockK),
"MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth");
};
};
Expand Down Expand Up @@ -315,8 +316,8 @@ namespace rocwmma
typename DataLayout,
uint32_t VectorWidth,
uint32_t MaxVectorWidth>
struct RowNT : public std::conditional_t<
std::is_same<DataLayout, col_major>::value,
struct RowNT : public conditional_t<
is_same<DataLayout, col_major>::value,
detail::RowOrthoVW<BlockDim, BlockK, DataT, VectorWidth, MaxVectorWidth>,
detail::RowOrthoVW<BlockDim, BlockK, DataT, 1, MaxVectorWidth>>
{
Expand All @@ -329,11 +330,11 @@ namespace rocwmma
// elements in both in row_major or col_major data layouts.
// This layout cannot support for VW > 1 in row_major data layout otherwise the
// ordering is broken.
static_assert(!(std::is_same_v<DataLayout, row_major> && VectorWidth > 1),
static_assert(!(is_same_v<DataLayout, row_major> && VectorWidth > 1),
"RowNT in row_major does not support VectorWidth > 1");

// Must ensure that MaxVectorWidth fits inside the leading dimension
static_assert(std::is_same_v<DataLayout, col_major> && (MaxVectorWidth <= BlockK),
static_assert(is_same_v<DataLayout, col_major> && (MaxVectorWidth <= BlockK),
"MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth");
};
};
Expand Down Expand Up @@ -498,8 +499,8 @@ namespace rocwmma
typename DataLayout,
uint32_t VectorWidth,
uint32_t MaxVectorWidth = VectorWidth>
struct Col : public std::conditional_t<
std::is_same<DataLayout, col_major>::value,
struct Col : public conditional_t<
is_same<DataLayout, col_major>::value,
detail::ColInlineVW<BlockDim, BlockK, DataT, VectorWidth, MaxVectorWidth>,
detail::ColOrthoVW<BlockDim, BlockK, DataT, VectorWidth, MaxVectorWidth>>
{
Expand Down Expand Up @@ -664,8 +665,8 @@ namespace rocwmma
typename DataLayout,
uint32_t VectorWidth,
uint32_t MaxVectorWidth = VectorWidth>
struct Row : public std::conditional_t<
std::is_same<DataLayout, row_major>::value,
struct Row : public conditional_t<
is_same<DataLayout, row_major>::value,
detail::RowInlineVW<BlockDim, BlockK, DataT, VectorWidth, MaxVectorWidth>,
detail::RowOrthoVW<BlockDim, BlockK, DataT, VectorWidth, MaxVectorWidth>>
{
Expand Down
Loading

0 comments on commit 69694c8

Please sign in to comment.