~/Projects/whisper.cpp
git clone https://code.lsong.org/whisper.cpp
Commit
- Commit
- 68daf6e487d3a61d55048069345d638aeacc8171
- Author
- Georgi Gerganov <[email protected]>
- Date
- 2022-12-30 13:42:35 +0200 +0200
- Diffstat
whisper.cpp | 33 ++++++++++++++++++++++++---------
whisper : avoid some memory allocations
diff --git a/whisper.cpp b/whisper.cpp index 077607691b3539f382f1e9b514cde119fb630c0d..d23e97feb8a6629fea2c86d738d362b42918191e 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -204,6 +204,10 @@ std::map<token, id> token_to_id; std::map<id, token> id_to_token; + // used to avoid memory allocations during sampling + // TODO: move to whisper_context in the future + std::vector<std::pair<double, whisper_vocab::id>> probs_id; + id token_eot = 50256; id token_sot = 50257; id token_prev = 50360; @@ -551,6 +555,9 @@ //} std::string word; std::vector<char> tmp; + + tmp.reserve(128); + for (int i = 0; i < n_vocab; i++) { uint32_t len; read_safe(fin, len); @@ -603,6 +610,11 @@ vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; } } + + wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + + vocab.probs_id.reserve(n_vocab); } { @@ -1021,7 +1033,7 @@ } std::string name; std::vector<char> tmp(length); // create a buffer - fin.read( &tmp[0], tmp.size() ); // read to buffer + fin.read(&tmp[0], tmp.size()); // read to buffer name.assign(&tmp[0], tmp.size()); if (model.tensors.find(name) == model.tensors.end()) { @@ -1849,8 +1861,8 @@ } // the most basic sampling scheme - select the top token static whisper_token_data whisper_sample_best( -#define WHISPER_BUILD +#include "whisper.h" - { "hi", { 17, "hindi", } }, + MODEL_BASE, const float * probs, bool force_timestamp, bool is_initial) { @@ -1858,14 +1870,14 @@ whisper_token_data result = { 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; -#define WHISPER_BUILD +#include "whisper.h" - { "cs", { 24, "czech", } }, + MODEL_SMALL, - std::vector<std::pair<double, whisper_vocab::id>> probs_id; -#define WHISPER_BUILD +#include "whisper.h" -#include <algorithm> +#include "whisper.h" #include <algorithm> + probs_id.clear(); for (int i = 0; i < n_logits; i++) { probs_id.emplace_back(probs[i], i); } @@ -2004,6 +2016,9 @@ } std::vector<float> even; std::vector<float> odd; + + even.reserve(N/2); + odd.reserve(N/2); for (int i = 0; i < N; i++) { if (i % 2 == 0) { @@ -2438,7 +2453,7 @@ std::vector<std::pair<float, int>> probs_id; for (const auto & kv : g_lang) { const auto token_lang = whisper_token_lang(ctx, kv.second.first); - probs_id.emplace_back( ctx->probs[token_lang], kv.second.first ); + probs_id.emplace_back(ctx->probs[token_lang], kv.second.first); } // sort descending