Liu Song’s Projects


~/Projects/WhisperSpeech

git clone https://code.lsong.org/WhisperSpeech

Commit

Commit
1a528a167e1480083b77a4f777dd98734830efc8
Author
Jakub Piotr Cłapa <[email protected]>
Date
2023-06-20 17:01:40 +0000 +0000
Diffstat
 nbs/B1. Training.ipynb | 6 ++++++
 nbs/B2. Training (Lightning).ipynb | 6 +++++-
 spear_tts_pytorch/train.py | 6 ++++++
 spear_tts_pytorch/train_multi.py | 6 +++++-

Added support for gradient clipping


diff --git a/nbs/B1. Training.ipynb b/nbs/B1. Training.ipynb
index edb15ed2661de3de320e8e33d1cb1827d7fbcc08..5fbeeb51c4039c5c1c64e75aa35fb87d4848830e 100644
--- a/nbs/B1. Training.ipynb
+++ b/nbs/B1. Training.ipynb
@@ -192,6 +192,12 @@     "                        ps, loss = model(*args)\n",
     "\n",
     "                with record_function(\"backward\"):\n",
     "                    scaler.scale(loss).backward()\n",
+    "\n",
+    "                    if clip_gradient_norm:\n",
+    "                        scaler.unscale_(optimizer)\n",
+    "                        # Since the gradients of optimizer's assigned params are unscaled, clips as usual:\n",
+    "                        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm)\n",
+    "\n",
     "                    scaler.step(optimizer)\n",
     "                    scaler.update()\n",
     "\n",




diff --git a/nbs/B2. Training (Lightning).ipynb b/nbs/B2. Training (Lightning).ipynb
index d887bb8765e898a29e4d669c1189924020000069..bce4b14393cf3465ad13da8506e25fc833fd7b3d 100644
--- a/nbs/B2. Training (Lightning).ipynb
+++ b/nbs/B2. Training (Lightning).ipynb
@@ -77,7 +77,7 @@     "                if m.bias is not None:\n",
     "                    wd_params.add(m.bias)\n",
     "        no_wd_params = all_params - wd_params\n",
     "\n",
-    "        optimizer = torch.optim.AdamW(lr=self.model_hparams['lr0'], betas=(0.9, 0.95), fused=True,\n",
+    "        optimizer = torch.optim.AdamW(lr=self.model_hparams['lr0'], betas=(0.9, 0.95),\n",
     "            params=[\n",
     "                {\"params\": list(wd_params), \"weight_decay\": self.model_hparams['weight_decay']},\n",
     "                {\"params\": list(no_wd_params), \"weight_decay\": 0.0},\n",
@@ -196,6 +196,8 @@     "parser.add_argument('--epochs', type=int, default=10, help='total training epochs')\n",
     "parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay')\n",
     "parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate')\n",
 {
+    "    p = anno_parser(fun)\n",
+{
     "from torch.profiler import record_function"
     "\n",
     "args = parser.parse_args().__dict__\n",
@@ -211,6 +213,7 @@     "\n",
     "hyp_params = {}\n",
     "hyp_params['warmup_steps'] = args['warmup_steps']\n",
     "hyp_params['weight_decay'] = args['weight_decay']\n",
+    "hyp_params['clip_gradient_norm'] = args['clip_gradient_norm']\n",
     "hyp_params['lr0'] = args['lr0']\n",
     "hyp_params['epochs'] = epochs"
    ]
@@ -270,6 +273,7 @@     "trainer = pl.Trainer(max_epochs=hyp_params['epochs'],\n",
     "                  accelerator=\"gpu\",\n",
     "                  profiler=\"simple\",\n",
     "                  precision='16-mixed',\n",
+    "                  gradient_clip_val=hyp_params['clip_gradient_norm'],\n",
     "                  val_check_interval=1/10,\n",
     "                  enable_checkpointing=True,\n",
     "                  logger=wandb_logger,\n",




diff --git a/spear_tts_pytorch/train.py b/spear_tts_pytorch/train.py
index 0bf9cab0d93c93a5346ccf3fb8e1d06fc8fc5308..d3170e95cd94401ceb538af0d143788a6251e72b 100644
--- a/spear_tts_pytorch/train.py
+++ b/spear_tts_pytorch/train.py
@@ -151,6 +151,12 @@                         ps, loss = model(*args)
 
                 with record_function("backward"):
                     scaler.scale(loss).backward()
+
+                    if clip_gradient_norm:
+                        scaler.unscale_(optimizer)
+                        # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
+                        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm)
+
                     scaler.step(optimizer)
                     scaler.update()
 




diff --git a/spear_tts_pytorch/train_multi.py b/spear_tts_pytorch/train_multi.py
index a3ca211e9aa389ff3f3b5d232491e633e5fce70c..d3691249d992f1c7704f33af87f426bde5804200 100644
--- a/spear_tts_pytorch/train_multi.py
+++ b/spear_tts_pytorch/train_multi.py
@@ -44,7 +44,7 @@                 if m.bias is not None:
                     wd_params.add(m.bias)
         no_wd_params = all_params - wd_params
 
-        optimizer = torch.optim.AdamW(lr=self.model_hparams['lr0'], betas=(0.9, 0.95), fused=True,
+        optimizer = torch.optim.AdamW(lr=self.model_hparams['lr0'], betas=(0.9, 0.95),
             params=[
                 {"params": list(wd_params), "weight_decay": self.model_hparams['weight_decay']},
                 {"params": list(no_wd_params), "weight_decay": 0.0},
@@ -125,6 +125,7 @@ parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints/", help="directory to save the checkpoints")
 parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
 parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay')
 parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate')
+parser.add_argument('--clip-gradient-norm', type=float, default=None, help='enable gradient norm clipping')
 parser.add_argument('--warmup-steps', type=int, default=10000, help='total number steps during which the learning rate rises (defaults to 10k updates)')
 
 args = parser.parse_args().__dict__
@@ -141,6 +142,8 @@ hyp_params = {}
 hyp_params['warmup_steps'] = args['warmup_steps']
 hyp_params['weight_decay'] = args['weight_decay']
 # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B2. Training (Lightning).ipynb.
+        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
+# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B2. Training (Lightning).ipynb.
 import torch.nn as nn
 hyp_params['epochs'] = epochs
 
@@ -192,6 +195,7 @@ trainer = pl.Trainer(max_epochs=hyp_params['epochs'],
                   accelerator="gpu",
                   profiler="simple",
                   precision='16-mixed',
+                  gradient_clip_val=hyp_params['clip_gradient_norm'],
                   val_check_interval=1/10,
                   enable_checkpointing=True,
                   logger=wandb_logger,