diff --git a/features/feature_case/cub/cub_block.cu b/features/feature_case/cub/cub_block.cu index aedce6a0d..84686b9a4 100644 --- a/features/feature_case/cub/cub_block.cu +++ b/features/feature_case/cub/cub_block.cu @@ -113,6 +113,33 @@ __global__ void BlockReduceKernel(int* data) { data[threadid] = output; } +__global__ void BlockExchangeStripedToBlocked(int* data) { + typedef cub::BlockExchange BlockExchange; + + __shared__ typename BlockExchange::TempStorage temp_storage; + + int threadid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y + blockIdx.x * blockDim.x * blockDim.y * blockDim.z; + int lane_id = thread_id % 16; + + //To-Do: Use LoadDirectStriped after header addition + // Manually load data into the striped arrangement + // Compute the global index for each item in the thread_data + int global_indices[8]; + for (int i = 0; i < 8; ++i) { + global_indices[i] = lane_id + i * 16; + } + + // Load data into the striped arrangement + for (int i = 0; i < 8; ++i) { + thread_data[i] = data[global_indices[i]]; + } + + // Collectively exchange data into a blocked arrangement across threads + BlockExchange(temp_storage).StripedToBlocked(thread_data, thread_data); + +} + + int main() { bool Result = true; @@ -276,6 +303,29 @@ int main() { print_data(dev_data, DATA_NUM); } + GridSize = 1; + BlockSize = 16; + int expect7[DATA_NUM] = { + 8128, 496, 495, 493, 490, 486, 481, 475, 468, 460, 451, 441, 430, 418, 405, 391, 376, 360, 343, 325, 306, 286, 265, 243, 220, 196, 171, 145, 118, 90, 61, 31, + 1520, 1488, 1455, 1421, 1386, 1350, 1313, 1275, 1236, 1196, 1155, 1113, 1070, 1026, 981, 935, 888, 840, 791, 741, 690, 638, 585, 531, 476, 420, 363, 305, 246, 186, 125, 63, + 2544, 2480, 2415, 2349, 2282, 2214, 2145, 2075, 2004, 1932, 1859, 1785, 1710, 1634, 1557, 1479, 1400, 1320, 1239, 1157, 1074, 990, 905, 819, 732, 644, 555, 465, 374, 282, 189, 95, + 3568, 3472, 3375, 3277, 3178, 3078, 2977, 2875, 2772, 2668, 2563, 2457, 2350, 2242, 2133, 2023, 1912, 1800, 1687, 1573, 1458, 1342, 1225, 1107, 988, 868, 747, 625, 502, 378, 253, 127 + }; + init_data(dev_data, DATA_NUM); + + BlockExchangeStripedToBlocked<<>>(dev_data); + + cudaDeviceSynchronize(); + if(!verify_data(dev_data, expect7, DATA_NUM, 128)) { + std::cout << "BlockExchangeStripedToBlocked" << " verify failed" << std::endl; + Result = false; + std::cout << "expect:" << std::endl; + print_data(expect6, DATA_NUM); + std::cout << "current result:" << std::endl; + print_data(dev_data, DATA_NUM); + } + + if(Result) { std::cout << "passed" << std::endl; return 0;