yl4579/StyleTTS2

Strange Loss Behavior During Stage Two Training - Not Decreasing after Diff Epoch

ethan-digi opened this issue · 3 comments

As stated in the title, I'm pre-training the model on a custom dataset, and am noticing that after Epoch 10 (multi-speaker, so LibriTTS epoch config) of second stage training, loss is not decreasing as I'd expect. I've noticed that it remains effectively flat. First stage training went without an issues, as did the first 10 epochs of second stage training.

For a lot of the losses, this makes sense. The vast majority of optimization happened prior to diffusion training (the first 10 epochs), and as such gains seem like they won't be terribly perceptible. Makes perfect sense. I also know that Gen and Discr. Losses tend not to decrease, as is the nature of their architecture.

However, neither Style nor Diffusion loss have decreased (see final loss chart screenshot), despite not being adversarial losses and despite being only introduced in the 10th epoch. I'm not sure if this is expected behavior or not, but I suspect not. When I run model checkpoints, I find the model fidelity is a bit below what I'd expect, which furthers my suspicion that training is not happening correctly.

I've attached losses and my config below:

lm_gen norm_f0 dur_ce loss_disc discLM_genLM style_diff
log_dir: "Models/LibriTTS"
first_stage_path: ""
save_freq: 1
log_interval: 10
device: "cuda"
epochs_1st: 50 # number of epochs for first stage training (pre-training)
epochs_2nd: 30 # number of peochs for second stage training (joint training)
batch_size: 14
max_len: 300 # maximum number of frames
pretrained_model: "Models/LibriTTS/epoch_2nd_00009.pth"
second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters

F0_path: "Utils/JDC/bst.t7"
ASR_config: "Utils/ASR/config.yml"
ASR_path: "Utils/ASR/epoch_00080.pth"
PLBERT_dir: 'Utils/PLBERT/'

data_params:
  train_data: "eleven_dataset/train_list.txt"
  val_data: "eleven_dataset/val_list.txt"
  root_path: "/local/wavs"
  OOD_data: "Data/OOD_texts.txt"
  min_length: 50 # sample until texts with this size are obtained for OOD texts

preprocess_params:
  sr: 24000
  spect_params:
    n_fft: 2048
    win_length: 1200
    hop_length: 300

model_params:
  multispeaker: true

  dim_in: 64 
  hidden_dim: 512
  max_conv_dim: 512
  n_layer: 3
  n_mels: 80

  n_token: 178 # number of phoneme tokens
  max_dur: 50 # maximum duration of a single phoneme
  style_dim: 128 # style vector size
  
  dropout: 0.2

  # config for decoder
  decoder: 
      type: 'hifigan' # either hifigan or istftnet
      resblock_kernel_sizes: [3,7,11]
      upsample_rates :  [10,5,3,2]
      upsample_initial_channel: 512
      resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
      upsample_kernel_sizes: [20,10,6,4]
      
  # speech language model config
  slm:
      model: 'microsoft/wavlm-base-plus'
      sr: 16000 # sampling rate of SLM
      hidden: 768 # hidden size of SLM
      nlayers: 13 # number of layers of SLM
      initial_channel: 64 # initial channels of SLM discriminator head
  
  # style diffusion model config
  diffusion:
    embedding_mask_proba: 0.1
    # transformer config
    transformer:
      num_layers: 3
      num_heads: 8
      head_features: 64
      multiplier: 2

    # diffusion distribution config
    dist:
      sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
      estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
      mean: -3.0
      std: 1.0
  
loss_params:
    lambda_mel: 5. # mel reconstruction loss
    lambda_gen: 1. # generator loss
    lambda_slm: 1. # slm feature matching loss
    
    lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
    lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
    TMA_epoch: 5 # TMA starting epoch (1st stage)

    lambda_F0: 1. # F0 reconstruction loss (2nd stage)
    lambda_norm: 1. # norm reconstruction loss (2nd stage)
    lambda_dur: 1. # duration loss (2nd stage)
    lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
    lambda_sty: 1. # style reconstruction loss (2nd stage)
    lambda_diff: 1. # score matching loss (2nd stage)
    
    diff_epoch: 10 # style diffusion starting epoch (2nd stage)
    joint_epoch: 15 # joint training starting epoch (2nd stage)

optimizer_params:
  lr: 0.0001 # general learning rate
  bert_lr: 0.00001 # learning rate for PLBERT
  ft_lr: 0.00001 # learning rate for acoustic modules
  
slmadv_params:
  min_len: 400 # minimum length of samples
  max_len: 500 # maximum length of samples
  batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
  iter: 20 # update the discriminator every this iterations of generator update
  thresh: 5 # gradient norm above which the gradient is scaled
  scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
  sig: 1.5 # sigma for differentiable duration modeling

Config is above. Highlights: batch size is 14 now, was 21 for epochs 1...10. Yes, I have seven (7) GPUs. I don't think that has to do with this, but 14 and 21 would be odd batch sizes otherwise. LR at default values.

I've made very slight modifications to the code for data loading, but nothing that should impact performance. I also modified one line in the diffusion module in sampler.py, but this is just a nan check on config:
sigma_data = self.sigma_data if not torch.isnan(torch.tensor(self.sigma_data)) else 0.2

Here's the config for the most recent epoch:

{ASR_config: Utils/ASR/config.yml, ASR_path: Utils/ASR/epoch_00080.pth, F0_path: Utils/JDC/bst.t7,
  PLBERT_dir: Utils/PLBERT/, batch_size: 14, data_params: {OOD_data: Data/OOD_texts.txt,
    min_length: 50, root_path: /local/wavs, train_data: eleven_dataset/train_list.txt,
    val_data: eleven_dataset/val_list.txt}, device: cuda, epochs_1st: 50, epochs_2nd: 30,
  first_stage_path: '', load_only_params: false, log_dir: Models/LibriTTS, log_interval: 10,
  loss_params: {TMA_epoch: 5, diff_epoch: 10, joint_epoch: 15, lambda_F0: 1.0, lambda_ce: 20.0,
    lambda_diff: 1.0, lambda_dur: 1.0, lambda_gen: 1.0, lambda_mel: 5.0, lambda_mono: 1.0,
    lambda_norm: 1.0, lambda_s2s: 1.0, lambda_slm: 1.0, lambda_sty: 1.0}, max_len: 300,
  model_params: {decoder: {resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3,
          5]], resblock_kernel_sizes: [3, 7, 11], type: hifigan, upsample_initial_channel: 512,
      upsample_kernel_sizes: [20, 10, 6, 4], upsample_rates: [10, 5, 3, 2]}, diffusion: {
      dist: {estimate_sigma_data: true, mean: -3.0, sigma_data: 0.308815117286219,
        std: 1.0}, embedding_mask_proba: 0.1, transformer: {head_features: 64, multiplier: 2,
        num_heads: 8, num_layers: 3}}, dim_in: 64, dropout: 0.2, hidden_dim: 512,
    max_conv_dim: 512, max_dur: 50, multispeaker: true, n_layer: 3, n_mels: 80, n_token: 178,
    slm: {hidden: 768, initial_channel: 64, model: microsoft/wavlm-base-plus, nlayers: 13,
      sr: 16000}, style_dim: 128}, optimizer_params: {bert_lr: 1.0e-05, ft_lr: 1.0e-05,
    lr: 0.0001}, preprocess_params: {spect_params: {hop_length: 300, n_fft: 2048,
      win_length: 1200}, sr: 24000}, pretrained_model: Models/LibriTTS/epoch_2nd_00009.pth,
  save_freq: 1, second_stage_load_pretrained: true, slmadv_params: {batch_percentage: 0.5,
    iter: 20, max_len: 500, min_len: 400, scale: 0.01, sig: 1.5, thresh: 5}}
root@f9834fda57cf:/workspace/tts/Models/LibriTTS# cat config_libritts_second.yml | grep sig
      dist: {estimate_sigma_data: true, mean: -3.0, sigma_data: 0.308815117286219,
    iter: 20, max_len: 500, min_len: 400, scale: 0.01, sig: 1.5, thresh: 5}}

I would appreciate some insight, or if anyone who has successfully trained from scratch could post their loss curves/configs.

Have you tried lowering learning rate?

Have you tried lowering learning rate?

Yes, I have thought of lowering the learning rate 😆. I'm just looking for some loss curves known to produce desirable results so I'm going at it from an angle other than 'down good', lol. But I do appreciate the input, I expect to some level that the solution will involve LR adjustments