Skip to content

Commit

Permalink
Enhance AMD support
Browse files Browse the repository at this point in the history
Summary: Support AMD GPU build.

Differential Revision: D56686760
  • Loading branch information
jianyuh authored and facebook-github-bot committed Apr 29, 2024
1 parent 561bd4e commit 37133ca
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 40 deletions.
20 changes: 6 additions & 14 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,9 @@
os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_gen_ai_py.so")
)
else:
if torch.version.hip:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:attention_ops_hip"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:quantize_ops_hip"
)
else:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:attention_ops_cuda"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:quantize_ops_cuda"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:attention_ops"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:quantize_ops"
)
42 changes: 42 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,48 @@ void set_gpu_max_dynamic_shared_memory(
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

#ifdef __HIP_PLATFORM_AMD__
using __nv_bfloat16 = hip_bfloat16;

typedef struct __align__(4) {
uint16_t x;
uint16_t y;
}
__nv_bfloat162_raw;

struct __align__(4) __nv_bfloat162 {
__nv_bfloat16 x;
__nv_bfloat16 y;
};

// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical
// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC
static __host__ __device__ __nv_bfloat16 __float2bfloat16(float f) {
__nv_bfloat16 output;
return output.round_to_bfloat16(f);
}

static __host__ __device__ __nv_bfloat16 __float2bfloat16_rn(float f) {
__nv_bfloat16 output;
return output.round_to_bfloat16(f);
}

static __host__ __device__ float __bfloat162float(__nv_bfloat16 f) {
// float output;
// https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/hip__bfloat16_8h_source.html
return float(f);
}

static __host__ __device__ __nv_bfloat162
__floats2bfloat162_rn(float x, float y) {
__nv_bfloat162 output;
output.x = __float2bfloat16_rn(x);
output.y = __float2bfloat16_rn(y);
return output;
}

#endif

// TODO: Include the following code from fbgemm_gpu header
struct __align__(16) bfx8 {
__nv_bfloat162 vals[4];
Expand Down
26 changes: 0 additions & 26 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,31 +90,6 @@ constexpr int32_t MAX_T = 16384;
constexpr int SMEM_ADJUST_THRESHOLD = 48 * 1024;

#ifdef __HIP_PLATFORM_AMD__
using __nv_bfloat16 = hip_bfloat16;

typedef struct __align__(4) {
uint16_t x;
uint16_t y;
}
__nv_bfloat162_raw;

struct __align__(4) __nv_bfloat162 {
__nv_bfloat16 x;
__nv_bfloat16 y;
};

// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical
// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC
static __host__ __device__ __nv_bfloat16 __float2bfloat16(float f) {
__nv_bfloat16 output;
return output.round_to_bfloat16(f);
}

static __host__ __device__ __nv_bfloat16 __float2bfloat16_rn(float f) {
__nv_bfloat16 output;
return output.round_to_bfloat16(f);
}

static __host__ __device__ float __bfloat162float(__nv_bfloat16 f) {
// float output;
// https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/hip__bfloat16_8h_source.html
Expand All @@ -128,7 +103,6 @@ __floats2bfloat162_rn(float x, float y) {
output.y = __float2bfloat16_rn(y);
return output;
}

#endif

struct __align__(16) bf16x8 {
Expand Down

0 comments on commit 37133ca

Please sign in to comment.