Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[Graph] Falcon-7B optimization (#1199)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhentaoyu authored Aug 1, 2023
1 parent 173aa17 commit 2a82ee0
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ bool falcon_model_load(const std::string& fname, falcon_model& model, gpt_vocab&
const int n_vocab = hparams.n_vocab;
const int head_dim = hparams.n_embd / hparams.n_head;

ctx_size += n_embd * n_vocab * ne_type_sizef(NE_TYPE_F32); // tok_embeddings
ctx_size += n_embd * n_vocab * ne_type_sizef(wtype); // tok_embeddings

ctx_size += n_embd * ne_type_sizef(NE_TYPE_F32); // output_norm
ctx_size += n_embd * ne_type_sizef(NE_TYPE_F32); // output_norm_b
Expand Down Expand Up @@ -218,7 +218,7 @@ bool falcon_model_load(const std::string& fname, falcon_model& model, gpt_vocab&

model.layers.resize(n_layer);

model.tok_embeddings = ne_new_tensor_2d(ctx, NE_TYPE_F32, n_embd, n_vocab, NE_SIZE_CALC);
model.tok_embeddings = ne_new_tensor_2d(ctx, wtype, n_embd, n_vocab, NE_SIZE_CALC);
model.output_norm = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd, NE_SIZE_CALC);
model.output_norm_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd, NE_SIZE_CALC);
model.lm_head = ne_new_tensor_2d(ctx, NE_TYPE_F32, n_embd, n_vocab, NE_SIZE_CALC);
Expand Down Expand Up @@ -411,7 +411,6 @@ bool falcon_eval(const falcon_model& model, const int n_threads, const int n_pas

// wte
struct ne_tensor* inpL = ne_get_rows(ctx0, model.tok_embeddings, embd);
struct ne_tensor* repeat_dummy = ne_new_tensor_3d(ctx0, inpL->type, head_dim, N + n_past, n_head, NE_SIZE_CALC);

for (int il = 0; il < n_layer; ++il) {
struct ne_tensor* cur;
Expand Down Expand Up @@ -453,28 +452,34 @@ bool falcon_eval(const falcon_model& model, const int n_threads, const int n_pas

// store key and value to memory
{
struct ne_tensor* k = ne_view_1d(ctx0, model.memory_k, N * head_dim,
(ne_element_size(model.memory_k) * head_dim) * (il * n_ctx + n_past));
struct ne_tensor* v = ne_view_1d(ctx0, model.memory_v, N * head_dim,
(ne_element_size(model.memory_v) * head_dim) * (il * n_ctx + n_past));

ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur, k));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur, v));
// head_dim, 1 (head_num), N --> head_dim, N, 1 (head_num)
struct ne_tensor* Kcur_permuted = ne_permute(ctx0, Kcur, 0, 2, 1, 3);
// head_dim, 1 (head_num), N --> N, head_dim, 1 (head_num)
struct ne_tensor* Vcur_permuted = ne_permute(ctx0, Vcur, 1, 2, 0, 3);

struct ne_tensor* k =
ne_view_3d(ctx0, model.memory_k, head_dim, N, 1, ne_element_size(model.memory_k) * head_dim,
ne_element_size(model.memory_k) * head_dim * n_ctx,
il * n_ctx * ne_element_size(model.memory_k) * head_dim +
n_past * ne_element_size(model.memory_k) * head_dim);
struct ne_tensor* v = ne_view_3d(
ctx0, model.memory_v, N, head_dim, 1, n_ctx * ne_element_size(model.memory_v),
n_ctx * ne_element_size(model.memory_v) * head_dim,
il * n_ctx * ne_element_size(model.memory_v) * head_dim + n_past * ne_element_size(model.memory_v));

ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_permuted, k));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_permuted, v));
}

// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3);

struct ne_tensor* K =
ne_permute(ctx0,
ne_reshape_3d(ctx0,
ne_view_1d(ctx0, model.memory_k, (n_past + N) * head_dim,
il * n_ctx * ne_element_size(model.memory_k) * head_dim),
head_dim, 1, n_past + N),
0, 2, 1, 3);
ne_view_3d(ctx0, model.memory_k, head_dim, N + n_past, 1, ne_element_size(model.memory_k) * head_dim,
ne_element_size(model.memory_k) * head_dim * n_ctx,
il * n_ctx * ne_element_size(model.memory_k) * head_dim * 1);

// K * Q
K = ne_cont(ctx0, ne_repeat(ctx0, K, repeat_dummy));
struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
Expand All @@ -488,14 +493,9 @@ bool falcon_eval(const falcon_model& model, const int n_threads, const int n_pas

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ne_tensor* V =
ne_permute(ctx0,
ne_reshape_3d(ctx0,
ne_view_1d(ctx0, model.memory_v, (n_past + N) * head_dim,
il * n_ctx * ne_element_size(model.memory_v) * head_dim),
head_dim, 1, n_past + N),
0, 2, 1, 3);

V = ne_cont(ctx0, ne_transpose(ctx0, ne_repeat(ctx0, V, repeat_dummy)));
ne_view_3d(ctx0, model.memory_v, N + n_past, head_dim, 1, ne_element_size(model.memory_v) * n_ctx,
ne_element_size(model.memory_v) * n_ctx * head_dim,
il * n_ctx * ne_element_size(model.memory_v) * head_dim * 1);

// KQV = transpose(V) * KQ_soft_max
struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ bool falcon_model_quantize(const std::string& fname_inp, const std::string& fnam
".*weight",
};

if (!ne_common_quantize_0(finp, fout, params, to_quant, {"transformer.word_embeddings.weight", "lm_head.weight"})) {
if (!ne_common_quantize_0(finp, fout, params, to_quant, {"lm_head.weight"})) {
fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
return false;
}
Expand Down
Loading

0 comments on commit 2a82ee0

Please sign in to comment.