Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates for ROCm 6.0 support (#2086) #2088

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fbgemm_gpu/cmake/Hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})
# Disable Asserts In Code (Can't use asserts on HIP stack.)
ADD_DEFINITIONS(-DNDEBUG)
ADD_DEFINITIONS(-DUSE_ROCM)
ADD_DEFINITIONS(-D__HIP_PLATFORM_AMD__)

IF(NOT DEFINED ENV{PYTORCH_ROCM_ARCH})
SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a)
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/embedding_backward_dense_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class SplitLookupFunction_Dense_Op

TORCH_CHECK_EQ(grad_outputs.size(), 1);

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ class {{ autograd_func }} :

TORCH_CHECK_EQ(grad_outputs.size(), 1);

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
#else
Expand Down
8 changes: 4 additions & 4 deletions fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e

// V100: 96 KB; A100: 160 KB; H100: 228 KB.
int max_shared_bytes = 0;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
cudaDeviceGetAttribute(&max_shared_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_weights.get_device());
#else
// MI100 has 64 KB local memory (shared memory) per workgroup
Expand All @@ -468,7 +468,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
C10_CUDA_KERNEL_LAUNCH_CHECK();
int shared_kb = max_shared_bytes >> 10;
// V100: 64 KB; A100: 96 KB; H100: 144 KB
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK_GT(used_shared_kb, 0);
Expand Down Expand Up @@ -740,7 +740,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
kMaxVecsPerThread,
kThreadGroupSize>;

#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
cudaFuncSetAttribute(
backward_cta_per_row_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
Expand Down Expand Up @@ -851,7 +851,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
if (std::is_same<emb_t, uint8_t>::value) {
shmem_bytes = BT_block_size * sizeof(
at::acc_type<cache_t, true>) * 4 * kWarpSize * kMaxVecsPerThread;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
cudaFuncSetAttribute(
backward_warp_per_row_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru

const uint32_t subwarp_id = threadIdx.x / 4;
const uint32_t subwarp_tid = threadIdx.x % 4;
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
const uint64_t subwarp_mask = static_cast<uint64_t>(0xF) << (4 * subwarp_id);
#else
const uint32_t subwarp_mask = static_cast<uint32_t>(0xF) << (4 * subwarp_id);
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ batch_index_select_dim0_codegen_forward_small_kernel(
{%- endif %}

{% if not dense %}
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Support only the split-pooled TBE case
template <
typename emb_t,
Expand Down Expand Up @@ -647,7 +647,7 @@ batch_index_select_dim0_codegen_forward_cuda(
// if (!is_experimental)
} else {

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
TORCH_CHECK(false, "is_experimental=True is not supported in ROCm");
#else
// Allocate num warps per table based on max_D
Expand Down
49 changes: 23 additions & 26 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh>

// clang-format off
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
#define HIPCUB_ARCH 1
#include <hipcub/backend/rocprim/block/block_scan.hpp>
#else
Expand All @@ -35,8 +35,7 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
#if !defined(__HIP_PLATFORM_HCC__) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 9000
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 9000
#define FBGEMM_USE_SUBWARP_SHUFFLE
#endif

Expand All @@ -58,14 +57,14 @@ namespace fbgemm_gpu {

enum class PrimitiveType : uint8_t { FP = 0, INT = 1, BF = 2 };

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
namespace cub = hipcub;
#endif

#define DEVICE_INLINE __device__ inline __attribute__((always_inline))

// Warp size
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
static constexpr int32_t kWarpSize = 64;
#else
static constexpr int32_t kWarpSize = 32;
Expand Down Expand Up @@ -93,7 +92,7 @@ struct Half4 {
half2 b;

__device__ inline void store(at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
p[0] = __low2half(a);
p[1] = __high2half(a);
p[2] = __low2half(b);
Expand Down Expand Up @@ -157,7 +156,7 @@ struct Vec4T<float> {
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
union U {
half2 h[2];
uint2 ui;
Expand Down Expand Up @@ -311,7 +310,7 @@ struct Vec4T<at::Half> {
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
union U {
half2 h[2];
uint2 ui;
Expand Down Expand Up @@ -409,7 +408,7 @@ struct Vec4T<at::Half> {
}

DEVICE_INLINE static void copy(const at::Half* src, at::Half* dst) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
dst[0] = src[0];
dst[1] = src[1];
dst[2] = src[2];
Expand Down Expand Up @@ -525,7 +524,7 @@ struct Vec4T<at::BFloat16> {
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
union U {
half2 h[2];
uint2 ui;
Expand Down Expand Up @@ -705,7 +704,7 @@ struct Vec4T<double> {
}

DEVICE_INLINE void load(const at::Half* p) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
union U {
half2 h[2];
uint2 ui;
Expand Down Expand Up @@ -854,7 +853,7 @@ DEVICE_INLINE T shfl_xor(
int laneMask,
int width = kWarpSize,
unsigned shfl_sync_mask = kFullWarpMask) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
return __shfl_xor(val, laneMask, width);
#else
return __shfl_xor_sync(shfl_sync_mask, val, laneMask, width);
Expand All @@ -867,7 +866,7 @@ DEVICE_INLINE T shfl_sync(
int srcLane = 0,
int width = kWarpSize,
unsigned shfl_sync_mask = kFullWarpMask) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
return __shfl(val, srcLane, width);
#else
return __shfl_sync(shfl_sync_mask, val, srcLane, width);
Expand All @@ -880,21 +879,21 @@ DEVICE_INLINE T shfl_down_sync(
unsigned delta,
int width = kWarpSize,
unsigned shfl_sync_mask = kFullWarpMask) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
return __shfl_down(val, delta, width);
#else
return __shfl_down_sync(shfl_sync_mask, val, delta, width);
#endif
}

#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
DEVICE_INLINE uint64_t ballot_sync(
#else
DEVICE_INLINE uint32_t ballot_sync(
#endif
int predicate,
unsigned shfl_sync_mask = kFullWarpMask) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
#if defined(USE_ROCM) || CUDA_VERSION < 9000
return __ballot(predicate);
#else
return __ballot_sync(shfl_sync_mask, predicate);
Expand All @@ -913,7 +912,7 @@ warpReduceAllSum(T val, unsigned shfl_sync_mask = kFullWarpMask) {
}

DEVICE_INLINE void syncwarp() {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
// Performance - replace a block level __syncthreads with per CU
// __threadfence_block. It is a fine replacement for __syncwarp on AMD GPUs,
// it is because a. memory fencing: __threadfence_block ops. at CU level,
Expand Down Expand Up @@ -1002,7 +1001,7 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
template <typename K, typename V, bool Dir, typename Comp>
struct BitonicSort {
static inline __device__ void sort(K k[1], V v[1]) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
static_assert(fbgemm_gpu::kWarpSize == 64, "unexpected warp size");
#else
static_assert(fbgemm_gpu::kWarpSize == 32, "unexpected warp size");
Expand Down Expand Up @@ -1607,7 +1606,7 @@ struct __align__(32) half16 {
half2 vals[8];
};

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
using __nv_bfloat16 = hip_bfloat16;

typedef struct __align__(4) {
Expand Down Expand Up @@ -1689,7 +1688,7 @@ DEVICE_INLINE half16 to_half16(float_16 v) {

// Override __bfloat162float to accept at::BFloat16
static DEVICE_INLINE float __bfloat162float(const at::BFloat16 input) {
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
return float(*reinterpret_cast<const __nv_bfloat16*>(&input));
#else
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&input));
Expand All @@ -1709,7 +1708,7 @@ static DEVICE_INLINE float to_float(const at::BFloat16 input) {
return __bfloat162float(input);
}

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical
// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC
static __host__ __device__ __nv_bfloat16 __float2bfloat16(float f) {
Expand Down Expand Up @@ -1829,8 +1828,7 @@ DEVICE_INLINE float_16 make_zero_float_16() {

__forceinline__ __device__ __half2
hfma2(const __half2 a, const __half2 b, const __half2 c) {
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || \
defined(__HIP_PLATFORM_HCC__)
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || defined(USE_ROCM)
return __hfma2(a, b, c);
#else
float2 fa, fb, fc;
Expand All @@ -1844,8 +1842,7 @@ hfma2(const __half2 a, const __half2 b, const __half2 c) {
}

__forceinline__ __device__ half hmul(half a, half b) {
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || \
defined(__HIP_PLATFORM_HCC__)
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || defined(USE_ROCM)
return __hmul(a, b);
#else
return __float2half(__half2float(a) * __half2float(b));
Expand Down Expand Up @@ -3603,7 +3600,7 @@ DEVICE_INLINE float float16_min(float_16 val) {
// ROCm does not natively support __any_sync(). Using __ballot()
// (https://rocmdocs.amd.com/en/latest/Programming_Guides/Kernel_language.html)
// to implement __any_sync(). Note: the "warp-size" of AMD GPU is 64.
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
__device__ int __any_sync(uint64_t mask, int predicate) {
uint64_t predicate_bit_pattern = __ballot(predicate);
return (predicate_bit_pattern & mask) > 0;
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#pragma once

#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
#define HIPCUB_ARCH 1
#endif

Expand Down
10 changes: 5 additions & 5 deletions fbgemm_gpu/src/jagged_tensor_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
matches &= (y_0_reshaped.size(1) < INT_MAX);

int max_shared_bytes;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaDeviceGetAttribute(
&max_shared_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
Expand All @@ -671,7 +671,7 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
max_shared_bytes = 64 << 10;
#endif
int shared_kb = max_shared_bytes >> 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK(used_shared_kb > 0);
Expand Down Expand Up @@ -779,7 +779,7 @@ void jagged_dense_elementwise_jagged_output_opt_(
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
if (dynamic_smem_size > cur_max_shared_bytes) {
int max_shared_bytes;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaDeviceGetAttribute(
&max_shared_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
Expand All @@ -789,7 +789,7 @@ void jagged_dense_elementwise_jagged_output_opt_(
max_shared_bytes = 64 << 10;
#endif
int shared_kb = max_shared_bytes >> 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK(used_shared_kb > 0);
Expand All @@ -798,7 +798,7 @@ void jagged_dense_elementwise_jagged_output_opt_(
int used_shared_kb = shared_kb;
#endif
int used_shared_bytes = used_shared_kb << 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
index_t>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
if (dynamic_smem_size > cur_max_shared_bytes) {
int max_shared_bytes;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaDeviceGetAttribute(
&max_shared_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
Expand All @@ -107,7 +107,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
max_shared_bytes = 64 << 10;
#endif
int shared_kb = max_shared_bytes >> 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK_GT(used_shared_kb, 0);
Expand All @@ -116,7 +116,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
int used_shared_kb = shared_kb;
#endif
int used_shared_bytes = used_shared_kb << 10;
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
index_t>,
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/src/quantize_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <ATen/TensorIterator.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
#include <math_constants.h>
#endif

Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ __global__ inline void _get_8bit_qparam_cuda_kernel(
const int output_columns = ncols_aligned + 2 * sizeof(float);

// starting values for future reductions
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
#define HIPRT_INF_F __int_as_float(0x7f800000)
float minimum_element = HIPRT_INF_F;
float maximum_element = -HIPRT_INF_F;
Expand Down
Loading
Loading