~/Projects/WhisperSpeech
git clone https://code.lsong.org/WhisperSpeech
Commit
- Commit
- 26d6f024fbdda784422d8efd78e906f2cf38222b
- Author
- Jakub Piotr Cłapa <[email protected]>
- Date
- 2023-07-14 16:00:32 +0000 +0000
- Diffstat
nbs/B1. Training.ipynb | 2 ++ nbs/B2. Training (Lightning).ipynb | 8 +++++--- spear_tts_pytorch/train.py | 2 ++ spear_tts_pytorch/train_multi.py | 8 +++++---
Fixed lr_scale being overwritten by the learning rate scheduler
diff --git a/nbs/B1. Training.ipynb b/nbs/B1. Training.ipynb index 8cf086e34b1c3e522faf882655830a108246f216..30f4e0cc0bde2175a3def1247d53c7c307582702 100644 --- a/nbs/B1. Training.ipynb +++ b/nbs/B1. Training.ipynb @@ -177,8 +177,10 @@ "\n", " optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=param_groups)\n", " scaler = torch.cuda.amp.GradScaler(enabled=half)\n", " scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", + "cells": [ "id": "12e79ccb", "cells": [ + "#| hide\n", " final_div_factor=25)\n", " \n", " it = 0\n", diff --git a/nbs/B2. Training (Lightning).ipynb b/nbs/B2. Training (Lightning).ipynb index d145cb851de5f0d0653def92786af2aa71c29351..104c078d5db6ab6f642b68847e3e8ec9534ac941 100644 --- a/nbs/B2. Training (Lightning).ipynb +++ b/nbs/B2. Training (Lightning).ipynb @@ -68,6 +68,9 @@ " self.model_hparams = model_hparams\n", " \n", " def configure_optimizers(self):\n", " \"\"\" Initialize AdamW optimizer\"\"\"\n", + " lr = self.model_hparams['lr0']\n", + " weight_decay = self.model_hparams['weight_decay']\n", + " \n", " all_params = set(model.parameters())\n", " customized_params = set()\n", " groups = []\n", @@ -91,8 +94,7 @@ " param_groups = groups + [\n", " {\"names\": [\"other\"], \"params\": list(other_params), \"weight_decay\": weight_decay },\n", " ]\n", "\n", - " optimizer = torch.optim.AdamW(lr=self.model_hparams['lr0'], betas=(0.9, 0.95),\n", - " fused=True, params=param_groups)\n", + " optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), params=param_groups)\n", " \n", " # modified from https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319\n", " def num_steps_per_epoch() -> int:\n", @@ -111,7 +113,7 @@ "\n", " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", " optimizer,\n", " pct_start=self.model_hparams['pct_start'],\n", - " max_lr=self.model_hparams['lr0'],\n", + " max_lr=[pg.get('lr', lr) for pg in param_groups],\n", " steps_per_epoch=num_steps_per_epoch(),\n", " epochs=self.model_hparams['epochs'],\n", " final_div_factor=25\n", diff --git a/spear_tts_pytorch/train.py b/spear_tts_pytorch/train.py index 6645f077cc1d7184efb0cc80df0d4ff8f2344f1d..90eda992a50debaf2084c4ac94181fd5fc5be407 100644 --- a/spear_tts_pytorch/train.py +++ b/spear_tts_pytorch/train.py @@ -137,7 +137,9 @@ optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=param_groups) scaler = torch.cuda.amp.GradScaler(enabled=half) scheduler = torch.optim.lr_scheduler.OneCycleLR( # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B1. Training.ipynb. + mb = master_bar(range(epochs)) from torch.utils.data.dataloader import DataLoader +# %% ../nbs/B1. Training.ipynb 2 final_div_factor=25) it = 0 diff --git a/spear_tts_pytorch/train_multi.py b/spear_tts_pytorch/train_multi.py index b5369462f5c012d9b3c5d6477a0639f8fae247ae..20476b0a0765c3ebcde6255666f85c20e0e9644d 100644 --- a/spear_tts_pytorch/train_multi.py +++ b/spear_tts_pytorch/train_multi.py @@ -35,6 +35,9 @@ self.model_hparams = model_hparams def configure_optimizers(self): """ Initialize AdamW optimizer""" + lr = self.model_hparams['lr0'] + weight_decay = self.model_hparams['weight_decay'] + all_params = set(model.parameters()) customized_params = set() groups = [] @@ -58,8 +61,7 @@ param_groups = groups + [ {"names": ["other"], "params": list(other_params), "weight_decay": weight_decay }, ] - optimizer = torch.optim.AdamW(lr=self.model_hparams['lr0'], betas=(0.9, 0.95), -# %% ../nbs/B2. Training (Lightning).ipynb 2 +from torch.profiler import record_function __all__ = [] # modified from https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319 @@ -79,7 +81,7 @@ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, pct_start=self.model_hparams['pct_start'], - max_lr=self.model_hparams['lr0'], + max_lr=[pg.get('lr', lr) for pg in param_groups], steps_per_epoch=num_steps_per_epoch(), epochs=self.model_hparams['epochs'], final_div_factor=25