Skip to content

Commit

Permalink
refactor: allow narrowing embeddings/rerank model ctx size
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <thxcode0824@gmail.com>
  • Loading branch information
thxCode committed Nov 28, 2024
1 parent 1916377 commit 7909630
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions llama-box/patches/llama.cpp/embedding.patch
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ index ab5e376e..658fd56a 100644
}
#endif
diff --git a/src/llama.cpp b/src/llama.cpp
index af5e686e..60721c87 100644
index af5e686e..7fc47b78 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -19693,10 +19693,10 @@ struct llama_context * llama_new_context_with_model(
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;

// this is necessary due to kv_self.n being padded later during inference
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
+ cparams.n_ctx = hparams.causal_attn ? GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams)) : hparams.n_ctx_train;
+ cparams.n_ctx = hparams.causal_attn ? GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams)) : std::min(cparams.n_ctx, hparams.n_ctx_train);

// with causal attention, the batch size is limited by the context size
- cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
+ cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : hparams.n_ctx_train;
+ cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : cparams.n_ctx;

// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
Expand All @@ -33,7 +33,7 @@ index af5e686e..60721c87 100644
}

- cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
+ cparams.n_ubatch = hparams.causal_attn ? std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch) : hparams.n_ctx_train;
+ cparams.n_ubatch = hparams.causal_attn ? std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch) : cparams.n_ctx;

cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
Expand Down

0 comments on commit 7909630

Please sign in to comment.