main : add "--prompt" command line argument (#90)
This allows to provide an initial prompt to be used at the start of the processing.
This commit is contained in:
parent
4312995974
commit
b8065d90f5
@ -73,8 +73,9 @@ struct whisper_params {
|
|||||||
bool print_colors = false;
|
bool print_colors = false;
|
||||||
bool no_timestamps = false;
|
bool no_timestamps = false;
|
||||||
|
|
||||||
std::string language = "en";
|
std::string language = "en";
|
||||||
std::string model = "models/ggml-base.en.bin";
|
std::string prompt = "";
|
||||||
|
std::string model = "models/ggml-base.en.bin";
|
||||||
|
|
||||||
std::vector<std::string> fname_inp = {};
|
std::vector<std::string> fname_inp = {};
|
||||||
};
|
};
|
||||||
@ -113,6 +114,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||||
|
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
|
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
|
||||||
else {
|
else {
|
||||||
@ -150,6 +152,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|||||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||||
|
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
@ -462,6 +465,22 @@ int main(int argc, char ** argv) {
|
|||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initial prompt
|
||||||
|
std::vector<whisper_token> prompt_tokens;
|
||||||
|
|
||||||
|
if (params.prompt.size() > 0) {
|
||||||
|
prompt_tokens.resize(1024);
|
||||||
|
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
||||||
|
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
|
||||||
|
fprintf(stderr, "initial tokens: [ ");
|
||||||
|
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
|
||||||
|
fprintf(stderr, "%d ", prompt_tokens[i]);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "]\n");
|
||||||
|
}
|
||||||
|
|
||||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||||
const auto fname_inp = params.fname_inp[f];
|
const auto fname_inp = params.fname_inp[f];
|
||||||
|
|
||||||
@ -577,7 +596,6 @@ int main(int argc, char ** argv) {
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// run the inference
|
// run the inference
|
||||||
{
|
{
|
||||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
@ -599,6 +617,9 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
wparams.speed_up = params.speed_up;
|
wparams.speed_up = params.speed_up;
|
||||||
|
|
||||||
|
wparams.prompt_tokens = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
|
||||||
|
wparams.prompt_n_tokens = prompt_tokens.size() == 0 ? 0 : prompt_tokens.size();
|
||||||
|
|
||||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||||
|
|
||||||
// this callback is called on each new segment
|
// this callback is called on each new segment
|
||||||
|
Loading…
Reference in New Issue
Block a user