Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory usage during Whisper inference #431

Merged
merged 15 commits into from
Feb 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 109 additions & 98 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion bindings/javascript/whisper.js

Large diffs are not rendered by default.

52 changes: 31 additions & 21 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,35 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
usage: ./main [options] file0.wav file1.wav ...

options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-owts, --output-words [false ] output script for generating karaoke video
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-owts, --output-words [false ] output script for generating karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
```
33 changes: 19 additions & 14 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,23 @@ void replace_all(std::string & s, const std::string & search, const std::string
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;

float word_thold = 0.01f;
float entropy_thold = 2.4f;
float logprob_thold = -1.0f;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;

bool speed_up = false;
bool translate = false;
bool diarize = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
Expand Down Expand Up @@ -117,6 +118,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
Expand Down Expand Up @@ -162,6 +164,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
Expand Down Expand Up @@ -514,7 +517,7 @@ int main(int argc, char ** argv) {

for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_outp = f < params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
const auto fname_outp = f < (int) params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];

std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
Expand Down Expand Up @@ -647,17 +650,19 @@ int main(int argc, char ** argv) {

wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;

wparams.speed_up = params.speed_up;

wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();

wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;

wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;

whisper_print_user_data user_data = { &params, &pcmf32s };

Expand Down
123 changes: 88 additions & 35 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
//

struct ggml_object {
size_t offset;
size_t offs;
size_t size;

struct ggml_object * next;
Expand All @@ -1284,6 +1284,9 @@ struct ggml_context {

struct ggml_object * objects_begin;
struct ggml_object * objects_end;

struct ggml_scratch scratch;
struct ggml_scratch scratch_save;
};

struct ggml_context_container {
Expand Down Expand Up @@ -1346,7 +1349,7 @@ inline static void ggml_critical_section_end(void) {

void ggml_print_object(const struct ggml_object * obj) {
GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
obj->offset, obj->size, (const void *) obj->next);
obj->offs, obj->size, (const void *) obj->next);
}

void ggml_print_objects(const struct ggml_context * ctx) {
Expand Down Expand Up @@ -1542,12 +1545,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
}

*ctx = (struct ggml_context) {
.mem_size = params.mem_size,
.mem_buffer = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
.mem_buffer_owned = params.mem_buffer ? false : true,
.n_objects = 0,
.objects_begin = NULL,
.objects_end = NULL,
/*.mem_size =*/ params.mem_size,
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
/*.n_objects =*/ 0,
/*.objects_begin =*/ NULL,
/*.objects_end =*/ NULL,
/*.scratch =*/ { 0, 0, NULL, },
/*.scratch_save =*/ { 0, 0, NULL, },
};

ggml_assert_aligned(ctx->mem_buffer);
Expand All @@ -1570,7 +1575,7 @@ void ggml_free(struct ggml_context * ctx) {
g_state.contexts[i].used = false;

GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
__func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
__func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);

if (ctx->mem_buffer_owned) {
free(ctx->mem_buffer);
Expand All @@ -1589,7 +1594,15 @@ void ggml_free(struct ggml_context * ctx) {
}

size_t ggml_used_mem(const struct ggml_context * ctx) {
return ctx->objects_end->offset + ctx->objects_end->size;
return ctx->objects_end->offs + ctx->objects_end->size;
}

size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0;

ctx->scratch = scratch;

return result;
}

////////////////////////////////////////////////////////////////////////////////
Expand All @@ -1603,9 +1616,9 @@ struct ggml_tensor * ggml_new_tensor_impl(
// always insert objects at the end of the context's memory pool
struct ggml_object * obj_cur = ctx->objects_end;

const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset;
const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
const size_t cur_end = cur_offset + cur_size;
const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
const size_t cur_end = cur_offs + cur_size;

size_t size_needed = 0;

Expand All @@ -1616,25 +1629,52 @@ struct ggml_tensor * ggml_new_tensor_impl(
}
// align to GGML_MEM_ALIGN
size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;

}
size_needed += sizeof(struct ggml_tensor);

if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__);
assert(false);
return NULL;
}

char * const mem_buffer = ctx->mem_buffer;

struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);

*obj_new = (struct ggml_object) {
.offset = cur_end + GGML_OBJECT_SIZE,
.size = size_needed,
.next = NULL,
};
if (ctx->scratch.data == NULL || data != NULL) {
size_needed += sizeof(struct ggml_tensor);

if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
__func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
assert(false);
return NULL;
}

*obj_new = (struct ggml_object) {
.offs = cur_end + GGML_OBJECT_SIZE,
.size = size_needed,
.next = NULL,
};
} else {
if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
GGML_PRINT("%s: not enough space in the scratch memory\n", __func__);
assert(false);
return NULL;
}

if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
__func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size);
assert(false);
return NULL;
}

data = (char * const) ctx->scratch.data + ctx->scratch.offs;

*obj_new = (struct ggml_object) {
.offs = cur_end + GGML_OBJECT_SIZE,
.size = sizeof(struct ggml_tensor),
.next = NULL,
};

//printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed);

ctx->scratch.offs += size_needed;
}

if (obj_cur != NULL) {
obj_cur->next = obj_new;
Expand All @@ -1645,9 +1685,9 @@ struct ggml_tensor * ggml_new_tensor_impl(

ctx->objects_end = obj_new;

//GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end);
//printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);

struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset);
struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs);

ggml_assert_aligned(result);

Expand Down Expand Up @@ -1690,7 +1730,7 @@ struct ggml_tensor * ggml_new_tensor(
struct ggml_context * ctx,
enum ggml_type type,
int n_dims,
const int* ne) {
const int * ne) {
return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
}

Expand Down Expand Up @@ -1732,16 +1772,26 @@ struct ggml_tensor * ggml_new_tensor_4d(
}

struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
ctx->scratch_save = ctx->scratch;
ctx->scratch.data = NULL;

struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);

ctx->scratch = ctx->scratch_save;

ggml_set_i32(result, value);

return result;
}

struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
ctx->scratch_save = ctx->scratch;
ctx->scratch.data = NULL;

struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);

ctx->scratch = ctx->scratch_save;

ggml_set_f32(result, value);

return result;
Expand Down Expand Up @@ -2350,7 +2400,7 @@ struct ggml_tensor * ggml_repeat(
result->op = GGML_OP_REPEAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;
result->src1 = b;

return result;
}
Expand Down Expand Up @@ -2966,9 +3016,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
// TODO: when implement backward, fix this:
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);

struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
((int32_t *) b->data)[0] = n_past;
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);

result->op = GGML_OP_DIAG_MASK_INF;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
Expand Down Expand Up @@ -4300,7 +4348,9 @@ static bool ggml_compute_forward_mul_mat_use_blas(
const int ne1 = dst->ne[1];

// TODO: find the optimal values for these
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) {
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && (
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)
)) {
//printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
return true;
}
Expand Down Expand Up @@ -7289,6 +7339,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
//printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
//printf("cur = %zu\n", cur);
} else {
cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
}
Expand Down
Loading