Skip to content

Commit

Permalink
Merge pull request ggerganov#8 from bobqianic/heuristic
Browse files Browse the repository at this point in the history
Heuristic
  • Loading branch information
bobqianic authored Feb 9, 2024
2 parents c0277e3 + de4f87f commit 476dff4
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 31 deletions.
12 changes: 6 additions & 6 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ struct whisper_params {
float no_speech_thold = 0.60f;

bool speed_up = false;
bool debug_mode = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
Expand All @@ -92,6 +91,7 @@ struct whisper_params {
bool print_progress = false;
bool no_timestamps = false;
bool suppress_nst = true; // suppress non speech tokens
bool heuristic = true;
bool log_score = false;
bool use_gpu = true;

Expand Down Expand Up @@ -143,8 +143,8 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params)
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-nst" || arg == "--nospeech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-snst" || arg == "--suppress-nst") { params.suppress_nst = true; }
else if (arg == "-nsnst"|| arg == "--no-suppress-nst") { params.suppress_nst = false; }
else if (arg == "-nh" || arg == "--no-heuristic") { params.heuristic = false; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
Expand Down Expand Up @@ -202,8 +202,8 @@ void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -nst N, --nospeech-thold N [%-7.2f] no-speech threshold for decoder fail\n", params.no_speech_thold);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -snst, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, " -nsnst, --no-suppress-nst [%-7s] do not suppress non-speech tokens\n", params.suppress_nst ? "false" : "true");
fprintf(stderr, " -nh, --no-heuristic [%-7s] do not use heuristic while decoding\n", params.heuristic ? "false" : "true");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
Expand Down Expand Up @@ -1014,7 +1014,7 @@ int run(int argc, const char ** argv) {
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;

wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;
wparams.heuristic = params.heuristic;

wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

Expand Down
8 changes: 0 additions & 8 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ struct whisper_params {
float temperature_inc = 0.20f;

bool speed_up = false;
bool debug_mode = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
Expand Down Expand Up @@ -141,7 +140,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
Expand Down Expand Up @@ -186,7 +184,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
Expand Down Expand Up @@ -443,10 +440,6 @@ void get_req_parameters(const Request & req, whisper_params & params)
{
params.logprob_thold = std::stof(req.get_file_value("logprob_thold").content);
}
if (req.has_file("debug_mode"))
{
params.debug_mode = parse_str_to_bool(req.get_file_value("debug_mode").content);
}
if (req.has_file("translate"))
{
params.translate = parse_str_to_bool(req.get_file_value("translate").content);
Expand Down Expand Up @@ -733,7 +726,6 @@ int main(int argc, char ** argv) {
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;

wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;

wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

Expand Down
59 changes: 43 additions & 16 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2845,7 +2845,6 @@ static bool log_mel_spectrogram(
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();

Expand Down Expand Up @@ -2917,17 +2916,6 @@ static bool log_mel_spectrogram(

wstate.t_mel_us += ggml_time_us() - t_start_us;

// Dump log_mel_spectrogram
if (debug) {
std::ofstream outFile("log_mel_spectrogram.json");
outFile << "[";
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
outFile << mel.data[i] << ", ";
}
outFile << mel.data[mel.data.size() - 1] << "]";
outFile.close();
}

return true;
}

Expand Down Expand Up @@ -3598,7 +3586,7 @@ void whisper_free_params(struct whisper_full_params * params) {
}

int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, state->mel)) {
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
Expand All @@ -3612,7 +3600,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int

// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, state->mel)) {
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
Expand Down Expand Up @@ -4530,7 +4518,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.max_tokens =*/ 0,

/*.speed_up =*/ false,
/*.debug_mode =*/ false,
/*.audio_ctx =*/ 0,

/*.tdrz_enable =*/ false,
Expand All @@ -4544,6 +4531,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ true,
/*.heuristic =*/ true,

/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
Expand Down Expand Up @@ -4785,7 +4773,7 @@ static void whisper_no_speech_probs(
}

static const std::vector<std::string> non_speech_tokens = {
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"\"", "#", "*", "+", "/", ":", ";", "<", "=", ">", "@", "\\", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
"♪♪♪","", "", "", "", "", "", ""
Expand Down Expand Up @@ -5435,6 +5423,7 @@ int whisper_full_with_state(
}

int seek = seek_start;
bool fast_forward = false;

std::vector<whisper_token> prompt;
prompt.reserve(whisper_n_text_ctx(ctx));
Expand Down Expand Up @@ -5991,6 +5980,42 @@ int whisper_full_with_state(
if (best_decoder.sequence.no_speech_probs > params.no_speech_thold) {
if (best_decoder.sequence.avg_logprobs < params.logprob_thold) {
// fast-forward to the next segment boundary
prompt_past.clear();
fast_forward = true;
seek += std::min(3000, state->mel.n_len_org - seek);
continue;
}
}
}

// repetition check
{
if (!params.no_timestamps && params.heuristic) {
const auto & best_decoder = state->decoders[best_decoder_id];
const auto & tokens_cur = best_decoder.sequence.tokens;

std::set<std::string> table;
std::string text;
int timestamp_token_counter = 0;
int max_length = 0;

for (auto & token : tokens_cur) {
if (token.id < whisper_token_beg(ctx)) {
text += ctx->vocab.id_to_token[token.id];
} else {
timestamp_token_counter++;
}
if (timestamp_token_counter % 2 == 0) {
if (text.length() > max_length) {max_length = text.length();}
table.insert(text);
text.clear();
}
}

if ((static_cast<float>(table.size()) / static_cast<float>(timestamp_token_counter)) < 0.25 || max_length <= 4) {
// fast-forward to the next segment boundary
prompt_past.clear();
fast_forward = true;
seek += std::min(3000, state->mel.n_len_org - seek);
continue;
}
Expand Down Expand Up @@ -6113,6 +6138,8 @@ int whisper_full_with_state(
}
}

fast_forward = false;

// update audio window
// https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/transcribe.py#L353-L361
{
Expand Down
2 changes: 1 addition & 1 deletion whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,6 @@ extern "C" {
// [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
int audio_ctx; // overwrite the audio context size (0 = use default)

// [EXPERIMENTAL] [TDRZ] tinydiarize
Expand All @@ -476,6 +475,7 @@ extern "C" {
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
bool heuristic;

float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
Expand Down

0 comments on commit 476dff4

Please sign in to comment.