From 8a4bad50a8ed24ed1e9df003521468dcc37320e8 Mon Sep 17 00:00:00 2001 From: Fan Shupei Date: Thu, 25 Jul 2024 15:21:09 +0800 Subject: [PATCH] llama: use sliding window for phi3 (#8627) * use sliding window for phi3 * fix typo, "data_swa" -> "data" * [conver_hf_to_gguf.py] add phi3 sliding window --- convert_hf_to_gguf.py | 1 + src/llama.cpp | 37 ++++++++++++++++++++++++++++--------- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dde4fa9c80ca3..4087187c19834 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2084,6 +2084,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_count(rope_dims) self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_sliding_window(self.find_hparam(["sliding_window"])) # write rope scaling for long context (128k) model rope_scaling = self.find_hparam(['rope_scaling'], True) diff --git a/src/llama.cpp b/src/llama.cpp index 04eaf6730bc24..9e502018dfb76 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4889,6 +4889,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_PHI3: { + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { @@ -10748,7 +10749,7 @@ struct llm_build_context { struct ggml_tensor * inp_pos = build_inp_pos(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); for (int il = 0; il < n_layer; ++il) { auto residual = inpL; @@ -10806,7 +10807,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -14013,18 +14014,23 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { "causal attention is not supported by this model" ); - if (lctx.inp_KQ_mask) { + if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn && !lctx.is_encoding) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - float * data = (float *) lctx.inp_KQ_mask->data; + float * data = nullptr; float * data_swa = nullptr; + if (lctx.inp_KQ_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + data = (float *) lctx.inp_KQ_mask->data; + } + if (lctx.inp_KQ_mask_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); data_swa = (float *) lctx.inp_KQ_mask_swa->data; } @@ -14047,7 +14053,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { f = 0.0f; } } - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + + if (data) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } // may need to cut off old tokens for sliding window if (data_swa) { @@ -14059,9 +14068,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } } } }