diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 591a09f8f57..1b7fc9c78c4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -72,7 +72,6 @@ struct whisper_params { float no_speech_thold = 0.60f; bool speed_up = false; - bool debug_mode = false; bool translate = false; bool detect_language = false; bool diarize = false; @@ -92,6 +91,7 @@ struct whisper_params { bool print_progress = false; bool no_timestamps = false; bool suppress_nst = true; // suppress non speech tokens + bool heuristic = true; bool log_score = false; bool use_gpu = true; @@ -143,8 +143,8 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-nst" || arg == "--nospeech-thold") { params.no_speech_thold = std::stof(argv[++i]); } // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } - else if (arg == "-snst" || arg == "--suppress-nst") { params.suppress_nst = true; } + else if (arg == "-nsnst"|| arg == "--no-suppress-nst") { params.suppress_nst = false; } + else if (arg == "-nh" || arg == "--no-heuristic") { params.heuristic = false; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } @@ -202,8 +202,8 @@ void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -nst N, --nospeech-thold N [%-7.2f] no-speech threshold for decoder fail\n", params.no_speech_thold); // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); - fprintf(stderr, " -snst, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); + fprintf(stderr, " -nsnst, --no-suppress-nst [%-7s] do not suppress non-speech tokens\n", params.suppress_nst ? "false" : "true"); + fprintf(stderr, " -nh, --no-heuristic [%-7s] do not use heuristic while decoding\n", params.heuristic ? "false" : "true"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); @@ -1014,7 +1014,7 @@ int run(int argc, const char ** argv) { wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.speed_up = params.speed_up; - wparams.debug_mode = params.debug_mode; + wparams.heuristic = params.heuristic; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b9c566a3a68..cf0157d4a58 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -68,7 +68,6 @@ struct whisper_params { float temperature_inc = 0.20f; bool speed_up = false; - bool debug_mode = false; bool translate = false; bool detect_language = false; bool diarize = false; @@ -141,7 +140,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); @@ -186,7 +184,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } @@ -443,10 +440,6 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.logprob_thold = std::stof(req.get_file_value("logprob_thold").content); } - if (req.has_file("debug_mode")) - { - params.debug_mode = parse_str_to_bool(req.get_file_value("debug_mode").content); - } if (req.has_file("translate")) { params.translate = parse_str_to_bool(req.get_file_value("translate").content); @@ -733,7 +726,6 @@ int main(int argc, char ** argv) { wparams.max_len = params.max_len == 0 ? 60 : params.max_len; wparams.speed_up = params.speed_up; - wparams.debug_mode = params.debug_mode; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] diff --git a/whisper.cpp b/whisper.cpp index 79e0e5eb092..dbea6058cd9 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2845,7 +2845,6 @@ static bool log_mel_spectrogram( const int n_mel, const int n_threads, const whisper_filters & filters, - const bool debug, whisper_mel & mel) { const int64_t t_start_us = ggml_time_us(); @@ -2917,17 +2916,6 @@ static bool log_mel_spectrogram( wstate.t_mel_us += ggml_time_us() - t_start_us; - // Dump log_mel_spectrogram - if (debug) { - std::ofstream outFile("log_mel_spectrogram.json"); - outFile << "["; - for (uint64_t i = 0; i < mel.data.size() - 1; i++) { - outFile << mel.data[i] << ", "; - } - outFile << mel.data[mel.data.size() - 1] << "]"; - outFile.close(); - } - return true; } @@ -3598,7 +3586,7 @@ void whisper_free_params(struct whisper_full_params * params) { } int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, state->mel)) { WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3612,7 +3600,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, state->mel)) { WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -4530,7 +4518,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_tokens =*/ 0, /*.speed_up =*/ false, - /*.debug_mode =*/ false, /*.audio_ctx =*/ 0, /*.tdrz_enable =*/ false, @@ -4544,6 +4531,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.suppress_blank =*/ true, /*.suppress_non_speech_tokens =*/ true, + /*.heuristic =*/ true, /*.temperature =*/ 0.0f, /*.max_initial_ts =*/ 1.0f, @@ -4785,7 +4773,7 @@ static void whisper_no_speech_probs( } static const std::vector non_speech_tokens = { - "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "\"", "#", "*", "+", "/", ":", ";", "<", "=", ">", "@", "\\", "^", "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" @@ -5435,6 +5423,7 @@ int whisper_full_with_state( } int seek = seek_start; + bool fast_forward = false; std::vector prompt; prompt.reserve(whisper_n_text_ctx(ctx)); @@ -5991,6 +5980,42 @@ int whisper_full_with_state( if (best_decoder.sequence.no_speech_probs > params.no_speech_thold) { if (best_decoder.sequence.avg_logprobs < params.logprob_thold) { // fast-forward to the next segment boundary + prompt_past.clear(); + fast_forward = true; + seek += std::min(3000, state->mel.n_len_org - seek); + continue; + } + } + } + + // repetition check + { + if (!params.no_timestamps && params.heuristic) { + const auto & best_decoder = state->decoders[best_decoder_id]; + const auto & tokens_cur = best_decoder.sequence.tokens; + + std::set table; + std::string text; + int timestamp_token_counter = 0; + int max_length = 0; + + for (auto & token : tokens_cur) { + if (token.id < whisper_token_beg(ctx)) { + text += ctx->vocab.id_to_token[token.id]; + } else { + timestamp_token_counter++; + } + if (timestamp_token_counter % 2 == 0) { + if (text.length() > max_length) {max_length = text.length();} + table.insert(text); + text.clear(); + } + } + + if ((static_cast(table.size()) / static_cast(timestamp_token_counter)) < 0.25 || max_length <= 4) { + // fast-forward to the next segment boundary + prompt_past.clear(); + fast_forward = true; seek += std::min(3000, state->mel.n_len_org - seek); continue; } @@ -6113,6 +6138,8 @@ int whisper_full_with_state( } } + fast_forward = false; + // update audio window // https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/transcribe.py#L353-L361 { diff --git a/whisper.h b/whisper.h index f9bdf36a0ce..137ecc8b9bf 100644 --- a/whisper.h +++ b/whisper.h @@ -457,7 +457,6 @@ extern "C" { // [EXPERIMENTAL] speed-up techniques // note: these can significantly reduce the quality of the output bool speed_up; // speed-up the audio by 2x using Phase Vocoder - bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) int audio_ctx; // overwrite the audio context size (0 = use default) // [EXPERIMENTAL] [TDRZ] tinydiarize @@ -476,6 +475,7 @@ extern "C" { // common decoding parameters: bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + bool heuristic; float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97