Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama: use sliding window for phi3 #8627

Merged
merged 3 commits into from
Jul 25, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4974,6 +4974,8 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_PHI3:
{
hparams.n_swa = 2048;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
Expand Down Expand Up @@ -10843,7 +10845,7 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = build_inp_pos();

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();

for (int il = 0; il < n_layer; ++il) {
auto residual = inpL;
Expand Down Expand Up @@ -10901,7 +10903,7 @@ struct llm_build_context {

cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}

if (il == n_layer - 1) {
Expand Down Expand Up @@ -14108,18 +14110,23 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
"causal attention is not supported by this model"
);

if (lctx.inp_KQ_mask) {
if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn && !lctx.is_encoding) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));

float * data = (float *) lctx.inp_KQ_mask->data;
float * data = nullptr;
float * data_swa = nullptr;

if (lctx.inp_KQ_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
data = (float *) lctx.inp_KQ_mask->data;
}

if (lctx.inp_KQ_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
}

Expand All @@ -14142,7 +14149,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
f = 0.0f;
}
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;

if (data) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}

// may need to cut off old tokens for sliding window
if (data_swa) {
Expand All @@ -14154,9 +14164,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}

for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}

if (data_swa) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_swa is also modified by the code above (line 14158). It is used by gemma2.

Overwriting it here may break gemma2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm not familiar with gemma2, so I haven't test the PR on gemma2. I only test this PR with Phi3 on CPU.

I do not understand when it should padded to GGML_KD_MASK_PAD or not. I pad data_swa to GGML_KD_MASK_PAD because it looks like the original just forgets to do so. Actually, padding data_swa or not does not affect the correctness of Phi3 on CPU.

I'm confused on the original code that data is explicitly padded to GGML_KQ_MASK_PAD but data_swa is not. Is this the intended behavior? If yes, I'm happy to revert the change (padding data_swa to GGML_KQ_MASK_PAD). but I still want someone could explain to me what GGML_KQ_MASK_PAD actually means.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson I agree with @FanShupei here, I think data_swa should also be padded. I don't see why not, since the ranges of data written here and above do not overlap.

Not sure why this worked before though. Padding data_swa seems saner than leaving the values uninitialized.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both should be padded. The padding is necessary so that GPU kernels (such as the Metal Flash-Attention) not perform extra checks for out-of-bounds access when working on chunks of data

for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
}
Expand Down
Loading