From dbd02958fa4f46898f68ca29c27ddcdc58a06f98 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 29 Dec 2023 03:32:31 -0500 Subject: [PATCH] ggml : fix some mul mat cases + add tests for src1 F16 (#669) * fixed mul-mat error for old GPUs * style fixes * add mul mat src1 f16 test cases, fix more cases ggml-ci --------- Co-authored-by: bssrdf Co-authored-by: slaren --- src/ggml-backend.c | 8 +++- src/ggml-cuda.cu | 89 +++++++++++++++++++------------------- src/ggml.c | 2 +- tests/test-backend-ops.cpp | 14 +++--- 4 files changed, 60 insertions(+), 53 deletions(-) diff --git a/src/ggml-backend.c b/src/ggml-backend.c index 526ce732b..2c3752067 100644 --- a/src/ggml-backend.c +++ b/src/ggml-backend.c @@ -614,10 +614,14 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c } static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - return true; + switch (op->op) { + case GGML_OP_MUL_MAT: + return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; + default: + return true; + } GGML_UNUSED(backend); - GGML_UNUSED(op); } static struct ggml_backend_i cpu_backend_i = { diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index abad9cc39..9a9effcf5 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -7485,6 +7485,8 @@ static void ggml_cuda_op_dequantize_mul_mat_vec( const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; + GGML_ASSERT(src1->type == GGML_TYPE_F32); + // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics #ifdef GGML_CUDA_F16 cuda_pool_alloc src1_dfloat_a; @@ -7577,6 +7579,7 @@ static void ggml_cuda_op_mul_mat_cublas( const int compute_capability = g_device_caps[id].cc; if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { + //printf("this branch\n"); // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 cuda_pool_alloc src0_as_f16; if (src0->type != GGML_TYPE_F16) { @@ -7614,9 +7617,9 @@ static void ggml_cuda_op_mul_mat_cublas( const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); - } - else { + } else { cuda_pool_alloc src0_ddq_as_f32; + cuda_pool_alloc src1_ddq_as_f32; if (src0->type != GGML_TYPE_F32) { const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); @@ -7624,7 +7627,15 @@ static void ggml_cuda_op_mul_mat_cublas( src0_ddq_as_f32.alloc(row_diff*ne00); to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream); } + if (src1->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + src1_ddq_as_f32.alloc(src1_ncols*ne10); + to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream); + } + const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); + const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); const float alpha = 1.0f; const float beta = 0.0f; @@ -7633,9 +7644,9 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK( cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, - &alpha, src0_ddf_i, ne00, - src1_ddf_i, ne10, - &beta, dst_dd_i, ldc)); + &alpha, src0_ddf_i, ne00, + src1_ddf1_i, ne10, + &beta, dst_dd_i, ldc)); } (void) dst; @@ -8035,6 +8046,7 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); @@ -8481,9 +8493,9 @@ static __global__ void k_compute_batched_ptrs( int64_t i03 = i13 / r3; int64_t i02 = i12 / r2; - ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; - ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2; - ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; + ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; + ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13; + ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; } static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -8492,28 +8504,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const int64_t nb01 = src0->nb[1]; - const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); - const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const int64_t nb11 = src1->nb[1]; - const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); - const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); + GGML_TENSOR_BINARY_OP_LOCALS - const int64_t ne1 = ggml_nelements(src1); - const int64_t ne = ggml_nelements(dst); + const int64_t ne_dst = ggml_nelements(dst); ggml_cuda_set_device(g_main_device); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; @@ -8522,7 +8516,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; - half * src0_as_f16 = (half *) src0_ddq; + half * src0_f16 = (half *) src0_ddq; ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; @@ -8531,11 +8525,15 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; // convert src1 to fp16 - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - - cuda_pool_alloc src1_as_f16(ne1); - to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream); + cuda_pool_alloc src1_f16_alloc; + if (src1->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + const int64_t ne_src1 = ggml_nelements(src1); + src1_f16_alloc.alloc(ne_src1); + GGML_ASSERT(to_fp16_cuda != nullptr); + to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream); + } + half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get(); cuda_pool_alloc dst_f16; char * dst_t; @@ -8557,7 +8555,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const const void * beta = &beta_f16; if (dst->op_params[0] == GGML_PREC_DEFAULT) { - dst_t = (char *) dst_f16.alloc(ne); + dst_t = (char *) dst_f16.alloc(ne_dst); nbd2 /= sizeof(float) / sizeof(half); nbd3 /= sizeof(float) / sizeof(half); @@ -8604,9 +8602,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_CHECK( cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA - (const char *) src1_as_f16.get(), CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB - beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC + alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA + (const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB + beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC ne12*ne13, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -8619,12 +8617,13 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const dim3 block_dims(ne13, ne12); k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( - src0_as_f16, src1_as_f16.get(), dst_t, + src0_f16, src1_f16, dst_t, ptrs_src.get(), ptrs_dst.get(), ne12, ne13, ne23, nb02, nb03, - nb12, nb13, + src1->type == GGML_TYPE_F16 ? nb12 : nb12/2, + src1->type == GGML_TYPE_F16 ? nb13 : nb13/2, nbd2, nbd3, r2, r3); CUDA_CHECK(cudaGetLastError()); @@ -8632,8 +8631,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_CHECK( cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/sizeof(half), - (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/sizeof(float), + alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, + (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10, beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01, ne23, cu_compute_type, @@ -8643,7 +8642,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const if (dst->op_params[0] == GGML_PREC_DEFAULT) { const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream); + to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream); } } @@ -8682,13 +8681,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_cuda_mul_mat_vec_nc(src0, src1, dst); - } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { + } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { + if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) { #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; #else diff --git a/src/ggml.c b/src/ggml.c index ed56e60a8..a9e1ea9b4 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -9687,7 +9687,7 @@ static void ggml_compute_forward_mul_mat( const size_t row_size = ggml_row_size(vec_dot_type, ne10); assert(params->wsize >= ne11*ne12*ne13*row_size); - assert(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f3df8a8c6..b115299c0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -350,13 +350,18 @@ struct test_case { fflush(stdout); // check if backends support op + bool supported = true; for (ggml_backend_t backend : {backend1, backend2}) { if (!ggml_backend_supports_op(backend, out)) { - printf("not supported\n"); - ggml_free(ctx); - return true; + printf("not supported [%s] ", ggml_backend_name(backend)); + supported = false; } } + if (!supported) { + printf("\n"); + ggml_free(ctx); + return true; + } // post-graph sentinel add_sentinel(ctx); @@ -1505,8 +1510,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } for (ggml_type type_a : all_types) { - for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { - // FIXME: CPU crashes on f16xf16 + for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));