diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6a3db73..bb19318 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -44,6 +44,8 @@ struct whisper_params { float entropy_thold = 2.40f; float logprob_thold = -1.00f; float grammar_penalty = 100.0f; + float temperature = 0.0f; + float temperature_inc = 0.2f; bool speed_up = false; bool debug_mode = false; @@ -133,6 +135,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); } + else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); } // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } @@ -198,6 +202,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); + fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); @@ -1107,7 +1113,9 @@ int main(int argc, char ** argv) { wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; - wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc; + wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc; + wparams.temperature = params.temperature; + wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold;