Skip to content

Commit

Permalink
Merging mainline - WIP
Browse files Browse the repository at this point in the history
AVX2 and CUDA appear to work.
CUDA performance seems slightly (~1-2%) lower as it is so often
the case with llama.cpp/ggml after some "improvements" have been made.
  • Loading branch information
Kawrakow committed Jul 26, 2024
1 parent 6b2b52d commit a0849e4
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 32 deletions.
6 changes: 1 addition & 5 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,7 @@ option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true)

if (GGML_SYCL)
set(CMAKE_CXX_STANDARD 17)
else()
set(CMAKE_CXX_STANDARD 11)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED true)

set(THREADS_PREFER_PTHREAD_FLAG ON)
Expand Down
1 change: 1 addition & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,7 @@ add_library(ggml
${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
${GGML_SOURCES_IQK_MM} ${GGML_HEADERS_IQK_MM}
${GGML_SOURCES_IQK}
${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN}
ggml-aarch64.c ggml-aarch64.h
)
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2753,6 +2753,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
return true;
default:
return false;
Expand Down
12 changes: 6 additions & 6 deletions ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -3812,7 +3812,7 @@ static inline __m128i get_scale_shuffle(int i) {

void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(GGML_TASK_TYPE_COMPUTE, nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
return;
}
#endif
Expand Down Expand Up @@ -4296,7 +4296,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r

void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(GGML_TASK_TYPE_COMPUTE, nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
return;
}
#endif
Expand Down Expand Up @@ -4585,7 +4585,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r

void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(GGML_TASK_TYPE_COMPUTE, nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
return;
}
#endif
Expand Down Expand Up @@ -4942,7 +4942,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r

void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(GGML_TASK_TYPE_COMPUTE, nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
return;
}
#endif
Expand Down Expand Up @@ -5318,7 +5318,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r

void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(GGML_TASK_TYPE_COMPUTE, nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
return;
}
#endif
Expand Down Expand Up @@ -11692,7 +11692,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void

void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(GGML_TASK_TYPE_COMPUTE, nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
return;
}
#endif
Expand Down
17 changes: 10 additions & 7 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -12286,13 +12286,19 @@ static void ggml_compute_forward_mul_mat(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows

#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
#endif

#if GGML_USE_IQK_MULMAT
if (dst->type == GGML_TYPE_F32 && params->type == GGML_TASK_TYPE_COMPUTE && (ne12*ne13)%nth == 0) {
if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) {
int counter = 0;
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
if (counter++ % nth == ith) {
if (!iqk_mul_mat(params->type, ne01, ne11, ne00,
if (!iqk_mul_mat(ne01, ne11, ne00,
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
Expand All @@ -12305,7 +12311,7 @@ static void ggml_compute_forward_mul_mat(
if (dst->type == GGML_TYPE_F32) {
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!iqk_mul_mat(params->type, ne01, ne11, ne00,
if (!iqk_mul_mat(ne01, ne11, ne00,
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
Expand All @@ -12316,9 +12322,6 @@ IQK_MulMat_Not_Available1:;
#endif

#if GGML_USE_LLAMAFILE
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;

const bool src1_cont = ggml_is_contiguous(src1);

Expand Down Expand Up @@ -12386,7 +12389,7 @@ UseGgmlGemm1:;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!iqk_mul_mat(params->type, ne01, ne11, ne00,
if (!iqk_mul_mat(ne01, ne11, ne00,
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
Expand Down
4 changes: 1 addition & 3 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct MulMat {

}

bool iqk_mul_mat(int task_type, long Nx, long Ny, long ne00,
bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth) {
Expand All @@ -140,8 +140,6 @@ bool iqk_mul_mat(int task_type, long Nx, long Ny, long ne00,
return false;
}

if (ggml_task_type(task_type) != GGML_TASK_TYPE_COMPUTE) return ggml_task_type(task_type) == GGML_TASK_TYPE_INIT;

auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA));
auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB));

Expand Down
2 changes: 1 addition & 1 deletion ggml/src/iqk/iqk_mul_mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
extern "C" {
#endif

bool iqk_mul_mat(int task_type, long Nx, long Ny, long ne00,
bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth);
Expand Down
12 changes: 6 additions & 6 deletions ggml/src/iqk/iqk_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_p
return sizeof(block_iq1_bn)*nblock*nrows;
}

void quantize_row_iq1_bn_reference(const float * x, block_iq1_bn * y, int64_t k) {
void quantize_row_iq1_bn_ref(const float * x, block_iq1_bn * y, int64_t k) {
quantize_iq1_bn(x, y, 1, k, nullptr);
}

Expand Down Expand Up @@ -148,7 +148,7 @@ size_t quantize_iq2_bn(const float * src, void * dst, int64_t nrows, int64_t n_p
return sizeof(block_iq2_bn)*nblock*nrows;
}

void quantize_row_iq2_bn_reference(const float * x, block_iq2_bn * y, int64_t k) {
void quantize_row_iq2_bn_ref(const float * x, block_iq2_bn * y, int64_t k) {
quantize_iq2_bn(x, y, 1, k, nullptr);
}

Expand Down Expand Up @@ -236,7 +236,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64");

#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(GGML_TASK_TYPE_COMPUTE, 1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
return;
}
#endif
Expand Down Expand Up @@ -286,7 +286,7 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si

static_assert(QK_IQ1BN == 64, "This dot product implementation for iq2_bn requires a block size of 64");

if (iqk_mul_mat(GGML_TASK_TYPE_COMPUTE, 1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
return;
}

Expand Down Expand Up @@ -322,7 +322,7 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si

}

void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k) {
void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {

float * dptr = (float *)y;
auto qs = (int8_t *)(dptr + 4);
Expand Down Expand Up @@ -409,6 +409,6 @@ void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k)
}

void quantize_row_q8_K64(const float * x, void * y, int64_t k) {
quantize_row_q8_K64_reference(x, (block_q8_K64 *)y, k);
quantize_row_q8_K64_ref(x, (block_q8_K64 *)y, k);
}

6 changes: 3 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7320,7 +7320,7 @@ static bool llm_load_tensors(
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
}

const uint32_t n_ff = hparams.n_ff;
const uint32_t n_ff = hparams.n_ff();
model.layers.resize(n_layer);
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
Expand Down Expand Up @@ -13075,8 +13075,8 @@ struct llm_build_context {
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);

const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_head = hparams.n_head();
const int64_t n_head_kv = hparams.n_head_kv();
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_head_v = hparams.n_embd_head_v;
Expand Down
2 changes: 1 addition & 1 deletion tests/test-quantize-fns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ static float dot_product_error(

auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);

qfns.from_float_reference(test_data1, tmp_q1.data(), test_size);
qfns.from_float_ref(test_data1, tmp_q1.data(), test_size);
vdot.from_float(test_data2, tmp_q2.data(), test_size);

float result = INFINITY;
Expand Down

0 comments on commit a0849e4

Please sign in to comment.