Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Quick GELU #254

Merged
merged 7 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ extern "C" {
GGML_OP_STEP,
GGML_OP_RELU,
GGML_OP_GELU,
GGML_OP_GELU_QUICK,
GGML_OP_SILU,
GGML_OP_SILU_BACK,
GGML_OP_NORM, // normalize
Expand Down Expand Up @@ -687,6 +688,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_gelu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_silu(
struct ggml_context * ctx,
struct ggml_tensor * a);
Expand Down
148 changes: 144 additions & 4 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ typedef void* thread_ret_t;
/*#define GGML_PERF*/
#define GGML_DEBUG 0
#define GGML_GELU_FP16
#define GGML_GELU_QUICK_FP16
#define GGML_SILU_FP16

#define GGML_SOFT_MAX_UNROLL 4
Expand Down Expand Up @@ -322,6 +323,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
// precomputed gelu table for f16 (128 KB)
static ggml_fp16_t table_gelu_f16[1 << 16];

// precomputed quick gelu table for f16 (128 KB)
static ggml_fp16_t table_gelu_quick_f16[1 << 16];

// precomputed silu table for f16 (128 KB)
static ggml_fp16_t table_silu_f16[1 << 16];

Expand Down Expand Up @@ -3288,6 +3292,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }

static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;

inline static float ggml_gelu_f32(float x) {
Expand Down Expand Up @@ -3318,6 +3323,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
}
#endif

inline static float ggml_gelu_quick_f32(float x) {
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
}

inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
y[i] = table_gelu_quick_f16[i16[i]];
}
}

#ifdef GGML_GELU_QUICK_FP16
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
uint16_t t;
for (int i = 0; i < n; ++i) {
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
}
}
#else
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_gelu_quick_f32(x[i]);
}
}
#endif

// Sigmoid Linear Unit (SiLU) function
inline static float ggml_silu_f32(float x) {
return x/(1.0f + expf(-x));
Expand Down Expand Up @@ -3519,6 +3552,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"STEP",
"RELU",
"GELU",
"GELU_QUICK",
"SILU",
"SILU_BACK",
"NORM",
Expand Down Expand Up @@ -3558,7 +3592,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"MAP_BINARY",
};

static_assert(GGML_OP_COUNT == 54, "GGML_OP_COUNT != 54");
static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand All @@ -3583,6 +3617,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"step(x)",
"relu(x)",
"gelu(x)",
"gelu_quick(x)",
"silu(x)",
"silu_back(x)",
"norm(x)",
Expand Down Expand Up @@ -3622,7 +3657,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"f(x,y)",
};

static_assert(GGML_OP_COUNT == 54, "GGML_OP_COUNT != 54");
static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");

static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
Expand Down Expand Up @@ -3899,7 +3934,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
// initialize time system (required on Windows)
ggml_time_init();

// initialize GELU, SILU and EXP F32 tables
// initialize GELU, Quick GELU, SILU and EXP F32 tables
{
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);

Expand All @@ -3909,13 +3944,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
memcpy(&ii, &ui, sizeof(ii));
const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
}

const uint64_t t_end = ggml_time_us(); UNUSED(t_end);

GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
}

// initialize g_state
Expand Down Expand Up @@ -5271,6 +5307,40 @@ struct ggml_tensor * ggml_gelu_inplace(
return ggml_gelu_impl(ctx, a, true);
}

// ggml_gelu_quick

struct ggml_tensor * ggml_gelu_quick_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
bool inplace) {
bool is_node = false;

if (!inplace && (a->grad)) {
is_node = true;
}

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

result->op = GGML_OP_GELU_QUICK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;

return result;
}

struct ggml_tensor * ggml_gelu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_gelu_quick_impl(ctx, a, false);
}

struct ggml_tensor * ggml_gelu_quick_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_gelu_quick_impl(ctx, a, true);
}

// ggml_silu

struct ggml_tensor * ggml_silu_impl(
Expand Down Expand Up @@ -9118,6 +9188,67 @@ static void ggml_compute_forward_gelu(
//printf("XXXXXXXX gelu\n");
}

// ggml_compute_forward_gelu_quick

static void ggml_compute_forward_gelu_quick_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

const int ith = params->ith;
const int nth = params->nth;

const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_gelu_quick_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));

#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}

static void ggml_compute_forward_gelu_quick(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_gelu_quick_f32(params, src0, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}

//printf("XXXXXXXX quick gelu\n");
}

// ggml_compute_forward_silu

static void ggml_compute_forward_silu_f32(
Expand Down Expand Up @@ -13360,6 +13491,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_gelu(params, tensor->src0, tensor);
} break;
case GGML_OP_GELU_QUICK:
{
ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
} break;
case GGML_OP_SILU:
{
ggml_compute_forward_silu(params, tensor->src0, tensor);
Expand Down Expand Up @@ -13771,6 +13906,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_GELU_QUICK:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_ALIBI:
{
GGML_ASSERT(false); // TODO: not implemented
Expand Down Expand Up @@ -14578,6 +14717,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_MUL:
case GGML_OP_GELU:
case GGML_OP_GELU_QUICK:
case GGML_OP_SILU:
case GGML_OP_SILU_BACK:
case GGML_OP_NORM:
Expand Down