From f990610776724ac4ccef3ed1afe90ddcdebc3269 Mon Sep 17 00:00:00 2001 From: Digipom Date: Wed, 6 Sep 2023 11:32:30 -0400 Subject: [PATCH] whisper.android : address ARM's big.LITTLE arch by checking cpu info (#1254) Addresses https://github.com/ggerganov/whisper.cpp/issues/1248 --- .../com/whispercppdemo/whisper/LibWhisper.kt | 6 +- .../whisper/WhisperCpuConfig.kt | 73 +++++++++++++++++++ .../app/src/main/jni/whisper/jni.c | 8 +- 3 files changed, 79 insertions(+), 8 deletions(-) create mode 100644 examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/WhisperCpuConfig.kt diff --git a/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt b/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt index a2b651c..b0d6703 100644 --- a/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt +++ b/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt @@ -18,7 +18,9 @@ class WhisperContext private constructor(private var ptr: Long) { suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) { require(ptr != 0L) - WhisperLib.fullTranscribe(ptr, data) + val numThreads = WhisperCpuConfig.preferredThreadCount + Log.d(LOG_TAG, "Selecting $numThreads threads") + WhisperLib.fullTranscribe(ptr, numThreads, data) val textCount = WhisperLib.getTextSegmentCount(ptr) return@withContext buildString { for (i in 0 until textCount) { @@ -126,7 +128,7 @@ private class WhisperLib { external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long external fun initContext(modelPath: String): Long external fun freeContext(contextPtr: Long) - external fun fullTranscribe(contextPtr: Long, audioData: FloatArray) + external fun fullTranscribe(contextPtr: Long, numThreads: Int, audioData: FloatArray) external fun getTextSegmentCount(contextPtr: Long): Int external fun getTextSegment(contextPtr: Long, index: Int): String external fun getSystemInfo(): String diff --git a/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/WhisperCpuConfig.kt b/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/WhisperCpuConfig.kt new file mode 100644 index 0000000..5fa9a4e --- /dev/null +++ b/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/WhisperCpuConfig.kt @@ -0,0 +1,73 @@ +package com.whispercppdemo.whisper + +import android.util.Log +import java.io.BufferedReader +import java.io.FileReader + +object WhisperCpuConfig { + val preferredThreadCount: Int + // Always use at least 2 threads: + get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2) +} + +private class CpuInfo(private val lines: List) { + private fun getHighPerfCpuCount(): Int = try { + getHighPerfCpuCountByFrequencies() + } catch (e: Exception) { + Log.d(LOG_TAG, "Couldn't read CPU frequencies", e) + getHighPerfCpuCountByVariant() + } + + private fun getHighPerfCpuCountByFrequencies(): Int = + getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) } + .also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") } + .countDroppingMin() + + private fun getHighPerfCpuCountByVariant(): Int = + getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) } + .also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") } + .countKeepingMin() + + private fun List.binnedValues() = groupingBy { it }.eachCount() + + private fun getCpuValues(property: String, mapper: (String) -> Int) = lines + .asSequence() + .filter { it.startsWith(property) } + .map { mapper(it.substringAfter(':').trim()) } + .sorted() + .toList() + + + private fun List.countDroppingMin(): Int { + val min = min() + return count { it > min } + } + + private fun List.countKeepingMin(): Int { + val min = min() + return count { it == min } + } + + companion object { + private const val LOG_TAG = "WhisperCpuConfig" + + fun getHighPerfCpuCount(): Int = try { + readCpuInfo().getHighPerfCpuCount() + } catch (e: Exception) { + Log.d(LOG_TAG, "Couldn't read CPU info", e) + // Our best guess -- just return the # of CPUs minus 4. + (Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0) + } + + private fun readCpuInfo() = CpuInfo( + BufferedReader(FileReader("/proc/cpuinfo")) + .useLines { it.toList() } + ) + + private fun getMaxCpuFrequency(cpuIndex: Int): Int { + val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq" + val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() } + return maxFreq.toInt() + } + } +} \ No newline at end of file diff --git a/examples/whisper.android/app/src/main/jni/whisper/jni.c b/examples/whisper.android/app/src/main/jni/whisper/jni.c index 82dfd77..c437d09 100644 --- a/examples/whisper.android/app/src/main/jni/whisper/jni.c +++ b/examples/whisper.android/app/src/main/jni/whisper/jni.c @@ -163,16 +163,12 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_freeContext( JNIEXPORT void JNICALL Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe( - JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) { + JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) { UNUSED(thiz); struct whisper_context *context = (struct whisper_context *) context_ptr; jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL); const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data); - // Leave 2 processors free (i.e. the high-efficiency cores). - int max_threads = max(1, min(8, get_nprocs() - 2)); - LOGI("Selecting %d threads", max_threads); - // The below adapted from the Objective-C iOS sample struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); params.print_realtime = true; @@ -181,7 +177,7 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe( params.print_special = false; params.translate = false; params.language = "en"; - params.n_threads = max_threads; + params.n_threads = num_threads; params.offset_ms = 0; params.no_context = true; params.single_segment = false;