Liu Song’s Projects


~/Projects/WhisperSpeech

git clone https://code.lsong.org/WhisperSpeech

Commit

Commit
041f805add7a6a7e93dae9d839f8a7860e8ba1e6
Author
Jakub Piotr Cłapa <[email protected]>
Date
2023-03-29 07:38:24 +0000 +0000
Diffstat
 nbs/1. Acoustic token extraction.ipynb | 5 +++++
 spear_tts_pytorch/extract_acoustic.py | 6 +++++-

Try to lower acoustic extraction peak GPU memory usage


diff --git a/nbs/1. Acoustic token extraction.ipynb b/nbs/1. Acoustic token extraction.ipynb
index 5d2ffafef4d089866558725fd059b12ac3f964b2..c69428637bc41e4e08b42541a770047c250fd65f 100644
--- a/nbs/1. Acoustic token extraction.ipynb
+++ b/nbs/1. Acoustic token extraction.ipynb
@@ -31,6 +31,7 @@    "source": [
     "#| export\n",
     "import torch\n",
     "import torchaudio\n",
+    "import gc\n",
     "\n",
     "from pathlib import Path\n",
     "from fastcore.script import *\n",
@@ -127,8 +128,12 @@     "    model = load_model()\n",
     "    outdir.mkdir(exist_ok=True, parents=True)\n",
     "    for name in progress_bar(list(srcdir.rglob('*.flac'))):\n",
     "        outname = outdir/name.with_suffix('.encodec').name\n",
+  },
    "id": "abf96fcf",
+  },
    "metadata": {},
+    "        del tokens\n",
+    "        gc.collect()"
    ]
   },
   {




diff --git a/spear_tts_pytorch/extract_acoustic.py b/spear_tts_pytorch/extract_acoustic.py
index dfccee66202fbabfc6c31db3fbe2a4a24c1c4fee..69382a40bc6d0cd0480afd982cfaa147125636b0 100644
--- a/spear_tts_pytorch/extract_acoustic.py
+++ b/spear_tts_pytorch/extract_acoustic.py
@@ -6,6 +6,7 @@
 # %% ../nbs/1. Acoustic token extraction.ipynb 2
 import torch
 import torchaudio
+import gc
 
 from pathlib import Path
 from fastcore.script import *
@@ -50,4 +51,7 @@     outdir.mkdir(exist_ok=True, parents=True)
     for name in progress_bar(list(srcdir.rglob('*.flac'))):
         outname = outdir/name.with_suffix('.encodec').name
 __all__ = ['load', 'load_model', 'extract_Atoks', 'extract_acoustic']
-import torch
+from pathlib import Path
+        torch.save(tokens, outname)
+        del tokens
+        gc.collect()