Feature/java bindings2 (#944)
* Java needs to call `whisper_full_default_params_by_ref()`, returning struct by val does not seem to work. * added convenience methods to WhisperFullParams * Remove unused WhisperJavaParams
This commit is contained in:
parent
9b926844e3
commit
d7c936b44a
47
.github/workflows/build.yml
vendored
47
.github/workflows/build.yml
vendored
@ -125,8 +125,10 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- arch: Win32
|
- arch: Win32
|
||||||
s2arc: x86
|
s2arc: x86
|
||||||
|
jnaPath: win32-x86
|
||||||
- arch: x64
|
- arch: x64
|
||||||
s2arc: x64
|
s2arc: x64
|
||||||
|
jnaPath: win32-x86-64
|
||||||
- sdl2: ON
|
- sdl2: ON
|
||||||
s2ver: 2.26.0
|
s2ver: 2.26.0
|
||||||
|
|
||||||
@ -159,6 +161,12 @@ jobs:
|
|||||||
if: matrix.sdl2 == 'ON'
|
if: matrix.sdl2 == 'ON'
|
||||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||||
|
|
||||||
|
- name: Upload dll
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: ${{ matrix.jnaPath }}_whisper.dll
|
||||||
|
path: build/bin/${{ matrix.build }}/whisper.dll
|
||||||
|
|
||||||
- name: Upload binaries
|
- name: Upload binaries
|
||||||
if: matrix.sdl2 == 'ON'
|
if: matrix.sdl2 == 'ON'
|
||||||
uses: actions/upload-artifact@v1
|
uses: actions/upload-artifact@v1
|
||||||
@ -363,3 +371,42 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
cd examples/whisper.android
|
cd examples/whisper.android
|
||||||
./gradlew assembleRelease --no-daemon
|
./gradlew assembleRelease --no-daemon
|
||||||
|
|
||||||
|
java:
|
||||||
|
needs: [ 'windows' ]
|
||||||
|
runs-on: windows-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v1
|
||||||
|
|
||||||
|
- name: Install Java
|
||||||
|
uses: actions/setup-java@v1
|
||||||
|
with:
|
||||||
|
java-version: 17
|
||||||
|
|
||||||
|
- name: Download Windows lib
|
||||||
|
uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
name: win32-x86-64_whisper.dll
|
||||||
|
path: bindings/java/build/generated/resources/main/win32-x86-64
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: |
|
||||||
|
models\download-ggml-model.cmd tiny.en
|
||||||
|
cd bindings/java
|
||||||
|
chmod +x ./gradlew
|
||||||
|
./gradlew build
|
||||||
|
|
||||||
|
- name: Upload jar
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: whispercpp.jar
|
||||||
|
path: bindings/java/build/libs/whispercpp-*.jar
|
||||||
|
|
||||||
|
# - name: Publish package
|
||||||
|
# if: ${{ github.ref == 'refs/heads/master' }}
|
||||||
|
# uses: gradle/gradle-build-action@v2
|
||||||
|
# with:
|
||||||
|
# arguments: publish
|
||||||
|
# env:
|
||||||
|
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
|
||||||
|
# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
|
||||||
|
@ -1,50 +0,0 @@
|
|||||||
cmake_minimum_required(VERSION 3.10)
|
|
||||||
|
|
||||||
project(whisper_java VERSION 1.4.2)
|
|
||||||
|
|
||||||
# Set the target name and source file/s
|
|
||||||
set(TARGET_NAME whisper_java)
|
|
||||||
set(SOURCES src/main/cpp/whisper_java.cpp)
|
|
||||||
|
|
||||||
# include <whisper.h>
|
|
||||||
include_directories(../../)
|
|
||||||
|
|
||||||
# Set the output directory for the DLL/shared library based on the platform as required by JNA
|
|
||||||
if(WIN32)
|
|
||||||
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/win32-x86-64)
|
|
||||||
elseif(UNIX AND NOT APPLE)
|
|
||||||
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/linux-x86-64)
|
|
||||||
elseif(APPLE)
|
|
||||||
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/macos-x86-64)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OUTPUT_DIR})
|
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OUTPUT_DIR})
|
|
||||||
|
|
||||||
# Create the whisper_java library
|
|
||||||
add_library(${TARGET_NAME} SHARED ${SOURCES})
|
|
||||||
|
|
||||||
# Link against ../../build/Release/whisper.dll (or so/dynlib)
|
|
||||||
target_link_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/../../../build/${CMAKE_BUILD_TYPE})
|
|
||||||
target_link_libraries(${TARGET_NAME} PRIVATE whisper)
|
|
||||||
|
|
||||||
# Set the appropriate compiler flags for Windows, Linux, and macOS
|
|
||||||
if(WIN32)
|
|
||||||
target_compile_options(${TARGET_NAME} PRIVATE /W4 /D_CRT_SECURE_NO_WARNINGS)
|
|
||||||
elseif(UNIX AND NOT APPLE)
|
|
||||||
target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
|
|
||||||
elseif(APPLE)
|
|
||||||
target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
target_compile_definitions(${TARGET_NAME} PRIVATE WHISPER_SHARED)
|
|
||||||
# add_definitions(-DWHISPER_SHARED)
|
|
||||||
|
|
||||||
# Force CMake to save the libs to build/generated/resources/main/${os}-${arch} as required by JNA
|
|
||||||
foreach(OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES})
|
|
||||||
string(TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG)
|
|
||||||
set_target_properties(${TARGET_NAME} PROPERTIES
|
|
||||||
RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
|
|
||||||
LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
|
|
||||||
ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR})
|
|
||||||
endforeach(OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES)
|
|
@ -6,11 +6,7 @@ This package provides Java JNI bindings for whisper.cpp. They have been tested o
|
|||||||
* Ubuntu on x86_64
|
* Ubuntu on x86_64
|
||||||
* Windows on x86_64
|
* Windows on x86_64
|
||||||
|
|
||||||
The "low level" bindings are in `WhisperCppJnaLibrary` and `WhisperJavaJnaLibrary` which caches `whisper_full_params` and `whisper_context` in `whisper_java.cpp`.
|
The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
|
||||||
|
|
||||||
There are a lot of classes in the `callbacks`, `ggml`, `model` and `params` directories but most of them have not been tested.
|
|
||||||
|
|
||||||
The most simple usage is as follows:
|
|
||||||
|
|
||||||
```java
|
```java
|
||||||
import io.github.ggerganov.whispercpp.WhisperCpp;
|
import io.github.ggerganov.whispercpp.WhisperCpp;
|
||||||
@ -48,12 +44,6 @@ In order to build, you need to have the JDK 8 or higher installed. Run the tests
|
|||||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||||
cd whisper.cpp/bindings/java
|
cd whisper.cpp/bindings/java
|
||||||
|
|
||||||
mkdir build
|
|
||||||
pushd build
|
|
||||||
cmake ..
|
|
||||||
cmake --build .
|
|
||||||
popd
|
|
||||||
|
|
||||||
./gradlew build
|
./gradlew build
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -22,6 +22,12 @@ sourceSets {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tasks.register('copyLibwhisperDynlib', Copy) {
|
||||||
|
from '../../build'
|
||||||
|
include 'libwhisper.dynlib'
|
||||||
|
into 'build/generated/resources/main/darwin'
|
||||||
|
}
|
||||||
|
|
||||||
tasks.register('copyLibwhisperSo', Copy) {
|
tasks.register('copyLibwhisperSo', Copy) {
|
||||||
from '../../build'
|
from '../../build'
|
||||||
include 'libwhisper.so'
|
include 'libwhisper.so'
|
||||||
@ -34,7 +40,9 @@ tasks.register('copyWhisperDll', Copy) {
|
|||||||
into 'build/generated/resources/main/windows-x86-64'
|
into 'build/generated/resources/main/windows-x86-64'
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.build.dependsOn copyLibwhisperSo, copyWhisperDll
|
tasks.register('copyLibs') {
|
||||||
|
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
|
||||||
|
}
|
||||||
|
|
||||||
test {
|
test {
|
||||||
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
|
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
#include <stdio.h>
|
|
||||||
#include "whisper_java.h"
|
|
||||||
|
|
||||||
struct whisper_full_params default_params;
|
|
||||||
struct whisper_context * whisper_ctx = nullptr;
|
|
||||||
|
|
||||||
struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) {
|
|
||||||
default_params = whisper_full_default_params(strategy);
|
|
||||||
|
|
||||||
// struct whisper_java_params result = {};
|
|
||||||
// return result;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
void whisper_java_init_from_file(const char * path_model) {
|
|
||||||
whisper_ctx = whisper_init_from_file(path_model);
|
|
||||||
if (0 == default_params.n_threads) {
|
|
||||||
whisper_java_default_params(WHISPER_SAMPLING_GREEDY);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Delegates to whisper_full, but without having to pass `whisper_full_params` */
|
|
||||||
int whisper_java_full(
|
|
||||||
struct whisper_context * ctx,
|
|
||||||
// struct whisper_java_params params,
|
|
||||||
const float * samples,
|
|
||||||
int n_samples) {
|
|
||||||
return whisper_full(ctx, default_params, samples, n_samples);
|
|
||||||
}
|
|
||||||
|
|
||||||
void whisper_java_free() {
|
|
||||||
// free(default_params);
|
|
||||||
}
|
|
@ -1,24 +0,0 @@
|
|||||||
#define WHISPER_BUILD
|
|
||||||
#include <whisper.h>
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
struct whisper_java_params {
|
|
||||||
};
|
|
||||||
|
|
||||||
WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy);
|
|
||||||
|
|
||||||
WHISPER_API void whisper_java_init_from_file(const char * path_model);
|
|
||||||
|
|
||||||
WHISPER_API int whisper_java_full(
|
|
||||||
struct whisper_context * ctx,
|
|
||||||
// struct whisper_java_params params,
|
|
||||||
const float * samples,
|
|
||||||
int n_samples);
|
|
||||||
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
@ -1,7 +1,8 @@
|
|||||||
package io.github.ggerganov.whispercpp;
|
package io.github.ggerganov.whispercpp;
|
||||||
|
|
||||||
|
import com.sun.jna.Native;
|
||||||
import com.sun.jna.Pointer;
|
import com.sun.jna.Pointer;
|
||||||
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
@ -13,8 +14,9 @@ import java.io.IOException;
|
|||||||
*/
|
*/
|
||||||
public class WhisperCpp implements AutoCloseable {
|
public class WhisperCpp implements AutoCloseable {
|
||||||
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
|
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
|
||||||
private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance;
|
|
||||||
private Pointer ctx = null;
|
private Pointer ctx = null;
|
||||||
|
private Pointer greedyPointer = null;
|
||||||
|
private Pointer beamPointer = null;
|
||||||
|
|
||||||
public File modelDir() {
|
public File modelDir() {
|
||||||
String modelDirPath = System.getenv("XDG_CACHE_HOME");
|
String modelDirPath = System.getenv("XDG_CACHE_HOME");
|
||||||
@ -27,9 +29,8 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
|
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
|
||||||
* @return a Pointer to the WhisperContext
|
|
||||||
*/
|
*/
|
||||||
void initContext(String modelPath) throws FileNotFoundException {
|
public void initContext(String modelPath) throws FileNotFoundException {
|
||||||
if (ctx != null) {
|
if (ctx != null) {
|
||||||
lib.whisper_free(ctx);
|
lib.whisper_free(ctx);
|
||||||
}
|
}
|
||||||
@ -42,7 +43,6 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
|
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
|
||||||
}
|
}
|
||||||
|
|
||||||
javaLib.whisper_java_init_from_file(modelPath);
|
|
||||||
ctx = lib.whisper_init_from_file(modelPath);
|
ctx = lib.whisper_init_from_file(modelPath);
|
||||||
|
|
||||||
if (ctx == null) {
|
if (ctx == null) {
|
||||||
@ -51,22 +51,38 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialises `whisper_full_params` internally in whisper_java.cpp so JNA doesn't have to map everything.
|
* Provides default params which can be used with `whisper_full()` etc.
|
||||||
* `whisper_java_init_from_file()` calls `whisper_java_default_params(WHISPER_SAMPLING_GREEDY)` for convenience.
|
* Because this function allocates memory for the params, the caller must call either:
|
||||||
|
* - call `whisper_free_params()`
|
||||||
|
* - `Native.free(Pointer.nativeValue(pointer));`
|
||||||
|
*
|
||||||
|
* @param strategy - GREEDY
|
||||||
*/
|
*/
|
||||||
public void getDefaultJavaParams(WhisperSamplingStrategy strategy) {
|
public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
|
||||||
javaLib.whisper_java_default_params(strategy.ordinal());
|
Pointer pointer;
|
||||||
// return lib.whisper_full_default_params(strategy.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// whisper_full_default_params was too hard to integrate with, so for now we use javaLib.whisper_java_default_params
|
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
|
||||||
// fun getDefaultParams(strategy: WhisperSamplingStrategy): WhisperFullParams {
|
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
|
||||||
// return lib.whisper_full_default_params(strategy.value)
|
if (greedyPointer == null) {
|
||||||
// }
|
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
|
||||||
|
}
|
||||||
|
pointer = greedyPointer;
|
||||||
|
} else {
|
||||||
|
if (beamPointer == null) {
|
||||||
|
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
|
||||||
|
}
|
||||||
|
pointer = beamPointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
WhisperFullParams params = new WhisperFullParams(pointer);
|
||||||
|
params.read();
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() {
|
public void close() {
|
||||||
freeContext();
|
freeContext();
|
||||||
|
freeParams();
|
||||||
System.out.println("Whisper closed");
|
System.out.println("Whisper closed");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,17 +92,28 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void freeParams() {
|
||||||
|
if (greedyPointer != null) {
|
||||||
|
Native.free(Pointer.nativeValue(greedyPointer));
|
||||||
|
greedyPointer = null;
|
||||||
|
}
|
||||||
|
if (beamPointer != null) {
|
||||||
|
Native.free(Pointer.nativeValue(beamPointer));
|
||||||
|
beamPointer = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
||||||
* Not thread safe for same context
|
* Not thread safe for same context
|
||||||
* Uses the specified decoding strategy to obtain the text.
|
* Uses the specified decoding strategy to obtain the text.
|
||||||
*/
|
*/
|
||||||
public String fullTranscribe(/*WhisperJavaParams whisperParams,*/ float[] audioData) throws IOException {
|
public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
||||||
if (ctx == null) {
|
if (ctx == null) {
|
||||||
throw new IllegalStateException("Model not initialised");
|
throw new IllegalStateException("Model not initialised");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (javaLib.whisper_java_full(ctx, /*whisperParams,*/ audioData, audioData.length) != 0) {
|
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
||||||
throw new IOException("Failed to process audio");
|
throw new IOException("Failed to process audio");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,10 +231,21 @@ public interface WhisperCppJnaLibrary extends Library {
|
|||||||
void whisper_print_timings(Pointer ctx);
|
void whisper_print_timings(Pointer ctx);
|
||||||
void whisper_reset_timings(Pointer ctx);
|
void whisper_reset_timings(Pointer ctx);
|
||||||
|
|
||||||
|
// Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"
|
||||||
|
// when `whisper_full_default_params()` tries to return a struct.
|
||||||
|
// WhisperFullParams whisper_full_default_params(int strategy);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* Provides default params which can be used with `whisper_full()` etc.
|
||||||
|
* Because this function allocates memory for the params, the caller must call either:
|
||||||
|
* - call `whisper_free_params()`
|
||||||
|
* - `Native.free(Pointer.nativeValue(pointer));`
|
||||||
|
*
|
||||||
* @param strategy - WhisperSamplingStrategy.value
|
* @param strategy - WhisperSamplingStrategy.value
|
||||||
*/
|
*/
|
||||||
WhisperFullParams whisper_full_default_params(int strategy);
|
Pointer whisper_full_default_params_by_ref(int strategy);
|
||||||
|
|
||||||
|
void whisper_free_params(Pointer params);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
package io.github.ggerganov.whispercpp;
|
|
||||||
|
|
||||||
import com.sun.jna.Library;
|
|
||||||
import com.sun.jna.Native;
|
|
||||||
import com.sun.jna.Pointer;
|
|
||||||
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
|
||||||
|
|
||||||
interface WhisperJavaJnaLibrary extends Library {
|
|
||||||
WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class);
|
|
||||||
|
|
||||||
void whisper_java_default_params(int strategy);
|
|
||||||
|
|
||||||
void whisper_java_free();
|
|
||||||
|
|
||||||
void whisper_java_init_from_file(String modelPath);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
|
||||||
* Not thread safe for same context
|
|
||||||
* Uses the specified decoding strategy to obtain the text.
|
|
||||||
*/
|
|
||||||
int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples);
|
|
||||||
}
|
|
@ -20,5 +20,5 @@ public interface WhisperEncoderBeginCallback extends Callback {
|
|||||||
* @param user_data User data.
|
* @param user_data User data.
|
||||||
* @return True if the computation should proceed, false otherwise.
|
* @return True if the computation should proceed, false otherwise.
|
||||||
*/
|
*/
|
||||||
boolean callback(WhisperContext ctx, WhisperState state, Pointer user_data);
|
boolean callback(Pointer ctx, Pointer state, Pointer user_data);
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
package io.github.ggerganov.whispercpp.callbacks;
|
package io.github.ggerganov.whispercpp.callbacks;
|
||||||
|
|
||||||
|
import com.sun.jna.Callback;
|
||||||
import com.sun.jna.Pointer;
|
import com.sun.jna.Pointer;
|
||||||
import io.github.ggerganov.whispercpp.WhisperContext;
|
|
||||||
import io.github.ggerganov.whispercpp.model.WhisperState;
|
|
||||||
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
|
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
|
||||||
|
|
||||||
import javax.security.auth.callback.Callback;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback to filter logits.
|
* Callback to filter logits.
|
||||||
* Can be used to modify the logits before sampling.
|
* Can be used to modify the logits before sampling.
|
||||||
@ -24,5 +21,5 @@ public interface WhisperLogitsFilterCallback extends Callback {
|
|||||||
* @param logits The array of logits.
|
* @param logits The array of logits.
|
||||||
* @param user_data User data.
|
* @param user_data User data.
|
||||||
*/
|
*/
|
||||||
void callback(WhisperContext ctx, WhisperState state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
|
void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
|
||||||
}
|
}
|
||||||
|
@ -20,5 +20,5 @@ public interface WhisperNewSegmentCallback extends Callback {
|
|||||||
* @param n_new The number of newly generated text segments.
|
* @param n_new The number of newly generated text segments.
|
||||||
* @param user_data User data.
|
* @param user_data User data.
|
||||||
*/
|
*/
|
||||||
void callback(WhisperContext ctx, WhisperState state, int n_new, Pointer user_data);
|
void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
package io.github.ggerganov.whispercpp.callbacks;
|
package io.github.ggerganov.whispercpp.callbacks;
|
||||||
|
|
||||||
|
import com.sun.jna.Callback;
|
||||||
import com.sun.jna.Pointer;
|
import com.sun.jna.Pointer;
|
||||||
import io.github.ggerganov.whispercpp.WhisperContext;
|
import io.github.ggerganov.whispercpp.WhisperContext;
|
||||||
import io.github.ggerganov.whispercpp.model.WhisperState;
|
import io.github.ggerganov.whispercpp.model.WhisperState;
|
||||||
|
|
||||||
import javax.security.auth.callback.Callback;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback for progress updates.
|
* Callback for progress updates.
|
||||||
*/
|
*/
|
||||||
@ -19,5 +18,5 @@ public interface WhisperProgressCallback extends Callback {
|
|||||||
* @param progress The progress value.
|
* @param progress The progress value.
|
||||||
* @param user_data User data.
|
* @param user_data User data.
|
||||||
*/
|
*/
|
||||||
void callback(WhisperContext ctx, WhisperState state, int progress, Pointer user_data);
|
void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,19 @@
|
|||||||
|
package io.github.ggerganov.whispercpp.params;
|
||||||
|
|
||||||
|
import com.sun.jna.Structure;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class BeamSearchParams extends Structure {
|
||||||
|
/** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */
|
||||||
|
public int beam_size;
|
||||||
|
|
||||||
|
/** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */
|
||||||
|
public float patience;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected List<String> getFieldOrder() {
|
||||||
|
return Arrays.asList("beam_size", "patience");
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,30 @@
|
|||||||
|
package io.github.ggerganov.whispercpp.params;
|
||||||
|
|
||||||
|
import com.sun.jna.IntegerType;
|
||||||
|
|
||||||
|
import java.util.function.BooleanSupplier;
|
||||||
|
|
||||||
|
public class CBool extends IntegerType implements BooleanSupplier {
|
||||||
|
public static final int SIZE = 1;
|
||||||
|
public static final CBool FALSE = new CBool(0);
|
||||||
|
public static final CBool TRUE = new CBool(1);
|
||||||
|
|
||||||
|
|
||||||
|
public CBool() {
|
||||||
|
this(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CBool(long value) {
|
||||||
|
super(SIZE, value, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean getAsBoolean() {
|
||||||
|
return intValue() == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return intValue() == 1 ? "true" : "false";
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
package io.github.ggerganov.whispercpp.params;
|
||||||
|
|
||||||
|
import com.sun.jna.Structure;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class GreedyParams extends Structure {
|
||||||
|
/** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */
|
||||||
|
public int best_of;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected List<String> getFieldOrder() {
|
||||||
|
return Collections.singletonList("best_of");
|
||||||
|
}
|
||||||
|
}
|
@ -1,13 +1,14 @@
|
|||||||
package io.github.ggerganov.whispercpp.params;
|
package io.github.ggerganov.whispercpp.params;
|
||||||
|
|
||||||
import com.sun.jna.Callback;
|
import com.sun.jna.*;
|
||||||
import com.sun.jna.Pointer;
|
|
||||||
import com.sun.jna.Structure;
|
|
||||||
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
|
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
|
||||||
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
|
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
|
||||||
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
|
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
|
||||||
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parameters for the whisper_full() function.
|
* Parameters for the whisper_full() function.
|
||||||
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
|
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
|
||||||
@ -15,62 +16,123 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
|||||||
*/
|
*/
|
||||||
public class WhisperFullParams extends Structure {
|
public class WhisperFullParams extends Structure {
|
||||||
|
|
||||||
|
public WhisperFullParams(Pointer p) {
|
||||||
|
super(p);
|
||||||
|
// super(p, ALIGN_MSVC);
|
||||||
|
// super(p, ALIGN_GNUC);
|
||||||
|
}
|
||||||
|
|
||||||
/** Sampling strategy for whisper_full() function. */
|
/** Sampling strategy for whisper_full() function. */
|
||||||
public int strategy;
|
public int strategy;
|
||||||
|
|
||||||
/** Number of threads. */
|
/** Number of threads. (default = 4) */
|
||||||
public int n_threads;
|
public int n_threads;
|
||||||
|
|
||||||
/** Maximum tokens to use from past text as a prompt for the decoder. */
|
/** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */
|
||||||
public int n_max_text_ctx;
|
public int n_max_text_ctx;
|
||||||
|
|
||||||
/** Start offset in milliseconds. */
|
/** Start offset in milliseconds. (default = 0) */
|
||||||
public int offset_ms;
|
public int offset_ms;
|
||||||
|
|
||||||
/** Audio duration to process in milliseconds. */
|
/** Audio duration to process in milliseconds. (default = 0) */
|
||||||
public int duration_ms;
|
public int duration_ms;
|
||||||
|
|
||||||
/** Translate flag. */
|
/** Translate flag. (default = false) */
|
||||||
public boolean translate;
|
public CBool translate;
|
||||||
|
|
||||||
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. */
|
/** The compliment of translateMode() */
|
||||||
public boolean no_context;
|
public void transcribeMode() {
|
||||||
|
translate = CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
/** Flag to force single segment output (useful for streaming). */
|
/** The compliment of transcribeMode() */
|
||||||
public boolean single_segment;
|
public void translateMode() {
|
||||||
|
translate = CBool.TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). */
|
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
|
||||||
public boolean print_special;
|
public CBool no_context;
|
||||||
|
|
||||||
/** Flag to print progress information. */
|
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
|
||||||
public boolean print_progress;
|
public void enableContext(boolean enable) {
|
||||||
|
no_context = enable ? CBool.FALSE : CBool.TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). */
|
/** Flag to force single segment output (useful for streaming). (default = false) */
|
||||||
public boolean print_realtime;
|
public CBool single_segment;
|
||||||
|
|
||||||
/** Flag to print timestamps for each text segment when printing realtime. */
|
/** Flag to force single segment output (useful for streaming). (default = false) */
|
||||||
public boolean print_timestamps;
|
public void singleSegment(boolean single) {
|
||||||
|
single_segment = single ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
/** [EXPERIMENTAL] Flag to enable token-level timestamps. */
|
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
||||||
public boolean token_timestamps;
|
public CBool print_special;
|
||||||
|
|
||||||
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). */
|
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
||||||
|
public void printSpecial(boolean enable) {
|
||||||
|
print_special = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Flag to print progress information. (default = true) */
|
||||||
|
public CBool print_progress;
|
||||||
|
|
||||||
|
/** Flag to print progress information. (default = true) */
|
||||||
|
public void printProgress(boolean enable) {
|
||||||
|
print_progress = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
|
||||||
|
public CBool print_realtime;
|
||||||
|
|
||||||
|
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
|
||||||
|
public void printRealtime(boolean enable) {
|
||||||
|
print_realtime = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
|
||||||
|
public CBool print_timestamps;
|
||||||
|
|
||||||
|
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
|
||||||
|
public void printTimestamps(boolean enable) {
|
||||||
|
print_timestamps = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
|
||||||
|
public CBool token_timestamps;
|
||||||
|
|
||||||
|
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
|
||||||
|
public void tokenTimestamps(boolean enable) {
|
||||||
|
token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */
|
||||||
public float thold_pt;
|
public float thold_pt;
|
||||||
|
|
||||||
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
|
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
|
||||||
public float thold_ptsum;
|
public float thold_ptsum;
|
||||||
|
|
||||||
/** Maximum segment length in characters. */
|
/** Maximum segment length in characters. (default = 0) */
|
||||||
public int max_len;
|
public int max_len;
|
||||||
|
|
||||||
/** Flag to split on word rather than on token (when used with max_len). */
|
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
|
||||||
public boolean split_on_word;
|
public CBool split_on_word;
|
||||||
|
|
||||||
/** Maximum tokens per segment (0 = no limit). */
|
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
|
||||||
|
public void splitOnWord(boolean enable) {
|
||||||
|
split_on_word = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Maximum tokens per segment (0, default = no limit) */
|
||||||
public int max_tokens;
|
public int max_tokens;
|
||||||
|
|
||||||
/** Flag to speed up the audio by 2x using Phase Vocoder. */
|
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
|
||||||
public boolean speed_up;
|
public CBool speed_up;
|
||||||
|
|
||||||
|
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
|
||||||
|
public void speedUp(boolean enable) {
|
||||||
|
speed_up = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
/** Overwrite the audio context size (0 = use default). */
|
/** Overwrite the audio context size (0 = use default). */
|
||||||
public int audio_ctx;
|
public int audio_ctx;
|
||||||
@ -79,9 +141,15 @@ public class WhisperFullParams extends Structure {
|
|||||||
* These are prepended to any existing text context from a previous call. */
|
* These are prepended to any existing text context from a previous call. */
|
||||||
public String initial_prompt;
|
public String initial_prompt;
|
||||||
|
|
||||||
/** Prompt tokens. */
|
/** Prompt tokens. (int*) */
|
||||||
public Pointer prompt_tokens;
|
public Pointer prompt_tokens;
|
||||||
|
|
||||||
|
public void setPromptTokens(int[] tokens) {
|
||||||
|
Memory mem = new Memory(tokens.length * 4L);
|
||||||
|
mem.write(0, tokens, 0, tokens.length);
|
||||||
|
prompt_tokens = mem;
|
||||||
|
}
|
||||||
|
|
||||||
/** Number of prompt tokens. */
|
/** Number of prompt tokens. */
|
||||||
public int prompt_n_tokens;
|
public int prompt_n_tokens;
|
||||||
|
|
||||||
@ -90,15 +158,29 @@ public class WhisperFullParams extends Structure {
|
|||||||
public String language;
|
public String language;
|
||||||
|
|
||||||
/** Flag to indicate whether to detect language automatically. */
|
/** Flag to indicate whether to detect language automatically. */
|
||||||
public boolean detect_language;
|
public CBool detect_language;
|
||||||
|
|
||||||
/** Common decoding parameters. */
|
/** Flag to indicate whether to detect language automatically. */
|
||||||
|
public void detectLanguage(boolean enable) {
|
||||||
|
detect_language = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common decoding parameters.
|
||||||
|
|
||||||
/** Flag to suppress blank tokens. */
|
/** Flag to suppress blank tokens. */
|
||||||
public boolean suppress_blank;
|
public CBool suppress_blank;
|
||||||
|
|
||||||
|
public void suppressBlanks(boolean enable) {
|
||||||
|
suppress_blank = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
/** Flag to suppress non-speech tokens. */
|
/** Flag to suppress non-speech tokens. */
|
||||||
public boolean suppress_non_speech_tokens;
|
public CBool suppress_non_speech_tokens;
|
||||||
|
|
||||||
|
/** Flag to suppress non-speech tokens. */
|
||||||
|
public void suppressNonSpeechTokens(boolean enable) {
|
||||||
|
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
/** Initial decoding temperature. */
|
/** Initial decoding temperature. */
|
||||||
public float temperature;
|
public float temperature;
|
||||||
@ -109,7 +191,7 @@ public class WhisperFullParams extends Structure {
|
|||||||
/** Length penalty. */
|
/** Length penalty. */
|
||||||
public float length_penalty;
|
public float length_penalty;
|
||||||
|
|
||||||
/** Fallback parameters. */
|
// Fallback parameters.
|
||||||
|
|
||||||
/** Temperature increment. */
|
/** Temperature increment. */
|
||||||
public float temperature_inc;
|
public float temperature_inc;
|
||||||
@ -123,31 +205,41 @@ public class WhisperFullParams extends Structure {
|
|||||||
/** No speech threshold. */
|
/** No speech threshold. */
|
||||||
public float no_speech_thold;
|
public float no_speech_thold;
|
||||||
|
|
||||||
class GreedyParams extends Structure {
|
|
||||||
/** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */
|
|
||||||
public int best_of;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Greedy decoding parameters. */
|
/** Greedy decoding parameters. */
|
||||||
public GreedyParams greedy;
|
public GreedyParams greedy;
|
||||||
|
|
||||||
class BeamSearchParams extends Structure {
|
|
||||||
/** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */
|
|
||||||
int beam_size;
|
|
||||||
|
|
||||||
/** ref: https://arxiv.org/pdf/2204.05424.pdf */
|
|
||||||
float patience;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Beam search decoding parameters.
|
* Beam search decoding parameters.
|
||||||
*/
|
*/
|
||||||
public BeamSearchParams beam_search;
|
public BeamSearchParams beam_search;
|
||||||
|
|
||||||
|
public void setBestOf(int bestOf) {
|
||||||
|
if (greedy == null) {
|
||||||
|
greedy = new GreedyParams();
|
||||||
|
}
|
||||||
|
greedy.best_of = bestOf;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBeamSize(int beamSize) {
|
||||||
|
if (beam_search == null) {
|
||||||
|
beam_search = new BeamSearchParams();
|
||||||
|
}
|
||||||
|
beam_search.beam_size = beamSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBeamSizeAndPatience(int beamSize, float patience) {
|
||||||
|
if (beam_search == null) {
|
||||||
|
beam_search = new BeamSearchParams();
|
||||||
|
}
|
||||||
|
beam_search.beam_size = beamSize;
|
||||||
|
beam_search.patience = patience;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback for every newly generated text segment.
|
* Callback for every newly generated text segment.
|
||||||
|
* WhisperNewSegmentCallback
|
||||||
*/
|
*/
|
||||||
public WhisperNewSegmentCallback new_segment_callback;
|
public Pointer new_segment_callback;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User data for the new_segment_callback.
|
* User data for the new_segment_callback.
|
||||||
@ -156,8 +248,9 @@ public class WhisperFullParams extends Structure {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback on each progress update.
|
* Callback on each progress update.
|
||||||
|
* WhisperProgressCallback
|
||||||
*/
|
*/
|
||||||
public WhisperProgressCallback progress_callback;
|
public Pointer progress_callback;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User data for the progress_callback.
|
* User data for the progress_callback.
|
||||||
@ -166,8 +259,9 @@ public class WhisperFullParams extends Structure {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback each time before the encoder starts.
|
* Callback each time before the encoder starts.
|
||||||
|
* WhisperEncoderBeginCallback
|
||||||
*/
|
*/
|
||||||
public WhisperEncoderBeginCallback encoder_begin_callback;
|
public Pointer encoder_begin_callback;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User data for the encoder_begin_callback.
|
* User data for the encoder_begin_callback.
|
||||||
@ -176,12 +270,44 @@ public class WhisperFullParams extends Structure {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback by each decoder to filter obtained logits.
|
* Callback by each decoder to filter obtained logits.
|
||||||
|
* WhisperLogitsFilterCallback
|
||||||
*/
|
*/
|
||||||
public WhisperLogitsFilterCallback logits_filter_callback;
|
public Pointer logits_filter_callback;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User data for the logits_filter_callback.
|
* User data for the logits_filter_callback.
|
||||||
*/
|
*/
|
||||||
public Pointer logits_filter_callback_user_data;
|
public Pointer logits_filter_callback_user_data;
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {
|
||||||
|
new_segment_callback = CallbackReference.getFunctionPointer(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setProgressCallback(WhisperProgressCallback callback) {
|
||||||
|
progress_callback = CallbackReference.getFunctionPointer(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {
|
||||||
|
encoder_begin_callback = CallbackReference.getFunctionPointer(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {
|
||||||
|
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected List<String> getFieldOrder() {
|
||||||
|
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
|
||||||
|
"no_context", "single_segment",
|
||||||
|
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
||||||
|
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
||||||
|
"initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||||
|
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
|
||||||
|
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
||||||
|
"new_segment_callback", "new_segment_callback_user_data",
|
||||||
|
"progress_callback", "progress_callback_user_data",
|
||||||
|
"encoder_begin_callback", "encoder_begin_callback_user_data",
|
||||||
|
"logits_filter_callback", "logits_filter_callback_user_data");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
package io.github.ggerganov.whispercpp.params;
|
|
||||||
|
|
||||||
import com.sun.jna.Structure;
|
|
||||||
|
|
||||||
public class WhisperJavaParams extends Structure {
|
|
||||||
|
|
||||||
}
|
|
@ -2,7 +2,8 @@ package io.github.ggerganov.whispercpp;
|
|||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
import io.github.ggerganov.whispercpp.params.CBool;
|
||||||
|
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
@ -19,11 +20,11 @@ class WhisperCppTest {
|
|||||||
static void init() throws FileNotFoundException {
|
static void init() throws FileNotFoundException {
|
||||||
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
||||||
// or you can provide the absolute path to the model file.
|
// or you can provide the absolute path to the model file.
|
||||||
String modelName = "base.en";
|
String modelName = "../../models/ggml-tiny.en.bin";
|
||||||
try {
|
try {
|
||||||
whisper.initContext(modelName);
|
whisper.initContext(modelName);
|
||||||
whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||||
// whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||||
modelInitialised = true;
|
modelInitialised = true;
|
||||||
} catch (FileNotFoundException ex) {
|
} catch (FileNotFoundException ex) {
|
||||||
System.out.println("Model " + modelName + " not found");
|
System.out.println("Model " + modelName + " not found");
|
||||||
@ -31,11 +32,30 @@ class WhisperCppTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testGetDefaultJavaParams() {
|
void testGetDefaultFullParams_BeamSearch() {
|
||||||
// When
|
// When
|
||||||
whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||||
|
|
||||||
// Then if it doesn't throw we've connected to whisper.cpp
|
// Then
|
||||||
|
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);
|
||||||
|
assertNotEquals(0, params.n_threads);
|
||||||
|
assertEquals(16384, params.n_max_text_ctx);
|
||||||
|
assertFalse(params.translate);
|
||||||
|
assertEquals(0.01f, params.thold_pt);
|
||||||
|
assertEquals(2, params.beam_search.beam_size);
|
||||||
|
assertEquals(-1.0f, params.beam_search.patience);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testGetDefaultFullParams_Greedy() {
|
||||||
|
// When
|
||||||
|
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||||
|
|
||||||
|
// Then
|
||||||
|
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
|
||||||
|
assertNotEquals(0, params.n_threads);
|
||||||
|
assertEquals(16384, params.n_max_text_ctx);
|
||||||
|
assertEquals(2, params.greedy.best_of);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -52,6 +72,13 @@ class WhisperCppTest {
|
|||||||
byte[] b = new byte[audioInputStream.available()];
|
byte[] b = new byte[audioInputStream.available()];
|
||||||
float[] floats = new float[b.length / 2];
|
float[] floats = new float[b.length / 2];
|
||||||
|
|
||||||
|
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||||
|
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||||
|
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||||
|
params.print_progress = CBool.FALSE;
|
||||||
|
// params.initial_prompt = "and so my fellow Americans um, like";
|
||||||
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
audioInputStream.read(b);
|
audioInputStream.read(b);
|
||||||
|
|
||||||
@ -61,13 +88,13 @@ class WhisperCppTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// When
|
// When
|
||||||
String result = whisper.fullTranscribe(/*params,*/ floats);
|
String result = whisper.fullTranscribe(params, floats);
|
||||||
|
|
||||||
// Then
|
// Then
|
||||||
System.out.println(result);
|
System.err.println(result);
|
||||||
assertEquals("And so my fellow Americans, ask not what your country can do for you, " +
|
assertEquals("And so my fellow Americans ask not what your country can do for you " +
|
||||||
"ask what you can do for your country.",
|
"ask what you can do for your country.",
|
||||||
result);
|
result.replace(",", ""));
|
||||||
} finally {
|
} finally {
|
||||||
audioInputStream.close();
|
audioInputStream.close();
|
||||||
}
|
}
|
||||||
|
14
whisper.cpp
14
whisper.cpp
@ -2852,6 +2852,12 @@ void whisper_free(struct whisper_context * ctx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void whisper_free_params(struct whisper_full_params * params) {
|
||||||
|
if (params) {
|
||||||
|
delete params;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
||||||
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
|
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
|
||||||
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
||||||
@ -3285,6 +3291,14 @@ const char * whisper_print_system_info(void) {
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
|
||||||
|
struct whisper_full_params params = whisper_full_default_params(strategy);
|
||||||
|
|
||||||
|
struct whisper_full_params* result = new whisper_full_params();
|
||||||
|
*result = params;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
|
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
|
||||||
struct whisper_full_params result = {
|
struct whisper_full_params result = {
|
||||||
/*.strategy =*/ strategy,
|
/*.strategy =*/ strategy,
|
||||||
|
@ -113,6 +113,7 @@ extern "C" {
|
|||||||
// Frees all allocated memory
|
// Frees all allocated memory
|
||||||
WHISPER_API void whisper_free (struct whisper_context * ctx);
|
WHISPER_API void whisper_free (struct whisper_context * ctx);
|
||||||
WHISPER_API void whisper_free_state(struct whisper_state * state);
|
WHISPER_API void whisper_free_state(struct whisper_state * state);
|
||||||
|
WHISPER_API void whisper_free_params(struct whisper_full_params * params);
|
||||||
|
|
||||||
// Convert RAW PCM audio to log mel spectrogram.
|
// Convert RAW PCM audio to log mel spectrogram.
|
||||||
// The resulting spectrogram is stored inside the default state of the provided whisper context.
|
// The resulting spectrogram is stored inside the default state of the provided whisper context.
|
||||||
@ -409,6 +410,8 @@ extern "C" {
|
|||||||
void * logits_filter_callback_user_data;
|
void * logits_filter_callback_user_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
|
||||||
|
WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
|
||||||
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
||||||
|
|
||||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||||
|
Loading…
Reference in New Issue
Block a user