From 2d550231438b3ada8b65c2aa618387a3ead307a3 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Mon, 26 Jun 2023 10:27:38 +0200 Subject: [PATCH 1/2] dequantize + matrix multiplication CUDA kernels --- CMakeLists.txt | 4 + Makefile | 3 + ggml-cuda.cu | 541 ++++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 519 insertions(+), 29 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ffda74a700bef..b4b45dcf63e23 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,6 +71,7 @@ option(LLAMA_CUBLAS "llama: use cuBLAS" set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF) +option(LLAMA_CUDA_DMM "llama: use dequantize mul mat CUDA kernels" OFF) set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" OFF) @@ -251,6 +252,9 @@ if (LLAMA_CUBLAS) if (LLAMA_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_DMMV_F16) endif() + if (LLAMA_CUDA_DMM) + add_compile_definitions(GGML_CUDA_DMM) + endif() add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) if (LLAMA_STATIC) diff --git a/Makefile b/Makefile index 03f38bdba04ec..60221892f6580 100644 --- a/Makefile +++ b/Makefile @@ -179,6 +179,9 @@ endif # LLAMA_CUDA_DMMV_Y ifdef LLAMA_CUDA_DMMV_F16 NVCCFLAGS += -DGGML_CUDA_DMMV_F16 endif # LLAMA_CUDA_DMMV_F16 +ifdef LLAMA_CUDA_DMM + NVCCFLAGS += -DGGML_CUDA_DMM +endif # LLAMA_CUDA_DMM ifdef LLAMA_CUDA_KQUANTS_ITER NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) else diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4e0d3dbdea4d4..8ba0830d0e5a5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -58,7 +58,8 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif //GGML_CUDA_DMMV_F16 -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef void (*dequantize_2_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef float (*dequantize_1_kernel_t)(const void * vx, const int i); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -280,7 +281,257 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } -static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ float dequantize_1_f32(const void * vx, const int i){ + const float * x = (const float *) vx; + + return x[i]; +} + +static __device__ __forceinline__ float dequantize_1_f16(const void * vx, const int i){ + const half * x = (const half *) vx; + + return __half2float(x[i]); +} + +static __device__ __forceinline__ float dequantize_1_q4_0(const void * vx, const int i){ + const block_q4_0 * x = (const block_q4_0 *) vx; + const int ib = i / QK4_0; + + const float d = x[ib].d; + + const int iqs0 = i % QK4_0; + const int shift = iqs0 / (QK4_0/QR4_0); + const int iqs = iqs0 - shift * (QK4_0/QR4_0); + + int vi = x[ib].qs[iqs]; + + vi >>= 4 * shift; + vi &= 0xF; + + return (vi - 8) * d; +} + +static __device__ __forceinline__ float dequantize_1_q4_1(const void * vx, const int i){ + const block_q4_1 * x = (const block_q4_1 *) vx; + const int ib = i / QK4_1; + + const float d = x[ib].d; + const float m = x[ib].m; + + const int iqs0 = i % QK4_1; + const int shift = iqs0 / (QK4_1/QR4_1); + const int iqs = iqs0 - shift * (QK4_1/QR4_1); + + int vi = x[ib].qs[iqs]; + + vi >>= 4 * shift; + vi &= 0xF; + + return vi * d + m; +} + +static __device__ __forceinline__ float dequantize_1_q5_0(const void * vx, const int i){ + const block_q5_0 * x = (const block_q5_0 *) vx; + const int ib = i / QK4_0; + + const float d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int iqs0 = i % QK5_0; + const int shift = iqs0 / (QK5_0/QR5_0); + const int not_shift = shift ^ 1; + const int iqs = iqs0 - shift * (QK5_0/QR5_0); + + int vi = x[ib].qs[iqs]; + vi >>= 4 * shift; + vi &= 0xF; + + const int xh = ((qh >> (iqs + 12*shift)) << not_shift*4) & 0x10; + vi |= xh; + + return (vi - 16) * d; +} + +static __device__ __forceinline__ float dequantize_1_q5_1(const void * vx, const int i){ + const block_q5_1 * x = (const block_q5_1 *) vx; + const int ib = i / QK4_0; + + const float d = x[ib].d; + const float m = x[ib].m; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int iqs0 = i % QK5_0; + const int shift = iqs0 / (QK5_0/QR5_0); + const int not_shift = shift ^ 1; + const int iqs = iqs0 - shift * (QK5_0/QR5_0); + + int vi = x[ib].qs[iqs]; + vi >>= 4 * shift; + vi &= 0xF; + + const int xh = ((qh >> (iqs + 12*shift)) << not_shift*4) & 0x10; + vi |= xh; + + return vi * d + m; +} + +static __device__ __forceinline__ float dequantize_1_q8_0(const void * vx, const int i){ + const block_q8_0 * x = (const block_q8_0 *) vx; + const int ib = i / QK8_0; + + const float d = x[ib].d; + + const int iqs = i % QK8_0; + + const float v = x[ib].qs[iqs]; + + return v * d; +} + +static __device__ __forceinline__ float dequantize_1_q2_K(const void * vx, const int i){ + const block_q2_K * x = (const block_q2_K *) vx; + const int ib = i / QK_K; + + const int iy = i % QK_K; + const int n = iy / (QK_K/2); + + const float d = x[ib].d; + const float dmin = x[ib].dmin; + + const int iqs = iy % (QK_K/8) + n * (QK_K/8); + const int qs_shift = 2 * ((iy % (QK_K/2)) / (QK_K/8)); + const int qs = (x[ib].qs[iqs] >> qs_shift) & 3; + + const int isc = iy / (QK_K/16); + const int sc = x[ib].scales[isc]; + + const float dl = d * (sc & 0xF); + const float ml = dmin * (sc >> 4); + + return dl * qs - ml; +} + +static __device__ __forceinline__ float dequantize_1_q3_K(const void * vx, const int i){ + const block_q3_K * x = (const block_q3_K *) vx; + const int ib = i / QK_K; + + const int iy = i % QK_K; + const int n = iy / (QK_K/2); + + const float d = x[ib].d; + + const int iqs = iy % (QK_K/8) + n * (QK_K/8); + const int qs_shift = 2 * ((iy % (QK_K/2)) / (QK_K/8)); + const int qs = (x[ib].qs[iqs] >> qs_shift) & 3; + + const int ih = iy % (QK_K/8); + const int ih_shift = iy / (QK_K/8); + const int h = x[ib].hmask[ih] & (1 << ih_shift) ? 0 : 4; + + const int q = qs - h; + + const int isc = iy / (QK_K/16); + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (x[ib].scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((x[ib].scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + return d * sc * q; +} + +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +static __device__ __forceinline__ float dequantize_1_q4_K(const void * vx, const int i){ + const block_q4_K * x = (const block_q4_K *) vx; + const int ib = i / QK_K; + + const int iy = i % QK_K; + const int j = iy / (QK_K/4); + + const float d = x[ib].d; + const float dmin = x[ib].dmin; + + const int iqs = iy % (QK_K/8) + j * (QK_K/8); + const int qs_shift = 4 * ((iy % (QK_K/4)) / (QK_K/8)); + const int qs = (x[ib].qs[iqs] >> qs_shift) & 0xF; + + const int isc = iy / (QK_K/8); + uint8_t sc, m; + get_scale_min_k4(isc, x[ib].scales, sc, m); + + return d * sc * qs - dmin * m; +} + +static __device__ __forceinline__ float dequantize_1_q5_K(const void * vx, const int i){ + const block_q5_K * x = (const block_q5_K *) vx; + const int ib = i / QK_K; + + const int iy = i % QK_K; + const int j = iy / (QK_K/4); + + const float d = x[ib].d; + const float dmin = x[ib].dmin; + + const int iqs = iy % (QK_K/8) + j * (QK_K/8); + const int qs_shift = 4 * ((iy % (QK_K/4)) / (QK_K/8)); + const int qs = (x[ib].qs[iqs] >> qs_shift) & 0xF; + + const int isc = iy / (QK_K/8); + uint8_t sc, m; + get_scale_min_k4(isc, x[ib].scales, sc, m); + + const int iqh = iy % (QK_K/8); + const int qh_shift = iy / (QK_K/8); + const int qh = 16 * ((x[ib].qh[iqh] >> qh_shift) & 1); + + const int q = qs + qh; + + return d * sc * q - dmin * m; +} + +static __device__ __forceinline__ float dequantize_1_q6_K(const void * vx, const int i){ + const block_q6_K * x = (const block_q6_K *) vx; + const int ib = i / QK_K; + + const int iy = i % QK_K; + const int n = iy / (QK_K/2); + + const float d = x[ib].d; + + const int iql = iy % (QK_K/4) + n * (QK_K/4); + const int ql_shift = 4 * ((iy % (QK_K/2)) / (QK_K/4)); + const int ql = (x[ib].ql[iql] >> ql_shift) & 0xF; + + const int iqh = iy % (QK_K/8) + n * (QK_K/8); + const int qh_shift = 2 * ((iy % (QK_K/2)) / (QK_K/8)); + const int qh = (((x[ib].qh[iqh] >> qh_shift) & 3) << 4); + + const int q = (ql | qh) - 32; + + const int isc = iy / (QK_K/16); + const int sc = x[ib].scales[isc]; + + return d * sc * q; +} + +static __device__ __forceinline__ void dequantize_2_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; const dfloat d = x[ib].d; @@ -299,7 +550,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in #endif // GGML_CUDA_DMMV_F16 } -static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_2_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; const dfloat d = x[ib].d; @@ -319,7 +570,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in #endif // GGML_CUDA_DMMV_F16 } -static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_2_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; const dfloat d = x[ib].d; @@ -342,7 +593,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in #endif // GGML_CUDA_DMMV_F16 } -static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_2_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; const dfloat d = x[ib].d; @@ -366,7 +617,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in #endif // GGML_CUDA_DMMV_F16 } -static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_2_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; const dfloat d = x[ib].d; @@ -470,17 +721,6 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { } -#if QK_K == 256 -static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { - if (j < 4) { - d = q[j] & 63; m = q[j + 4] & 63; - } else { - d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - } -} -#endif - static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { const block_q4_K * x = (const block_q4_K *) vx; @@ -1153,7 +1393,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, v.y = x[ib + iqs + 1]; } -template +template static __global__ void dequantize_block(const void * vx, float * y, const int k) { const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; @@ -1174,7 +1414,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) y[iybs + iqs + y_offset] = v.y; } -template +template static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block @@ -1243,6 +1483,73 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, } } +template +static __global__ void dequantize_mul_mat( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst) { + + const int nrows_y = ncols_x; + const int ncols_dst = ncols_y; + + const int tid = threadIdx.x; + + const int row_dst_0 = blockIdx.x*WARP_SIZE; + const int row_x_0 = row_dst_0; + const int row_dst = row_dst_0 + tid; + + const int col_dst_0 = blockIdx.y*WARP_SIZE; + const int col_y_0 = col_dst_0; + + __shared__ float tile_x[WARP_SIZE][WARP_SIZE + 1]; + __shared__ float tile_y[WARP_SIZE][WARP_SIZE]; + float sum[WARP_SIZE] = {0.0f}; + + for (int col_x_0 = 0; col_x_0 < ncols_x; col_x_0 += WARP_SIZE) { + const int row_y_0 = col_x_0; + + const int col_x_tile = min(col_x_0 + tid, ncols_x-1); + +#pragma unroll + for (int j = 0; j < WARP_SIZE; ++j) { + const int row_x_tile = min(row_x_0 + j, nrows_x-1); + tile_x[j][tid] = dequantize_kernel(vx, row_x_tile*ncols_x + col_x_tile); + } + + const int row_y_tile = min(row_y_0 + tid, nrows_y-1); + +#pragma unroll + for (int i = 0; i < WARP_SIZE; ++i) { + const int col_y_tile = min(col_y_0 + i, ncols_y-1); + tile_y[i][tid] = y[col_y_tile*nrows_y + row_y_tile]; + } + +#pragma unroll + for (int i = 0; i < WARP_SIZE; ++i) { + const float xi = tile_x[tid][i]; + +#pragma unroll + for (int j = 0; j < WARP_SIZE; ++j) { + const float yi = tile_y[j][i]; + sum[j] += xi*yi; + } + } + } + + if (row_dst >= nrows_dst) { + return; + } + +#pragma unroll + for (int j = 0; j < WARP_SIZE; ++j) { + const int col_dst_j = col_dst_0 + j; + + if (col_dst_j >= ncols_dst) { + break; + } + + dst[col_dst_j*nrows_dst + row_dst] = sum[j]; + } +} + static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { const half * x = (const half *) vx; @@ -1491,27 +1798,27 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block<<>>(vx, y, k); } static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block<<>>(vx, y, k); } static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block<<>>(vx, y, k); } static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block<<>>(vx, y, k); } static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block<<>>(vx, y, k); } static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -1560,7 +1867,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -1569,7 +1876,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -1578,7 +1885,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -1587,7 +1894,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -1596,7 +1903,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -1685,6 +1992,102 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } +static void ggml_dequantize_mul_mat_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_f16_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_dequantize_mul_mat_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, 1, 1); + dequantize_mul_mat<<>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { const dim3 block_nums(1, nrows_x, nchannels_x); const dim3 block_dims(WARP_SIZE, 1, 1); @@ -1848,9 +2251,11 @@ void ggml_init_cublas() { // create main stream CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking)); +#ifndef GGML_CUDA_DMM // create cublas handle CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); +#endif // GGML_CUDA_DMM } // configure logging to stdout @@ -2140,6 +2545,80 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( (void) i1; } +inline void ggml_cuda_op_dequantize_mul_mat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, + float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t & cudaStream_main){ + + GGML_ASSERT(src0_ddq_i != nullptr); + GGML_ASSERT(src1_ddf_i != nullptr); + GGML_ASSERT(dst_ddf_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne11 = src1->ne[1]; + + const int64_t ne0 = dst->ne[0]; + + const int64_t i01_diff = i01_high - i01_low; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into + const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff; + + switch (src0->type) { + case GGML_TYPE_F32: + ggml_dequantize_mul_mat_f32_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_F16: + ggml_dequantize_mul_mat_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q4_0: + ggml_dequantize_mul_mat_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q4_1: + ggml_dequantize_mul_mat_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q5_0: + ggml_dequantize_mul_mat_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q5_1: + ggml_dequantize_mul_mat_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q8_0: + ggml_dequantize_mul_mat_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q2_K: + ggml_dequantize_mul_mat_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q3_K: + ggml_dequantize_mul_mat_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q4_K: + ggml_dequantize_mul_mat_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q5_K: + ggml_dequantize_mul_mat_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q6_K: + ggml_dequantize_mul_mat_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; + default: + GGML_ASSERT(false); + break; + } + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src0_ddf_i; + (void) i02; + (void) i1; +} + inline void ggml_cuda_op_mul_mat_cublas( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, @@ -2682,7 +3161,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false); } else { +#ifdef GGML_CUDA_DMM + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat, false, false); +#else ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); +#endif // GGML_CUDA_DMM } } else { GGML_ASSERT(false); From b90c80bdbf17e298533b16d71296327d46de1390 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 8 Jul 2023 22:53:43 +0200 Subject: [PATCH 2/2] Add __restrict__ to dequantize_mul_mat kernels --- ggml-cuda.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8ba0830d0e5a5..7227646ee7471 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1485,7 +1485,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, template static __global__ void dequantize_mul_mat( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst) { + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst) { const int nrows_y = ncols_x; const int ncols_dst = ncols_y;