From 1558ec5a16cb2b2a0bf54815df1d41f83dc3815b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 Mar 2024 14:48:19 +0200 Subject: [PATCH] whisper : improve handling of prompts (#1981) * whisper : improve handling of prompts * whisper : add whisper_token_count helper --- examples/main/main.cpp | 2 +- whisper.cpp | 13 +++++++++++-- whisper.h | 8 +++++++- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index caa800b..415c3b3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -207,7 +207,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); diff --git a/whisper.cpp b/whisper.cpp index 0a820c4..02800a3 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3721,7 +3721,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to if (n_max_tokens < (int) res.size()) { WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); - return -1; + return -(int) res.size(); } for (int i = 0; i < (int) res.size(); i++) { @@ -3731,6 +3731,10 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to return res.size(); } +int whisper_token_count(struct whisper_context * ctx, const char * text) { + return -whisper_tokenize(ctx, text, NULL, 0); +} + int whisper_lang_max_id() { auto max_id = 0; for (const auto & kv : g_lang) { @@ -5313,7 +5317,12 @@ int whisper_full_with_state( // initial prompt if (!params.prompt_tokens && params.initial_prompt) { prompt_tokens.resize(1024); - prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size())); + int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); + if (n_needed < 0) { + prompt_tokens.resize(-n_needed); + n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); + } + prompt_tokens.resize(n_needed); params.prompt_tokens = prompt_tokens.data(); params.prompt_n_tokens = prompt_tokens.size(); } diff --git a/whisper.h b/whisper.h index 2754337..bd8d8df 100644 --- a/whisper.h +++ b/whisper.h @@ -337,7 +337,7 @@ extern "C" { // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens - // Returns -1 on failure + // Returns a negative number on failure - the number of tokens that would have been returned // TODO: not sure if correct WHISPER_API int whisper_tokenize( struct whisper_context * ctx, @@ -345,6 +345,10 @@ extern "C" { whisper_token * tokens, int n_max_tokens); + // Return the number of tokens in the provided text + // Equivalent to: -whisper_tokenize(ctx, text, NULL, 0) + int whisper_token_count(struct whisper_context * ctx, const char * text); + // Largest language id (i.e. number of available languages - 1) WHISPER_API int whisper_lang_max_id(); @@ -503,6 +507,8 @@ extern "C" { // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call + // use whisper_tokenize() to convert text to tokens + // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224) const char * initial_prompt; const whisper_token * prompt_tokens; int prompt_n_tokens;