Skip to content

Commit

Permalink
Workaround for CUDA 12.6
Browse files Browse the repository at this point in the history
It appears that in CUDA 12.6 cub explicitly include nvtx3. This causes a
conflict with nvtx included by AMReX_TinyProfiler.cpp and
AMReX_GpuDevice.cpp. This is a workaround for the issue.

I think this is a cub bug. It's incompatible with the default `#include
<nvTools.h>` included by nvcc.
  • Loading branch information
WeiqunZhang committed Aug 6, 2024
1 parent c09da99 commit 16c3c4d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 13 deletions.
4 changes: 4 additions & 0 deletions Src/Base/AMReX_GpuControl.H
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#define AMREX_CUDA_GE_11_2 1
#endif

#if !defined(AMREX_USE_CUB) && defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
#define AMREX_USE_CUB 1
#endif

#if defined(AMREX_USE_HIP) || defined(AMREX_CUDA_GE_11_2)
#define AMREX_GPU_STREAM_ALLOC_SUPPORT 1
#endif
Expand Down
9 changes: 8 additions & 1 deletion Src/Base/AMReX_GpuDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
#if defined(AMREX_USE_CUDA)
#include <cuda_profiler_api.h>
#if defined(AMREX_PROFILING) || defined (AMREX_TINY_PROFILING)
#include <nvToolsExt.h>
# if defined(AMREX_USE_CUB)
// Since 12.6 cub might include nvtx3. We include cub here so that we
// can avoid conflict.
# include <cub/cub.cuh>
# endif
# if !defined(NVTX_VERSION)
# include <nvToolsExt.h>
# endif
#endif
#endif

Expand Down
4 changes: 0 additions & 4 deletions Src/Base/AMReX_GpuReduce.H
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
#include <AMReX_Functional.H>
#include <AMReX_TypeTraits.H>

#if !defined(AMREX_USE_CUB) && defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
#define AMREX_USE_CUB 1
#endif

#if defined(AMREX_USE_CUB)
#include <cub/cub.cuh>
#elif defined(AMREX_USE_HIP)
Expand Down
8 changes: 0 additions & 8 deletions Src/Base/AMReX_TinyProfiler.H
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
#include <AMReX_INT.H>
#include <AMReX_REAL.H>

#ifdef AMREX_USE_CUDA
#include <nvToolsExt.h>
#endif

#if defined(AMREX_USE_HIP) && defined(AMREX_USE_ROCTX)
#include <roctracer/roctx.h>
#endif

#include <array>
#include <deque>
#include <iosfwd>
Expand Down
15 changes: 15 additions & 0 deletions Src/Base/AMReX_TinyProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@
#include <omp.h>
#endif

#ifdef AMREX_USE_CUDA
# if defined(AMREX_USE_CUB)
// Since 12.6 cub might include nvtx3. We include cub here so that we
// can avoid conflict.
# include <cub/cub.cuh>
# endif
# if !defined(NVTX_VERSION)
# include <nvToolsExt.h>
# endif
#endif

#if defined(AMREX_USE_HIP) && defined(AMREX_USE_ROCTX)
#include <roctracer/roctx.h>
#endif

#include <algorithm>
#include <cmath>
#include <iostream>
Expand Down

0 comments on commit 16c3c4d

Please sign in to comment.