diff --git a/Src/Base/AMReX_GpuControl.H b/Src/Base/AMReX_GpuControl.H index 4cc8abffdc..bb23942eb8 100644 --- a/Src/Base/AMReX_GpuControl.H +++ b/Src/Base/AMReX_GpuControl.H @@ -11,6 +11,10 @@ #define AMREX_CUDA_GE_11_2 1 #endif +#if !defined(AMREX_USE_CUB) && defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11) +#define AMREX_USE_CUB 1 +#endif + #if defined(AMREX_USE_HIP) || defined(AMREX_CUDA_GE_11_2) #define AMREX_GPU_STREAM_ALLOC_SUPPORT 1 #endif diff --git a/Src/Base/AMReX_GpuDevice.cpp b/Src/Base/AMReX_GpuDevice.cpp index 5e63646121..b001aca33a 100644 --- a/Src/Base/AMReX_GpuDevice.cpp +++ b/Src/Base/AMReX_GpuDevice.cpp @@ -19,7 +19,14 @@ #if defined(AMREX_USE_CUDA) #include #if defined(AMREX_PROFILING) || defined (AMREX_TINY_PROFILING) -#include +# if defined(AMREX_USE_CUB) + // Since 12.6 cub might include nvtx3. We include cub here so that we + // can avoid conflict. +# include +# endif +# if !defined(NVTX_VERSION) +# include +# endif #endif #endif diff --git a/Src/Base/AMReX_GpuReduce.H b/Src/Base/AMReX_GpuReduce.H index b10d22bbcf..7c7cf63bf9 100644 --- a/Src/Base/AMReX_GpuReduce.H +++ b/Src/Base/AMReX_GpuReduce.H @@ -10,10 +10,6 @@ #include #include -#if !defined(AMREX_USE_CUB) && defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11) -#define AMREX_USE_CUB 1 -#endif - #if defined(AMREX_USE_CUB) #include #elif defined(AMREX_USE_HIP) diff --git a/Src/Base/AMReX_TinyProfiler.H b/Src/Base/AMReX_TinyProfiler.H index 4ffc5bdef2..f995eb4a2c 100644 --- a/Src/Base/AMReX_TinyProfiler.H +++ b/Src/Base/AMReX_TinyProfiler.H @@ -5,14 +5,6 @@ #include #include -#ifdef AMREX_USE_CUDA -#include -#endif - -#if defined(AMREX_USE_HIP) && defined(AMREX_USE_ROCTX) -#include -#endif - #include #include #include diff --git a/Src/Base/AMReX_TinyProfiler.cpp b/Src/Base/AMReX_TinyProfiler.cpp index 7e84b457e7..c3c577936d 100644 --- a/Src/Base/AMReX_TinyProfiler.cpp +++ b/Src/Base/AMReX_TinyProfiler.cpp @@ -15,6 +15,21 @@ #include #endif +#ifdef AMREX_USE_CUDA +# if defined(AMREX_USE_CUB) + // Since 12.6 cub might include nvtx3. We include cub here so that we + // can avoid conflict. +# include +# endif +# if !defined(NVTX_VERSION) +# include +# endif +#endif + +#if defined(AMREX_USE_HIP) && defined(AMREX_USE_ROCTX) +#include +#endif + #include #include #include