From 77d929f60388f6d6e9c8c6439443505592704ed1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 2 Oct 2022 17:46:21 +0300 Subject: [PATCH] Fix bug in FFT The FFT routine does not work for odd N Solution is to add DFT and use it when N is odd --- main.cpp | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/main.cpp b/main.cpp index fb758e37ddf..b39f36016c9 100644 --- a/main.cpp +++ b/main.cpp @@ -1909,8 +1909,31 @@ whisper_vocab::id whisper_sample_timestamp( return probs_id[0].second; } +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +void dft(const std::vector & in, std::vector & out) { + int N = in.size(); + + out.resize(N*2); + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + float angle = 2*M_PI*k*n/N; + re += in[n]*cos(angle); + im -= in[n]*sin(angle); + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + // Cooley-Tukey FFT -// poor man's implmentation - use something better +// poor man's implementation - use something better // input is real-valued // output is complex-valued void fft(const std::vector & in, std::vector & out) { @@ -1924,6 +1947,11 @@ void fft(const std::vector & in, std::vector & out) { return; } + if (N%2 == 1) { + dft(in, out); + return; + } + std::vector even; std::vector odd; @@ -2014,9 +2042,20 @@ bool log_mel_spectrogram( // FFT -> mag^2 fft(fft_in, fft_out); - for (int j = 0; j < n_fft; j++) { + for (int j = 0; j < fft_size; j++) { fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]); } + for (int j = 1; j < fft_size/2; j++) { + //if (i == 0) { + // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]); + //} + fft_out[j] += fft_out[fft_size - j]; + } + if (i == 0) { + //for (int j = 0; j < fft_size; j++) { + // printf("%d: %e\n", j, fft_out[j]); + //} + } // mel spectrogram for (int j = 0; j < mel.n_mel; j++) { @@ -2048,6 +2087,7 @@ bool log_mel_spectrogram( mmax = mel.data[i]; } } + //printf("%s: max = %f\n", __func__, mmax); mmax -= 8.0;