Skip to content

Commit

Permalink
ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
slaren authored Sep 20, 2023
1 parent 2f3a46f commit 1be2b8c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 12 deletions.
50 changes: 49 additions & 1 deletion ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6343,7 +6343,7 @@ static struct ggml_tensor * ggml_cpy_impl(
}

// make a view of the destination
struct ggml_tensor * result = b->op == GGML_OP_NONE ? b : ggml_view_tensor(ctx, b);
struct ggml_tensor * result = ggml_view_tensor(ctx, b);
if (strlen(b->name) > 0) {
ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
} else {
Expand Down Expand Up @@ -6406,6 +6406,54 @@ struct ggml_tensor * ggml_cont_inplace(
return ggml_cont_impl(ctx, a, true);
}


// make contiguous, with new shape
GGML_API struct ggml_tensor * ggml_cont_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0) {
return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
}

GGML_API struct ggml_tensor * ggml_cont_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1) {
return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
}

GGML_API struct ggml_tensor * ggml_cont_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2) {
return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
}

struct ggml_tensor * ggml_cont_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3) {
GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));

bool is_node = false;

struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
ggml_format_name(result, "%s (cont)", a->name);

result->op = GGML_OP_CONT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;

return result;
}


// ggml_reshape

struct ggml_tensor * ggml_reshape(
Expand Down
28 changes: 27 additions & 1 deletion ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,6 @@ extern "C" {
size_t nb1,
size_t offset);


// a -> b, return view(b)
GGML_API struct ggml_tensor * ggml_cpy(
struct ggml_context * ctx,
Expand All @@ -1072,6 +1071,33 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

// make contiguous, with new shape
GGML_API struct ggml_tensor * ggml_cont_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0);

GGML_API struct ggml_tensor * ggml_cont_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1);

GGML_API struct ggml_tensor * ggml_cont_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2);

GGML_API struct ggml_tensor * ggml_cont_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3);

// return view(a), b specifies the new shape
// TODO: when we start computing gradient, make a copy instead of view
GGML_API struct ggml_tensor * ggml_reshape(
Expand Down
14 changes: 4 additions & 10 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2893,9 +2893,7 @@ static struct ggml_cgraph * llm_build_llama(
ggml_set_name(KQV_merged, "KQV_merged");

// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens));
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous");

Expand Down Expand Up @@ -3302,9 +3300,7 @@ static struct ggml_cgraph * llm_build_baichaun(
ggml_set_name(KQV_merged, "KQV_merged");

// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens));
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous");

Expand Down Expand Up @@ -3710,7 +3706,7 @@ static struct ggml_cgraph * llm_build_falcon(
offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged");

cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens));
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous");

Expand Down Expand Up @@ -3964,9 +3960,7 @@ static struct ggml_cgraph * llm_build_starcoder(
ggml_set_name(KQV_merged, "KQV_merged");

// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens));
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
ggml_set_name(cur, "KQV_merged_contiguous");
}

Expand Down

0 comments on commit 1be2b8c

Please sign in to comment.