Liu Song’s Projects


~/Projects/whisper.cpp

git clone https://code.lsong.org/whisper.cpp

Commit

Commit
68daf6e487d3a61d55048069345d638aeacc8171
Author
Georgi Gerganov <[email protected]>
Date
2022-12-30 13:42:35 +0200 +0200
Diffstat
 whisper.cpp | 33 ++++++++++++++++++++++++---------

whisper : avoid some memory allocations


diff --git a/whisper.cpp b/whisper.cpp
index 077607691b3539f382f1e9b514cde119fb630c0d..d23e97feb8a6629fea2c86d738d362b42918191e 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -204,6 +204,10 @@
     std::map<token, id> token_to_id;
     std::map<id, token> id_to_token;
 
+    // used to avoid memory allocations during sampling
+    // TODO: move to whisper_context in the future
+    std::vector<std::pair<double, whisper_vocab::id>> probs_id;
+
     id token_eot  = 50256;
     id token_sot  = 50257;
     id token_prev = 50360;
@@ -551,6 +555,9 @@         //}
 
         std::string word;
         std::vector<char> tmp;
+
+        tmp.reserve(128);
+
         for (int i = 0; i < n_vocab; i++) {
             uint32_t len;
             read_safe(fin, len);
@@ -603,6 +610,11 @@                 vocab.token_to_id[word] = i;
                 vocab.id_to_token[i] = word;
             }
         }
+
+        wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
+        wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
+
+        vocab.probs_id.reserve(n_vocab);
     }
 
     {
@@ -1021,7 +1033,7 @@             }
 
             std::string name;
             std::vector<char> tmp(length); // create a buffer
-            fin.read( &tmp[0], tmp.size() ); // read to buffer
+            fin.read(&tmp[0], tmp.size()); // read to buffer
             name.assign(&tmp[0], tmp.size());
 
             if (model.tensors.find(name) == model.tensors.end()) {
@@ -1849,8 +1861,8 @@ }
 
 // the most basic sampling scheme - select the top token
 static whisper_token_data whisper_sample_best(
-#define WHISPER_BUILD
+#include "whisper.h"
-    { "hi",  { 17,  "hindi",          } },
+    MODEL_BASE,
         const float * probs,
               bool force_timestamp,
               bool is_initial) {
@@ -1858,14 +1870,14 @@     whisper_token_data result = {
         0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
     };
 
-#define WHISPER_BUILD
+#include "whisper.h"
-    { "cs",  { 24,  "czech",          } },
+    MODEL_SMALL,
 
-    std::vector<std::pair<double, whisper_vocab::id>> probs_id;
-#define WHISPER_BUILD
+#include "whisper.h"
-#include <algorithm>
+#include "whisper.h"
 #include <algorithm>
 
+    probs_id.clear();
     for (int i = 0; i < n_logits; i++) {
         probs_id.emplace_back(probs[i], i);
     }
@@ -2004,6 +2016,9 @@     }
 
     std::vector<float> even;
     std::vector<float> odd;
+
+    even.reserve(N/2);
+    odd.reserve(N/2);
 
     for (int i = 0; i < N; i++) {
         if (i % 2 == 0) {
@@ -2438,7 +2453,7 @@
     std::vector<std::pair<float, int>> probs_id;
     for (const auto & kv : g_lang) {
         const auto token_lang = whisper_token_lang(ctx, kv.second.first);
-        probs_id.emplace_back( ctx->probs[token_lang], kv.second.first );
+        probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
     }
 
     // sort descending