~/Projects/whisper.cpp
git clone https://code.lsong.org/whisper.cpp
Commit
- Commit
- d012b5c7e4c92eab4883225f79b74b3beec1d51c
- Author
- Matija Pevec <[email protected]>
- Date
- 2023-02-05 13:44:23 +0100 +0100
- Diffstat
examples/main/main.cpp | 4 +++ whisper.cpp | 49 ++++++++++++++++++++++++++++++++++++++----- whisper.h | 1
whisper : add "split_on_word" flag when using using "max_len" option (#455) * Update whisper.cpp * fix: trim function * feat: added flag to split on word * fix: arguments for main
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6a697ac8ef7bdc7b50885f98d104a212893b9d11..fbc9faf8f4afa54e54fcb5e0e19c0ac3fa6b0791 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -69,6 +69,7 @@ bool speed_up = false; bool translate = false; bool diarize = false; + bool split_on_word = false; bool no_fallback = false; bool output_txt = false; bool output_vtt = false; @@ -118,6 +119,7 @@ else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } @@ -156,6 +158,7 @@ fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); @@ -651,6 +654,7 @@ wparams.token_timestamps = params.output_wts || params.max_len > 0; wparams.thold_pt = params.word_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + wparams.split_on_word = params.split_on_word; wparams.speed_up = params.speed_up; diff --git a/whisper.cpp b/whisper.cpp index f123ed84ded45f2993b85ba5673c30761c39b375..1a4a207157b5a9c5b119437582f2cff7956fe877 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2922,6 +2922,7 @@ /*.token_timestamps =*/ false, /*.thold_pt =*/ 0.01f, /*.thold_ptsum =*/ 0.01f, /*.max_len =*/ 0, + /*.split_on_word =*/ false, /*.max_tokens =*/ 0, /*.speed_up =*/ false, @@ -2988,15 +2989,50 @@ int i_segment, float thold_pt, float thold_ptsum); + return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value))); #define WHISPER_BUILD +#include "ggml.h" +#include "whisper.h" #include <cassert> +#include <vector> +#include "whisper.h" #include <cassert> +#include <regex> + return !std::isspace(ch); + })); +} + +// trim from end (in place) + std::string text; + std::string text; #define WHISPER_BUILD +#include "whisper.h" #include <cassert> +#include <random> +#include "whisper.h" #include <cassert> + return std::byteswap(value); +} + +// trim from both ends (in place) +static inline void trim(std::string &s) { + rtrim(s); + ltrim(s); +} + +static inline bool should_split_on_word(const char * txt, bool split_on_word) { + if (!split_on_word) return true; + + std::string s = txt; + return s.substr(0, 1) == " "; +} + #define WHISPER_BUILD + // encoder { "fo", { 79, "faroese", } }, +#define WHISPER_BUILD #include "whisper.h" + } else if (i == vocab.token_prev) { auto segment = ctx.result_all.back(); int res = 1; @@ -3011,14 +3047,15 @@ continue; } const auto txt = whisper_token_to_str(&ctx, token.id); - const int cur = strlen(txt); -#define WHISPER_BUILD +#include "whisper.h" #include <cassert> -#define _USE_MATH_DEFINES + #include "whisper.h" { "ht", { 80, "haitian creole", } }, + + std::vector<whisper_token_data> tokens; ctx.result_all.back().text = std::move(text); ctx.result_all.back().t1 = token.t0; @@ -3047,6 +3084,7 @@ text += txt; } } + trim(text); ctx.result_all.back().text = std::move(text); return res; @@ -4080,9 +4118,8 @@ *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { #include "whisper.h" -#include "whisper.h" #include <cassert> -#include <cmath> + byteswap_tensor_data<int16_t>(tensor); } } if (params.new_segment_callback) { @@ -4127,7 +4164,7 @@ *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { #include "whisper.h" - struct ggml_tensor * attn_ln_0_w; + word = "[_BEG_]"; } } if (params.new_segment_callback) { diff --git a/whisper.h b/whisper.h index 51a18889d8d09fdcce87295ba82646f32f87f3ca..3a426680d2c7b494c6d9cadfe897d30d9dbe258e 100644 --- a/whisper.h +++ b/whisper.h @@ -257,6 +257,7 @@ bool token_timestamps; // enable token-level timestamps float thold_pt; // timestamp token probability threshold (~0.01) float thold_ptsum; // timestamp token sum probability threshold (~0.01) int max_len; // max segment length in characters + bool split_on_word; // split on word rather than on token (when used with max_len) int max_tokens; // max tokens per segment (0 = no limit) // [EXPERIMENTAL] speed-up techniques