whisper : slightly faster Log Mel computation + n-1 FFT threads (#568)

This commit is contained in:
Georgi Gerganov 2023-04-15 14:18:46 +03:00
parent 355da83690
commit 3dead611bb
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -2340,9 +2340,21 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
for (int j = 0; j < mel.n_mel; j++) { for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0; double sum = 0.0;
for (int k = 0; k < n_fft; k++) { // unroll loop (suggested by GH user @lunixbochs)
int k = 0;
for (k = 0; k < n_fft - 3; k += 4) {
sum +=
fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
fft_out[k + 3] * filters.data[j*n_fft + k + 3];
}
// handle n_fft remainder
for (; k < n_fft; k++) {
sum += fft_out[k] * filters.data[j * n_fft + k]; sum += fft_out[k] * filters.data[j * n_fft + k];
} }
if (sum < 1e-10) { if (sum < 1e-10) {
sum = 1e-10; sum = 1e-10;
} }
@ -2383,17 +2395,19 @@ static bool log_mel_spectrogram(
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
if (n_threads == 1) { {
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel); std::vector<std::thread> workers(n_threads - 1);
} else { for (int iw = 0; iw < n_threads - 1; ++iw) {
std::vector<std::thread> workers(n_threads); workers[iw] = std::thread(
for (int iw = 0; iw < n_threads; ++iw) { log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
workers[iw] = std::thread(log_mel_spectrogram_worker_thread, iw, std::cref(hann), samples,
n_samples, fft_size, fft_step, n_threads, n_samples, fft_size, fft_step, n_threads,
std::cref(filters), speed_up, std::ref(mel)); std::cref(filters), speed_up, std::ref(mel));
} }
for (int iw = 0; iw < n_threads; ++iw) { // main thread
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join(); workers[iw].join();
} }
} }