Skip to content

Commit

Permalink
updates for ROCm 6.0 support (#2088)
Browse files Browse the repository at this point in the history
Summary:

ROCm 6.0 introduces backwards-incompatible changes such as removing the long-deprecated use of `__HIP_PLATFORM_HCC__`. It is better to use the USE_ROCM macro which is already defined and indicates a ROCm build. This PR also defines `__HIP_PLATFORM_AMD__` which is the new symbol name. This symbol is still required for compiling with HIP headers but when not using hip-clang.


Reviewed By: sryap

Differential Revision: D50580075

Pulled By: sryap
  • Loading branch information
jeffdaily authored and facebook-github-bot committed Oct 24, 2023
1 parent f94254d commit 94e5034
Show file tree
Hide file tree
Showing 21 changed files with 57 additions and 59 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/cmake/Hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})
# Disable Asserts In Code (Can't use asserts on HIP stack.)
ADD_DEFINITIONS(-DNDEBUG)
ADD_DEFINITIONS(-DUSE_ROCM)
ADD_DEFINITIONS(-D__HIP_PLATFORM_AMD__)

IF(NOT DEFINED ENV{PYTORCH_ROCM_ARCH})
SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a)
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/embedding_backward_dense_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class SplitLookupFunction_Dense_Op

TORCH_CHECK_EQ(grad_outputs.size(), 1);

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ class {{ autograd_func }} :

TORCH_CHECK_EQ(grad_outputs.size(), 1);

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
#else
Expand Down
8 changes: 4 additions & 4 deletions fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e

// V100: 96 KB; A100: 160 KB; H100: 228 KB.
int max_shared_bytes = 0;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
cudaDeviceGetAttribute(&max_shared_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_weights.get_device());
#else
// MI100 has 64 KB local memory (shared memory) per workgroup
Expand All @@ -468,7 +468,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
C10_CUDA_KERNEL_LAUNCH_CHECK();
int shared_kb = max_shared_bytes >> 10;
// V100: 64 KB; A100: 96 KB; H100: 144 KB
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK_GT(used_shared_kb, 0);
Expand Down Expand Up @@ -740,7 +740,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
kMaxVecsPerThread,
kThreadGroupSize>;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
cudaFuncSetAttribute(
backward_cta_per_row_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
Expand Down Expand Up @@ -851,7 +851,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
if (std::is_same<emb_t, uint8_t>::value) {
shmem_bytes = BT_block_size * sizeof(
at::acc_type<cache_t, true>) * 4 * kWarpSize * kMaxVecsPerThread;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
cudaFuncSetAttribute(
backward_warp_per_row_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru

const uint32_t subwarp_id = threadIdx.x / 4;
const uint32_t subwarp_tid = threadIdx.x % 4;
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
const uint64_t subwarp_mask = static_cast<uint64_t>(0xF) << (4 * subwarp_id);
#else
const uint32_t subwarp_mask = static_cast<uint32_t>(0xF) << (4 * subwarp_id);
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ batch_index_select_dim0_codegen_forward_small_kernel(
{%- endif %}

{% if not dense %}
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Support only the split-pooled TBE case
template <
typename emb_t,
Expand Down Expand Up @@ -647,7 +647,7 @@ batch_index_select_dim0_codegen_forward_cuda(
// if (!is_experimental)
} else {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
TORCH_CHECK(false, "is_experimental=True is not supported in ROCm");
#else
// Allocate num warps per table based on max_D
Expand Down
49 changes: 23 additions & 26 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh>

// clang-format off
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
#define HIPCUB_ARCH 1
#include <hipcub/backend/rocprim/block/block_scan.hpp>
#else
Expand All @@ -35,8 +35,7 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
#if !defined(__HIP_PLATFORM_HCC__) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 9000
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 9000
#define FBGEMM_USE_SUBWARP_SHUFFLE
#endif

Expand All @@ -58,14 +57,14 @@ namespace fbgemm_gpu {

enum class PrimitiveType : uint8_t { FP = 0, INT = 1, BF = 2 };

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
namespace cub = hipcub;
#endif

#define DEVICE_INLINE __device__ inline __attribute__((always_inline))

// Warp size
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
static constexpr int32_t kWarpSize = 64;
#else
static constexpr int32_t kWarpSize = 32;
Expand Down Expand Up @@ -93,7 +92,7 @@ struct Half4 {
half2 b;

__device__ inline void store(at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
p[0] = __low2half(a);
p[1] = __high2half(a);
p[2] = __low2half(b);
Expand Down Expand Up @@ -157,7 +156,7 @@ struct Vec4T<float> {
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
union U {
half2 h[2];
uint2 ui;
Expand Down Expand Up @@ -311,7 +310,7 @@ struct Vec4T<at::Half> {
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
union U {
half2 h[2];
uint2 ui;
Expand Down Expand Up @@ -409,7 +408,7 @@ struct Vec4T<at::Half> {
}

DEVICE_INLINE static void copy(const at::Half* src, at::Half* dst) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
dst[0] = src[0];
dst[1] = src[1];
dst[2] = src[2];
Expand Down Expand Up @@ -525,7 +524,7 @@ struct Vec4T<at::BFloat16> {
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
union U {
half2 h[2];
uint2 ui;
Expand Down Expand Up @@ -705,7 +704,7 @@ struct Vec4T<double> {
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
union U {
half2 h[2];
uint2 ui;
Expand Down Expand Up @@ -854,7 +853,7 @@ DEVICE_INLINE T shfl_xor(
int laneMask,
int width = kWarpSize,
unsigned shfl_sync_mask = kFullWarpMask) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
return __shfl_xor(val, laneMask, width);
#else
return __shfl_xor_sync(shfl_sync_mask, val, laneMask, width);
Expand All @@ -867,7 +866,7 @@ DEVICE_INLINE T shfl_sync(
int srcLane = 0,
int width = kWarpSize,
unsigned shfl_sync_mask = kFullWarpMask) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
return __shfl(val, srcLane, width);
#else
return __shfl_sync(shfl_sync_mask, val, srcLane, width);
Expand All @@ -880,21 +879,21 @@ DEVICE_INLINE T shfl_down_sync(
unsigned delta,
int width = kWarpSize,
unsigned shfl_sync_mask = kFullWarpMask) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
return __shfl_down(val, delta, width);
#else
return __shfl_down_sync(shfl_sync_mask, val, delta, width);
#endif
}

#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
DEVICE_INLINE uint64_t ballot_sync(
#else
DEVICE_INLINE uint32_t ballot_sync(
#endif
int predicate,
unsigned shfl_sync_mask = kFullWarpMask) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
return __ballot(predicate);
#else
return __ballot_sync(shfl_sync_mask, predicate);
Expand All @@ -913,7 +912,7 @@ warpReduceAllSum(T val, unsigned shfl_sync_mask = kFullWarpMask) {
}

DEVICE_INLINE void syncwarp() {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
// Performance - replace a block level __syncthreads with per CU
// __threadfence_block. It is a fine replacement for __syncwarp on AMD GPUs,
// it is because a. memory fencing: __threadfence_block ops. at CU level,
Expand Down Expand Up @@ -1002,7 +1001,7 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
template <typename K, typename V, bool Dir, typename Comp>
struct BitonicSort {
static inline __device__ void sort(K k[1], V v[1]) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
static_assert(fbgemm_gpu::kWarpSize == 64, "unexpected warp size");
#else
static_assert(fbgemm_gpu::kWarpSize == 32, "unexpected warp size");
Expand Down Expand Up @@ -1607,7 +1606,7 @@ struct __align__(32) half16 {
half2 vals[8];
};

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
using __nv_bfloat16 = hip_bfloat16;

typedef struct __align__(4) {
Expand Down Expand Up @@ -1689,7 +1688,7 @@ DEVICE_INLINE half16 to_half16(float_16 v) {

// Override __bfloat162float to accept at::BFloat16
static DEVICE_INLINE float __bfloat162float(const at::BFloat16 input) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
return float(*reinterpret_cast<const __nv_bfloat16*>(&input));
#else
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&input));
Expand All @@ -1709,7 +1708,7 @@ static DEVICE_INLINE float to_float(const at::BFloat16 input) {
return __bfloat162float(input);
}

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical
// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC
static __host__ __device__ __nv_bfloat16 __float2bfloat16(float f) {
Expand Down Expand Up @@ -1829,8 +1828,7 @@ DEVICE_INLINE float_16 make_zero_float_16() {

__forceinline__ __device__ __half2
hfma2(const __half2 a, const __half2 b, const __half2 c) {
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || \
defined(__HIP_PLATFORM_HCC__)
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || defined(USE_ROCM)
return __hfma2(a, b, c);
#else
float2 fa, fb, fc;
Expand All @@ -1844,8 +1842,7 @@ hfma2(const __half2 a, const __half2 b, const __half2 c) {
}

__forceinline__ __device__ half hmul(half a, half b) {
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || \
defined(__HIP_PLATFORM_HCC__)
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || defined(USE_ROCM)
return __hmul(a, b);
#else
return __float2half(__half2float(a) * __half2float(b));
Expand Down Expand Up @@ -3603,7 +3600,7 @@ DEVICE_INLINE float float16_min(float_16 val) {
// ROCm does not natively support __any_sync(). Using __ballot()
// (https://rocmdocs.amd.com/en/latest/Programming_Guides/Kernel_language.html)
// to implement __any_sync(). Note: the "warp-size" of AMD GPU is 64.
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
__device__ int __any_sync(uint64_t mask, int predicate) {
uint64_t predicate_bit_pattern = __ballot(predicate);
return (predicate_bit_pattern & mask) > 0;
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#pragma once

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
#define HIPCUB_ARCH 1
#endif

Expand Down
10 changes: 5 additions & 5 deletions fbgemm_gpu/src/jagged_tensor_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
matches &= (y_0_reshaped.size(1) < INT_MAX);

int max_shared_bytes;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaDeviceGetAttribute(
&max_shared_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
Expand All @@ -671,7 +671,7 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
max_shared_bytes = 64 << 10;
#endif
int shared_kb = max_shared_bytes >> 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK(used_shared_kb > 0);
Expand Down Expand Up @@ -779,7 +779,7 @@ void jagged_dense_elementwise_jagged_output_opt_(
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
if (dynamic_smem_size > cur_max_shared_bytes) {
int max_shared_bytes;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaDeviceGetAttribute(
&max_shared_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
Expand All @@ -789,7 +789,7 @@ void jagged_dense_elementwise_jagged_output_opt_(
max_shared_bytes = 64 << 10;
#endif
int shared_kb = max_shared_bytes >> 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK(used_shared_kb > 0);
Expand All @@ -798,7 +798,7 @@ void jagged_dense_elementwise_jagged_output_opt_(
int used_shared_kb = shared_kb;
#endif
int used_shared_bytes = used_shared_kb << 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
index_t>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
if (dynamic_smem_size > cur_max_shared_bytes) {
int max_shared_bytes;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaDeviceGetAttribute(
&max_shared_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
Expand All @@ -107,7 +107,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
max_shared_bytes = 64 << 10;
#endif
int shared_kb = max_shared_bytes >> 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK_GT(used_shared_kb, 0);
Expand All @@ -116,7 +116,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
int used_shared_kb = shared_kb;
#endif
int used_shared_bytes = used_shared_kb << 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
index_t>,
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/src/quantize_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <ATen/TensorIterator.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
#include <math_constants.h>
#endif

Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ __global__ inline void _get_8bit_qparam_cuda_kernel(
const int output_columns = ncols_aligned + 2 * sizeof(float);

// starting values for future reductions
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
#define HIPRT_INF_F __int_as_float(0x7f800000)
float minimum_element = HIPRT_INF_F;
float maximum_element = -HIPRT_INF_F;
Expand Down
Loading

0 comments on commit 94e5034

Please sign in to comment.