Skip to content

Commit

Permalink
Move all HIP stuff to ggml-cuda.cu
Browse files Browse the repository at this point in the history
  • Loading branch information
SlyEcho committed May 4, 2023
1 parent d83cfba commit 04c0d48
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 54 deletions.
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,16 @@ if (LLAMA_HIPBLAS)
find_package(hipblas)

if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "hipBLAS found")
add_compile_definitions(GGML_USE_HIPBLAS)
add_library(ggml-hip OBJECT ggml-cuda.cu ggml-cuda.h)
message(STATUS "HIP and hipBLAS found")
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-hip PRIVATE hip::device)
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::hipblas)

if (LLAMA_STATIC)
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} hip::host roc::hipblas ggml-hip)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm)
else()
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
endif()
Expand Down
44 changes: 41 additions & 3 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,47 @@
#include <atomic>

#if defined(GGML_USE_HIPBLAS)
#include "hip/hip_runtime.h"
#include "hipblas/hipblas.h"
#include "hip/hip_fp16.h"
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEvent_t hipEvent_t
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocPortable)
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent hipStreamWaitEvent
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#else
#include <cuda_runtime.h>
#include <cublas_v2.h>
Expand Down
46 changes: 0 additions & 46 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
@@ -1,49 +1,3 @@
#if defined(GGML_USE_HIPBLAS)
#include "hipblas/hipblas.h"
#include "hip/hip_runtime.h"
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(h, m) HIPBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEvent_t hipEvent_t
#define cudaFree hipFree
#define cudaFreeHost hipFreeHost
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaMalloc hipMalloc
#define cudaMallocHost hipMallocHost
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent hipStreamWaitEvent
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define GGML_USE_CUBLAS
#else
#include <cublas_v2.h>
#include <cuda_runtime.h>
#endif
#include "ggml.h"

#ifdef __cplusplus
Expand Down

0 comments on commit 04c0d48

Please sign in to comment.