yl4579/StyleTTS2

Fine Tuning fails on multi GPU

Closed this issue · 3 comments

Tune Config

{'log_dir': 'Models/LJSpeech',
 'save_freq': 5,
 'log_interval': 10,
 'device': 'cuda',
 'epochs': 50,
 'batch_size': 8,
 'max_len': 400,
 'pretrained_model': 'Models/LibriTTS/epochs_2nd_00020.pth',
 'second_stage_load_pretrained': True,
 'load_only_params': True,
 'F0_path': 'Utils/JDC/bst.t7',
 'ASR_config': 'Utils/ASR/config.yml',
 'ASR_path': 'Utils/ASR/epoch_00080.pth',
 'PLBERT_dir': 'Utils/PLBERT/',
 'data_params': {'train_data': 'Data/train_list.txt',
  'val_data': 'Data/val_list.txt',
  'root_path': 'Data/wavs',
  'OOD_data': 'Data/OOD_texts.txt',
  'min_length': 50},
 'preprocess_params': {'sr': 24000,
  'spect_params': {'n_fft': 2048, 'win_length': 1200, 'hop_length': 300}},
 'model_params': {'multispeaker': True,
  'dim_in': 64,
  'hidden_dim': 512,
  'max_conv_dim': 512,
  'n_layer': 3,
  'n_mels': 80,
  'n_token': 178,
  'max_dur': 50,
  'style_dim': 128,
  'dropout': 0.2,
  'decoder': {'type': 'hifigan',
   'resblock_kernel_sizes': [3, 7, 11],
   'upsample_rates': [10, 5, 3, 2],
   'upsample_initial_channel': 512,
   'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
   'upsample_kernel_sizes': [20, 10, 6, 4]},
  'slm': {'model': 'microsoft/wavlm-base-plus',
   'sr': 16000,
   'hidden': 768,
   'nlayers': 13,
   'initial_channel': 64},
  'diffusion': {'embedding_mask_proba': 0.1,
   'transformer': {'num_layers': 3,
    'num_heads': 8,
    'head_features': 64,
    'multiplier': 2},
   'dist': {'sigma_data': 0.2,
    'estimate_sigma_data': True,
    'mean': -3.0,
    'std': 1.0}}},
 'loss_params': {'lambda_mel': 5.0,
  'lambda_gen': 1.0,
  'lambda_slm': 1.0,
  'lambda_mono': 1.0,
  'lambda_s2s': 1.0,
  'lambda_F0': 1.0,
  'lambda_norm': 1.0,
  'lambda_dur': 1.0,
  'lambda_ce': 20.0,
  'lambda_sty': 1.0,
  'lambda_diff': 1.0,
  'diff_epoch': 10,
  'joint_epoch': 30},
 'optimizer_params': {'lr': 0.0001, 'bert_lr': 1e-05, 'ft_lr': 0.0001},
 'slmadv_params': {'min_len': 400,
  'max_len': 500,
  'batch_percentage': 0.5,
  'iter': 10,
  'thresh': 5,
  'scale': 0.01,
  'sig': 1.5}}

nvidia-smi

Fri Dec  8 17:40:12 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:1B.0 Off |                    0 |
| N/A   41C    P0    51W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:00:1C.0 Off |                    0 |
| N/A   39C    P0    51W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  Off  | 00000000:00:1D.0 Off |                    0 |
| N/A   38C    P0    52W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  Off  | 00000000:00:1E.0 Off |                    0 |
| N/A   42C    P0    53W / 300W |      0MiB / 16160MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+


installs

!pip install SoundFile torchaudio gdown tensorboard munch torch==1.11.0 pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git

Error:

!python train_finetune.py --config_path ./Configs/config_ft.yml
bert loaded
bert_encoder loaded
predictor loaded
decoder loaded
text_encoder loaded
predictor_encoder loaded
style_encoder loaded
diffusion loaded
text_aligner loaded
pitch_extractor loaded
mpd loaded
msd loaded
wd loaded
BERT AdamW (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9, 0.99)
    eps: 1e-09
    initial_lr: 1e-05
    lr: 1e-05
    max_lr: 2e-05
    max_momentum: 0.95
    maximize: False
    min_lr: 0
    weight_decay: 0.01
)
decoder AdamW (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.0, 0.99)
    eps: 1e-09
    initial_lr: 0.0001
    lr: 0.0001
    max_lr: 0.0002
    max_momentum: 0.95
    maximize: False
    min_lr: 0
    weight_decay: 0.0001
)
Traceback (most recent call last):
  File "/root/StyleTTS2/Colab/StyleTTS2/train_finetune.py", line 707, in <module>
    main()
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/root/StyleTTS2/Colab/StyleTTS2/train_finetune.py", line 396, in main
    y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/opt/conda/lib/python3.10/site-packages/torch/_utils.py", line 457, in reraise
    raise exception
IndexError: Caught IndexError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/StyleTTS2/Colab/StyleTTS2/Modules/hifigan.py", line 474, in forward
    x = self.generator(x, s, F0_curve)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/StyleTTS2/Colab/StyleTTS2/Modules/hifigan.py", line 329, in forward
    x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 477, in __getitem__
    idx = self._get_abs_string_index(idx)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 460, in _get_abs_string_index
    raise IndexError('index {} is out of range'.format(idx))
IndexError: index 0 is out of range

Note had to pull torch==1.11.0 to work with the CUDA version 11.4, may be related ?

yl4579 commented

Your torch version may be too old. The code was tested with torch >= 2.0

Makes sense. The 11.4 drivers are compiled against 2.0.1, this works

import torch;
torch.__version__
'2.0.1+cu117'