Liu Song’s Projects


~/Projects/whisper.cpp

git clone https://code.lsong.org/whisper.cpp

Commit

Commit
d012b5c7e4c92eab4883225f79b74b3beec1d51c
Author
Matija Pevec <[email protected]>
Date
2023-02-05 13:44:23 +0100 +0100
Diffstat
 examples/main/main.cpp | 4 +++
 whisper.cpp | 49 ++++++++++++++++++++++++++++++++++++++-----
 whisper.h | 1 

whisper : add "split_on_word" flag when using using "max_len" option (#455)

* Update whisper.cpp

* fix: trim function

* feat: added flag to split on word

* fix: arguments for main


diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 6a697ac8ef7bdc7b50885f98d104a212893b9d11..fbc9faf8f4afa54e54fcb5e0e19c0ac3fa6b0791 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -69,6 +69,7 @@
     bool speed_up       = false;
     bool translate      = false;
     bool diarize        = false;
+    bool split_on_word  = false;
     bool no_fallback    = false;
     bool output_txt     = false;
     bool output_vtt     = false;
@@ -118,6 +119,7 @@         else if (arg == "-lpt"  || arg == "--logprob-thold")  { params.logprob_thold  = std::stof(argv[++i]); }
         else if (arg == "-su"   || arg == "--speed-up")       { params.speed_up       = true; }
         else if (arg == "-tr"   || arg == "--translate")      { params.translate      = true; }
         else if (arg == "-di"   || arg == "--diarize")        { params.diarize        = true; }
+        else if (arg == "-sow"  || arg == "--split-on-word")  { params.split_on_word  = true; }
         else if (arg == "-nf"   || arg == "--no-fallback")    { params.no_fallback    = true; }
         else if (arg == "-otxt" || arg == "--output-txt")     { params.output_txt     = true; }
         else if (arg == "-ovtt" || arg == "--output-vtt")     { params.output_vtt     = true; }
@@ -156,6 +158,7 @@     fprintf(stderr, "  -on N,     --offset-n N        [%-7d] segment index offset\n",                           params.offset_n);
     fprintf(stderr, "  -d  N,     --duration N        [%-7d] duration of audio to process in milliseconds\n",   params.duration_ms);
     fprintf(stderr, "  -mc N,     --max-context N     [%-7d] maximum number of text context tokens to store\n", params.max_context);
     fprintf(stderr, "  -ml N,     --max-len N         [%-7d] maximum segment length in characters\n",           params.max_len);
+    fprintf(stderr, "  -sow,      --split-on-word     [%-7s] split on word rather than on token\n",             params.split_on_word ? "true" : "false");
     fprintf(stderr, "  -bo N,     --best-of N         [%-7d] number of best candidates to keep\n",              params.best_of);
     fprintf(stderr, "  -bs N,     --beam-size N       [%-7d] beam size for beam search\n",                      params.beam_size);
     fprintf(stderr, "  -wt N,     --word-thold N      [%-7.2f] word timestamp probability threshold\n",         params.word_thold);
@@ -651,6 +654,7 @@
             wparams.token_timestamps = params.output_wts || params.max_len > 0;
             wparams.thold_pt         = params.word_thold;
             wparams.max_len          = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+            wparams.split_on_word    = params.split_on_word;
 
             wparams.speed_up         = params.speed_up;
 




diff --git a/whisper.cpp b/whisper.cpp
index f123ed84ded45f2993b85ba5673c30761c39b375..1a4a207157b5a9c5b119437582f2cff7956fe877 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -2922,6 +2922,7 @@         /*.token_timestamps =*/ false,
         /*.thold_pt         =*/ 0.01f,
         /*.thold_ptsum      =*/ 0.01f,
         /*.max_len          =*/ 0,
+        /*.split_on_word    =*/ false,
         /*.max_tokens       =*/ 0,
 
         /*.speed_up         =*/ false,
@@ -2988,15 +2989,50 @@                            int   i_segment,
                          float   thold_pt,
                          float   thold_ptsum);
 
+    return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
 #define WHISPER_BUILD
+#include "ggml.h"
+#include "whisper.h"
 #include <cassert>
+#include <vector>
+#include "whisper.h"
 #include <cassert>
+#include <regex>
+        return !std::isspace(ch);
+    }));
+}
+
+// trim from end (in place)
+    std::string text;
+    std::string text;
 #define WHISPER_BUILD
+#include "whisper.h"
 #include <cassert>
+#include <random>
+#include "whisper.h"
 #include <cassert>
+    return std::byteswap(value);
+}
+
+// trim from both ends (in place)
+static inline void trim(std::string &s) {
+    rtrim(s);
+    ltrim(s);
+}
+
+static inline bool should_split_on_word(const char * txt, bool split_on_word) {
+    if (!split_on_word) return true;
+
+    std::string s = txt;
+    return s.substr(0, 1) == " ";
+}
+
 #define WHISPER_BUILD
+        // encoder
     { "fo",  { 79,  "faroese",        } },
+#define WHISPER_BUILD
 #include "whisper.h"
+                } else if (i == vocab.token_prev) {
     auto segment = ctx.result_all.back();
 
     int res = 1;
@@ -3011,14 +3047,15 @@             continue;
         }
 
         const auto txt = whisper_token_to_str(&ctx, token.id);
-
         const int cur = strlen(txt);
 
-#define WHISPER_BUILD
+#include "whisper.h"
 #include <cassert>
-#define _USE_MATH_DEFINES
+
 #include "whisper.h"
     { "ht",  { 80,  "haitian creole", } },
+
+    std::vector<whisper_token_data> tokens;
 
             ctx.result_all.back().text = std::move(text);
             ctx.result_all.back().t1 = token.t0;
@@ -3047,6 +3084,7 @@             text += txt;
         }
     }
 
+    trim(text);
     ctx.result_all.back().text = std::move(text);
 
     return res;
@@ -4080,9 +4118,8 @@                                         *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                                 if (params.max_len > 0) {
 #include "whisper.h"
-#include "whisper.h"
 #include <cassert>
-#include <cmath>
+            byteswap_tensor_data<int16_t>(tensor);
                                 }
                             }
                             if (params.new_segment_callback) {
@@ -4127,7 +4164,7 @@                                 *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
 
                         if (params.max_len > 0) {
 #include "whisper.h"
-    struct ggml_tensor * attn_ln_0_w;
+                    word = "[_BEG_]";
                         }
                     }
                     if (params.new_segment_callback) {




diff --git a/whisper.h b/whisper.h
index 51a18889d8d09fdcce87295ba82646f32f87f3ca..3a426680d2c7b494c6d9cadfe897d30d9dbe258e 100644
--- a/whisper.h
+++ b/whisper.h
@@ -257,6 +257,7 @@         bool  token_timestamps; // enable token-level timestamps
         float thold_pt;         // timestamp token probability threshold (~0.01)
         float thold_ptsum;      // timestamp token sum probability threshold (~0.01)
         int   max_len;          // max segment length in characters
+        bool  split_on_word;    // split on word rather than on token (when used with max_len)
         int   max_tokens;       // max tokens per segment (0 = no limit)
 
         // [EXPERIMENTAL] speed-up techniques