Skip to content

Commit

Permalink
Merge branch '184-add-support-for-u-int128_t-in-deviceradixsort' into…
Browse files Browse the repository at this point in the history
… 'develop_stream'

make device_radix_sort compatible with compiler provided __int128_t and __uint128_t

See merge request amd/libraries/rocPRIM!549
  • Loading branch information
parbenc authored and Naraenda committed Oct 18, 2023
2 parents 759a6f4 + 8461645 commit 38ba018
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 11 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ Full documentation for rocPRIM is available at [https://rocprim.readthedocs.io/e
## [Unreleased rocPRIM-3.0.0 for ROCm 6.0.0]
### Added
### Changed
- Removed deprecated functionality: `reduce_by_key_config`, `MatchAny`, `scan_config`, `scan_by_key_config` and `radix_sort_config`.
- Renamed `scan_config_v2` to `scan_config`, `scan_by_key_config_v2` to `scan_by_key_config`, `radix_sort_config_v2` to `radix_sort_config`, `reduce_by_key_config_v2` to `reduce_by_key_config`, `radix_sort_config_v2` to `radix_sort_config`.
- Removed support for custom config types for device algorithms.
- Removed deprecated functionality: `reduce_by_key_config`, `MatchAny`, `scan_config`, `scan_by_key_config` and `radix_sort_config`.
- Renamed `scan_config_v2` to `scan_config`, `scan_by_key_config_v2` to `scan_by_key_config`, `radix_sort_config_v2` to `radix_sort_config`, `reduce_by_key_config_v2` to `reduce_by_key_config`, `radix_sort_config_v2` to `radix_sort_config`.
- Removed support for custom config types for device algorithms.
- `host_warp_size()` was moved into `rocprim/device/config_types.hpp`, and now uses either a `device_id` or a `stream` parameter to query the proper device and a `device_id` out parameter. The return type is `hipError_t`.
- Added support for __int128_t in `device_radix_sort` and `block_radix_sort`.
### Fixed

## [Unreleased rocPRIM-2.13.1 for ROCm 5.7.0]
Expand Down
71 changes: 70 additions & 1 deletion rocprim/include/rocprim/detail/radix_sort.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -67,6 +67,33 @@ struct radix_key_codec_integral<Key, BitKey, typename std::enable_if<::rocprim::
}
};

template<class Key, class BitKey>
struct radix_key_codec_integral<
Key,
BitKey,
typename std::enable_if<std::is_same<Key, __uint128_t>::value>::type>
{
using bit_key_type = BitKey;

ROCPRIM_DEVICE ROCPRIM_INLINE static bit_key_type encode(Key key)
{
return __builtin_bit_cast(bit_key_type, key);
}

ROCPRIM_DEVICE ROCPRIM_INLINE static Key decode(bit_key_type bit_key)
{
return __builtin_bit_cast(Key, bit_key);
}

template<bool Descending>
ROCPRIM_DEVICE static unsigned int
extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
{
unsigned int mask = (1u << length) - 1;
return static_cast<unsigned int>(bit_key >> start) & mask;
}
};

template<class Key, class BitKey>
struct radix_key_codec_integral<Key, BitKey, typename std::enable_if<::rocprim::is_signed<Key>::value>::type>
{
Expand Down Expand Up @@ -97,6 +124,36 @@ struct radix_key_codec_integral<Key, BitKey, typename std::enable_if<::rocprim::
}
};

template<class Key, class BitKey>
struct radix_key_codec_integral<Key,
BitKey,
typename std::enable_if<std::is_same<Key, __int128_t>::value>::type>
{
using bit_key_type = BitKey;

static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1);

ROCPRIM_DEVICE ROCPRIM_INLINE static bit_key_type encode(Key key)
{
const bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key);
return sign_bit ^ bit_key;
}

ROCPRIM_DEVICE ROCPRIM_INLINE static Key decode(bit_key_type bit_key)
{
bit_key ^= sign_bit;
return __builtin_bit_cast(Key, bit_key);
}

template<bool Descending>
ROCPRIM_DEVICE static unsigned int
extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
{
unsigned int mask = (1u << length) - 1;
return static_cast<unsigned int>(bit_key >> start) & mask;
}
};

template<class Key>
struct float_bit_mask;

Expand Down Expand Up @@ -199,6 +256,18 @@ struct radix_key_codec_base<
typename std::enable_if<::rocprim::is_integral<Key>::value>::type
> : radix_key_codec_integral<Key, typename std::make_unsigned<Key>::type> { };

template<class Key>
struct radix_key_codec_base<Key,
typename std::enable_if<std::is_same<Key, __int128_t>::value>::type>
: radix_key_codec_integral<Key, __uint128_t>
{};

template<class Key>
struct radix_key_codec_base<Key,
typename std::enable_if<std::is_same<Key, __uint128_t>::value>::type>
: radix_key_codec_integral<Key, __uint128_t>
{};

template<>
struct radix_key_codec_base<bool>
{
Expand Down
4 changes: 2 additions & 2 deletions test/rocprim/test_block_radix_sort.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIT License
//
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -40,7 +40,7 @@ struct RocprimBlockRadixSort;

struct Integral;
#define suite_name RocprimBlockRadixSort
#define warp_params BlockParamsIntegral
#define warp_params BlockParamsIntegralExtended
#define name_suffix Integral

#include "test_block_radix_sort.hpp"
Expand Down
6 changes: 5 additions & 1 deletion test/rocprim/test_device_radix_sort.cpp.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIT License
//
// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -49,6 +49,10 @@
#endif

#if ROCPRIM_TEST_TYPE_SLICE == 0
#if defined(__GNUC__) || defined(__clang__)
INSTANTIATE(params<__int128_t, __int128_t>)
INSTANTIATE(params<__uint128_t, __uint128_t>)
#endif
INSTANTIATE(params<signed char, double, true>)
INSTANTIATE(params<int, short>)
INSTANTIATE(params<short, int, true>)
Expand Down
77 changes: 76 additions & 1 deletion test/rocprim/test_utils_assertions.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -240,6 +240,81 @@ void assert_bit_eq(const std::vector<T>& result, const std::vector<T>& expected)
}
}

#if defined(__GNUC__) || defined(__clang__)
inline void assert_bit_eq(const std::vector<__int128_t>& result,
const std::vector<__int128_t>& expected)
{
ASSERT_EQ(result.size(), expected.size());

auto to_string = [](__int128_t value)
{
static const char* charmap = "0123456789";

std::string result;
result.reserve(41); // max. 40 digits possible ( uint64_t has 20) plus sign
__uint128_t helper = (value < 0) ? -value : value;

do
{
result += charmap[helper % 10];
helper /= 10;
}
while(helper);
if(value < 0)
{
result += "-";
}
std::reverse(result.begin(), result.end());
return result;
};

for(size_t i = 0; i < result.size(); i++)
{
if(!bit_equal(result[i], expected[i]))
{
FAIL() << "Expected strict/bitwise equality of these values: " << std::endl
<< " result[i]: " << to_string(result[i]) << std::endl
<< " expected[i]: " << to_string(expected[i]) << std::endl
<< "where index = " << i;
}
}
}

inline void assert_bit_eq(const std::vector<__uint128_t>& result,
const std::vector<__uint128_t>& expected)
{
ASSERT_EQ(result.size(), expected.size());

auto to_string = [](__uint128_t value)
{
static const char* charmap = "0123456789";

std::string result;
result.reserve(40); // max. 40 digits possible ( uint64_t has 20)
__uint128_t helper = value;

do
{
result += charmap[helper % 10];
helper /= 10;
}
while(helper);
std::reverse(result.begin(), result.end());
return result;
};

for(size_t i = 0; i < result.size(); i++)
{
if(!bit_equal(result[i], expected[i]))
{
FAIL() << "Expected strict/bitwise equality of these values: " << std::endl
<< " result[i]: " << to_string(result[i]) << std::endl
<< " expected[i]: " << to_string(expected[i]) << std::endl
<< "where index = " << i;
}
}
}
#endif
}

#endif //ROCPRIM_TEST_UTILS_ASSERTIONS_HPP
86 changes: 85 additions & 1 deletion test/rocprim/test_utils_data_generation.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2022 Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -159,6 +159,90 @@ void add_special_values(std::vector<T>& source, seed_type seed_value)
}
}

template<class T, class U, class V>
inline auto get_random_data(size_t size, U min, V max, seed_type seed_value) ->
typename std::enable_if<std::is_same<T, __int128_t>::value, std::vector<T>>::type
{
engine_type gen{seed_value};
using dis_type = typename std::conditional<
is_valid_for_int_distribution<T>::value,
T,
typename std::conditional<std::is_signed<T>::value, int, unsigned int>::type>::type;
std::uniform_int_distribution<dis_type> distribution(static_cast<dis_type>(min),
static_cast<dis_type>(max));
std::vector<T> data(size);
size_t segment_size = size / random_data_generation_segments;
if(segment_size != 0)
{
for(uint32_t segment_index = 0; segment_index < random_data_generation_segments;
segment_index++)
{
if(segment_index % random_data_generation_repeat_strides == 0)
{
T repeated_value = static_cast<T>(distribution(gen));
std::fill(data.begin() + segment_size * segment_index,
data.begin() + segment_size * (segment_index + 1),
repeated_value);
}
else
{
std::generate(data.begin() + segment_size * segment_index,
data.begin() + segment_size * (segment_index + 1),
[&]() { return static_cast<T>(distribution(gen)); });
}
}
}
else
{
std::generate(data.begin(),
data.end(),
[&]() { return static_cast<T>(distribution(gen)); });
}
return data;
}

template<class T, class U, class V>
inline auto get_random_data(size_t size, U min, V max, seed_type seed_value) ->
typename std::enable_if<std::is_same<T, __uint128_t>::value, std::vector<T>>::type
{
engine_type gen{seed_value};
using dis_type = typename std::conditional<
is_valid_for_int_distribution<T>::value,
T,
typename std::conditional<std::is_signed<T>::value, int, unsigned int>::type>::type;
std::uniform_int_distribution<dis_type> distribution(static_cast<dis_type>(min),
static_cast<dis_type>(max));
std::vector<T> data(size);
size_t segment_size = size / random_data_generation_segments;
if(segment_size != 0)
{
for(uint32_t segment_index = 0; segment_index < random_data_generation_segments;
segment_index++)
{
if(segment_index % random_data_generation_repeat_strides == 0)
{
T repeated_value = static_cast<T>(distribution(gen));
std::fill(data.begin() + segment_size * segment_index,
data.begin() + segment_size * (segment_index + 1),
repeated_value);
}
else
{
std::generate(data.begin() + segment_size * segment_index,
data.begin() + segment_size * (segment_index + 1),
[&]() { return static_cast<T>(distribution(gen)); });
}
}
}
else
{
std::generate(data.begin(),
data.end(),
[&]() { return static_cast<T>(distribution(gen)); });
}
return data;
}

template<class T, class U, class V>
inline auto get_random_data(size_t size, U min, V max, seed_type seed_value)
-> typename std::enable_if<rocprim::is_integral<T>::value, std::vector<T>>::type
Expand Down
42 changes: 41 additions & 1 deletion test/rocprim/test_utils_sort_comparator.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIT License
//
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -65,6 +65,46 @@ struct key_comparator<Key,
}
};

template<class Key, bool Descending, unsigned int StartBit, unsigned int EndBit>
struct key_comparator<Key,
Descending,
StartBit,
EndBit,
typename std::enable_if<std::is_same<Key, __int128_t>::value>::type>
{
static constexpr Key radix_mask_upper
= EndBit == 8 * sizeof(Key) ? ~Key(0) : (Key(1) << EndBit) - 1;
static constexpr Key radix_mask_bottom = (Key(1) << StartBit) - 1;
static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom;

bool operator()(const Key& lhs, const Key& rhs) const
{
Key l = lhs & radix_mask;
Key r = rhs & radix_mask;
return Descending ? (r < l) : (l < r);
}
};

template<class Key, bool Descending, unsigned int StartBit, unsigned int EndBit>
struct key_comparator<Key,
Descending,
StartBit,
EndBit,
typename std::enable_if<std::is_same<Key, __uint128_t>::value>::type>
{
static constexpr Key radix_mask_upper
= EndBit == 8 * sizeof(Key) ? ~Key(0) : (Key(1) << EndBit) - 1;
static constexpr Key radix_mask_bottom = (Key(1) << StartBit) - 1;
static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom;

bool operator()(const Key& lhs, const Key& rhs) const
{
Key l = lhs & radix_mask;
Key r = rhs & radix_mask;
return Descending ? (r < l) : (l < r);
}
};

template<class Key, bool Descending, unsigned int StartBit, unsigned int EndBit>
struct key_comparator<Key,
Descending,
Expand Down
Loading

0 comments on commit 38ba018

Please sign in to comment.