From 6b125b39ebe9184e008906e8a1b6ea03ce60c330 Mon Sep 17 00:00:00 2001 From: Allison Vacanti Date: Wed, 7 Oct 2020 11:10:14 -0400 Subject: [PATCH] WIP Integrate CUB's scan into thrust. --- cub/device/dispatch/dispatch_scan.cuh | 52 +++++++++++++++++++++------ 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/cub/device/dispatch/dispatch_scan.cuh b/cub/device/dispatch/dispatch_scan.cuh index 833de674a2..1b3ff9370f 100644 --- a/cub/device/dispatch/dispatch_scan.cuh +++ b/cub/device/dispatch/dispatch_scan.cuh @@ -140,15 +140,45 @@ template < struct DeviceScanPolicy { + // TODO + // Thrust tests that scan will work with a 256-byte value type. The old + // custom scan implementation in thrust accomplished this by always using + // timesliced loads stores. Now that thrust is using CUB's scans directly, + // I've added a conditional check that will switch to timeslicing when the + // value type exceeds some threshold. + // + // All SMs currently use 128 bytes as this threshold, but this will probably + // need to be retuned. + // + // More worrying, it looks like we instantiate the scan agent for all of the + // declared policies, regardless of whether or not we're targeting them with + // the current build. This is likely adding very significant compilation + // overhead, and it's requiring me to update all policies to fit into my sm75 + // shared memory. This seems like a bad situation for both perf and compile + // time. + + // For large values, use timesliced loads/stores to fit shared memory. + template + struct LoadStoreAlgo + { + static constexpr bool LargeValues = sizeof(OutputT) > TimesliceLimit; + static constexpr BlockLoadAlgorithm Load = + LargeValues ? BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED + : BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr BlockStoreAlgorithm Store = + LargeValues ? BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED + : BLOCK_STORE_WARP_TRANSPOSE; + }; + /// SM10 struct Policy100 : ChainedPolicy<100, Policy100, Policy100> { typedef AgentScanPolicy< 64, 9, ///< Threads per block, items per thread OutputT, - BLOCK_LOAD_WARP_TRANSPOSE, + LoadStoreAlgo<128>::Load, LOAD_DEFAULT, - BLOCK_STORE_WARP_TRANSPOSE, + LoadStoreAlgo<128>::Store, BLOCK_SCAN_WARP_SCANS> ScanPolicyT; }; @@ -159,9 +189,9 @@ struct DeviceScanPolicy typedef AgentScanPolicy< 96, 21, ///< Threads per block, items per thread OutputT, - BLOCK_LOAD_WARP_TRANSPOSE, + LoadStoreAlgo<128>::Load, LOAD_DEFAULT, - BLOCK_STORE_WARP_TRANSPOSE, + LoadStoreAlgo<128>::Store, BLOCK_SCAN_RAKING_MEMOIZE> ScanPolicyT; }; @@ -173,9 +203,9 @@ struct DeviceScanPolicy typedef AgentScanPolicy< 128, 12, ///< Threads per block, items per thread OutputT, - BLOCK_LOAD_WARP_TRANSPOSE, + LoadStoreAlgo<128>::Load, LOAD_DEFAULT, - BLOCK_STORE_WARP_TRANSPOSE, + LoadStoreAlgo<128>::Store, BLOCK_SCAN_WARP_SCANS> ScanPolicyT; }; @@ -186,9 +216,9 @@ struct DeviceScanPolicy typedef AgentScanPolicy< 256, 9, ///< Threads per block, items per thread OutputT, - BLOCK_LOAD_WARP_TRANSPOSE, + LoadStoreAlgo<128>::Load, LOAD_DEFAULT, - BLOCK_STORE_WARP_TRANSPOSE, + LoadStoreAlgo<128>::Store, BLOCK_SCAN_WARP_SCANS> ScanPolicyT; }; @@ -216,7 +246,7 @@ struct DeviceScanPolicy OutputT, BLOCK_LOAD_DIRECT, LOAD_LDG, - BLOCK_STORE_WARP_TRANSPOSE, + LoadStoreAlgo<128>::Store, BLOCK_SCAN_WARP_SCANS> ScanPolicyT; }; @@ -227,9 +257,9 @@ struct DeviceScanPolicy typedef AgentScanPolicy< 128, 15, ///< Threads per block, items per thread OutputT, - BLOCK_LOAD_TRANSPOSE, + LoadStoreAlgo<128>::Load, LOAD_DEFAULT, - BLOCK_STORE_TRANSPOSE, + LoadStoreAlgo<128>::Store, BLOCK_SCAN_WARP_SCANS> ScanPolicyT; };