whisper : slightly faster Log Mel computation + n-1 FFT threads (#568)
This commit is contained in:
parent
355da83690
commit
3dead611bb
52
whisper.cpp
52
whisper.cpp
@ -2306,10 +2306,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|||||||
std::vector<float> fft_in(fft_size, 0.0);
|
std::vector<float> fft_in(fft_size, 0.0);
|
||||||
std::vector<float> fft_out(2 * fft_size);
|
std::vector<float> fft_out(2 * fft_size);
|
||||||
int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
|
int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);
|
||||||
|
|
||||||
for (int i = ith; i < mel.n_len; i += n_threads) {
|
for (int i = ith; i < mel.n_len; i += n_threads) {
|
||||||
const int offset = i * fft_step;
|
const int offset = i * fft_step;
|
||||||
|
|
||||||
// apply Hanning window
|
// apply Hanning window
|
||||||
for (int j = 0; j < fft_size; j++) {
|
for (int j = 0; j < fft_size; j++) {
|
||||||
if (offset + j < n_samples) {
|
if (offset + j < n_samples) {
|
||||||
@ -2318,37 +2318,49 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|||||||
fft_in[j] = 0.0;
|
fft_in[j] = 0.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FFT -> mag^2
|
// FFT -> mag^2
|
||||||
fft(fft_in, fft_out);
|
fft(fft_in, fft_out);
|
||||||
|
|
||||||
for (int j = 0; j < fft_size; j++) {
|
for (int j = 0; j < fft_size; j++) {
|
||||||
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
||||||
}
|
}
|
||||||
for (int j = 1; j < fft_size / 2; j++) {
|
for (int j = 1; j < fft_size / 2; j++) {
|
||||||
fft_out[j] += fft_out[fft_size - j];
|
fft_out[j] += fft_out[fft_size - j];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (speed_up) {
|
if (speed_up) {
|
||||||
// scale down in the frequency domain results in a speed up in the time domain
|
// scale down in the frequency domain results in a speed up in the time domain
|
||||||
for (int j = 0; j < n_fft; j++) {
|
for (int j = 0; j < n_fft; j++) {
|
||||||
fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
|
fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mel spectrogram
|
// mel spectrogram
|
||||||
for (int j = 0; j < mel.n_mel; j++) {
|
for (int j = 0; j < mel.n_mel; j++) {
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
|
|
||||||
for (int k = 0; k < n_fft; k++) {
|
// unroll loop (suggested by GH user @lunixbochs)
|
||||||
|
int k = 0;
|
||||||
|
for (k = 0; k < n_fft - 3; k += 4) {
|
||||||
|
sum +=
|
||||||
|
fft_out[k + 0] * filters.data[j*n_fft + k + 0] +
|
||||||
|
fft_out[k + 1] * filters.data[j*n_fft + k + 1] +
|
||||||
|
fft_out[k + 2] * filters.data[j*n_fft + k + 2] +
|
||||||
|
fft_out[k + 3] * filters.data[j*n_fft + k + 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle n_fft remainder
|
||||||
|
for (; k < n_fft; k++) {
|
||||||
sum += fft_out[k] * filters.data[j * n_fft + k];
|
sum += fft_out[k] * filters.data[j * n_fft + k];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sum < 1e-10) {
|
if (sum < 1e-10) {
|
||||||
sum = 1e-10;
|
sum = 1e-10;
|
||||||
}
|
}
|
||||||
|
|
||||||
sum = log10(sum);
|
sum = log10(sum);
|
||||||
|
|
||||||
mel.data[j * mel.n_len + i] = sum;
|
mel.data[j * mel.n_len + i] = sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2383,17 +2395,19 @@ static bool log_mel_spectrogram(
|
|||||||
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
|
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
|
||||||
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
|
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
|
||||||
|
|
||||||
if (n_threads == 1) {
|
{
|
||||||
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
|
std::vector<std::thread> workers(n_threads - 1);
|
||||||
} else {
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||||
std::vector<std::thread> workers(n_threads);
|
workers[iw] = std::thread(
|
||||||
for (int iw = 0; iw < n_threads; ++iw) {
|
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples,
|
||||||
workers[iw] = std::thread(log_mel_spectrogram_worker_thread, iw, std::cref(hann), samples,
|
n_samples, fft_size, fft_step, n_threads,
|
||||||
n_samples, fft_size, fft_step, n_threads,
|
std::cref(filters), speed_up, std::ref(mel));
|
||||||
std::cref(filters), speed_up, std::ref(mel));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int iw = 0; iw < n_threads; ++iw) {
|
// main thread
|
||||||
|
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
|
||||||
|
|
||||||
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||||
workers[iw].join();
|
workers[iw].join();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user