speechbrain/speechbrain

Possible NCCL-level deadlock during checkpointing

kokamido opened this issue · 7 comments

Describe the bug

General description

Hi! I am not sure if this is a bug. It's unclear to me how checkpointing should be used in a DDP mode on a single machine with multiple GPU. If I misunderstood something I would be grateful for explanation, but as far as I know, if I use speechbrain 0.5.16 I have only the following options:

  1. Write same checkpoint multiple times because every DDP-worker will write it. I think it is not a perfect option because multiple writings are redundant and there is a possible race condition somewhere in the checkpointing code.
  2. Get a deadlock.

The following text is based on behavior of the repro-setup I provide in the "To Reproduce" section. I discuss a few runs of repro setup. Full logs of these runs are provided in the "Relevant Log Output" section. I implement end-of-epoch checkpointing using this speechbrain recipe as a reference.

Multiple writings of the same checkpoint

To prove that every DDP-worker writes checkpoints I modified the source code of speechbrain==0.5.16. I added print(f'{os.environ.get("LOCAL_RANK")}\t{name}') here. If I ran my repro as

rm -rf experiments/ && torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --number_of_epochs=1 --ckpt_b
ehavior=="All threads" 2>&1 | tee log_1.txt

I can see that both workers write the same checkpoints parts (full log provided as "log_1" in the "Relevant Log Output" section):

100%|██████████| 160/160 [00:01<00:00, 147.62it/s, train_loss=0.68] 
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer

Test setups

I ran the following setups without "print" added before for the sake of clarity.

1. Write intra-epoch checkpoints only

torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --ckpt_interval_minutes=0.001 2>&1 | tee log_2.txt

FileNotFoundError: [Errno 2] No such file or directory: 'experiments/ddp_crash_repro/save/CKPT+2024-02-09+08-48-16+01'

Full log_2 is provided in the "Relevant Log Output" section. This situation is described in the issue

2. Write end-of-epoch checkpoints in main thread only.

torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --ckpt_behavior="Main thread only" 2>&1 | tee log_3. txt

The train stops like this:

100%|██████████| 160/160 [00:01<00:00, 149.47it/s, train_loss=0.68] 
100%|██████████| 40/40 [00:00<00:00, 1391.75it/s]
  0%|          | 0/160 [00:00<?, ?it/s, train_loss=0.625]

NCCL will terminate the train in 7200 seconds because of operation timeout.

If I ran this setup with TORCH_DISTRIBUTED_DEBUG=DETAIL
TORCH_DISTRIBUTED_DEBUG=DETAIL torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --ckpt_behavior="Main thread only" 2>&1 | tee log_4.txt

it will crash right after the checkpointing because of Collectives differ in the following aspects: Op type: BROADCASTvs ALLREDUCE (see log_4 for details). I think that this broadcast is the reason. This broadcast has been added in this commit so it affects speechbrain==0.5.16 only.

3. Write end-of-epoch checkpoints in all threads.

I ran this setup with additional "print" described above.
torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --ckpt_behavior="All threads" --number_of_epochs=3 2>&1 | tee log_5.txt

100%|██████████| 160/160 [00:01<00:00, 151.04it/s, train_loss=0.68] 
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer
100%|██████████| 40/40 [00:00<00:00, 1410.13it/s]
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer

Concurrent write does not lead to train error in the case of end-of-epoch checkpointing but it still looks redundant and quite dangerous because two thread may write something in the same file.

Expected behaviour

I think that checkpoint should be written exactly one time, despite there are multiple DDP workers running at the same time on my single machine. As far as I know it worked like that in ht speechbrain 0.5.12, but now there is no if_main_tread condition here

To Reproduce

ckpt_repro.py

import sys

import speechbrain as sb
import torch
import torch.nn as nn
from hyperpyyaml import load_hyperpyyaml
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from speechbrain.utils.distributed import if_main_process, run_on_main
from torch.distributed.elastic.multiprocessing.errors import record

from nemo.collections.asr.parts.submodules.conformer_modules import ConformerLayer


class TestClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(torch.nn.Linear(24, 1))
        

    def forward(self, x):
        x = x.squeeze(1)
        for i, layer in enumerate(self.layers):
            x = layer(x)
        return x



class TestBrain(sb.Brain):
    def __init__(self, modules=None, opt_class=None, hparams=None, run_opts=None, checkpointer=None, deadlock=False):
        super().__init__(modules, opt_class, hparams, run_opts, checkpointer)
        self.loss = hparams['loss']
        self.ckpt_behavior = hparams['ckpt_behavior']

    @record
    def fit(self,
        epoch_counter,
        train_set,
        valid_set=None,
        progressbar=None,
        train_loader_kwargs={},
        valid_loader_kwargs={},
    ):
        super(TestBrain, self).fit(epoch_counter, train_set, valid_set, progressbar, train_loader_kwargs, valid_loader_kwargs)

    def on_stage_end(self, stage, stage_loss, epoch=None):
        if self.ckpt_behavior == 'Main thread only':
            if if_main_process():
                self.checkpointer.save_checkpoint({'test': 'test'})
        elif self.ckpt_behavior == 'All threads':
            self.checkpointer.save_checkpoint({'test': 'test'})


    def compute_objectives(self, predictions, batch, stage):
        _, labels = batch
        return self.loss(predictions, labels.to(self.device))

    def compute_forward(self, batch, stage):
        data, _ = batch
        return self.modules['model'](data.to(self.device)).squeeze()


def get_loaders():
    seed = int(hparams['seed'])
    X, y = make_classification(hparams['dataset_samples_count'], hparams['dataset_features_count'],
                               shuffle=False, random_state=seed)

    X_train, X_test, y_train, y_test = train_test_split(X[:, None, :], y, test_size=0.2, shuffle=True,
                                                        random_state=seed)

    train_loader = DataLoader(TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)),
                              batch_size=hparams['batch_size'], shuffle=False)
    test_loader = DataLoader(TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test)),
                             batch_size=hparams['batch_size'], shuffle=False)
    return train_loader, test_loader


if __name__ == "__main__":
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
    # Initialize ddp (useful only for multi-GPU DDP training)
    sb.utils.distributed.ddp_init_group(run_opts)

    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, overrides)

    train_loader, test_loader = get_loaders()

    modules = {'model': TestClassifier()}
    brain = TestBrain(modules, hparams['opt_class'], hparams, run_opts, hparams['checkpointer'])

    brain.fit(hparams['epoch_counter'], train_loader, test_loader)

ckpt_repro.yaml

name: ddp_crash_repro
output_folder: !ref experiments/<name>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/<name>_log.txt

batch_size: 64
seed: 3456
number_of_epochs: 10
ckpt_interval_minutes: 9999
ckpt_behavior: None

__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
  save_file: !ref <train_log>

dataset_samples_count: 12800
dataset_features_count: 24
dataset_features_informative: 15

opt_class: !name:torch.optim.Adam


loss: !new:torch.nn.modules.loss.BCEWithLogitsLoss

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
  limit: !ref <number_of_epochs>

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
  checkpoints_dir: !ref <save_folder>
  recoverables:
    counter: !ref <epoch_counter>

Environment Details

GPU: 2xV100

OS: Ubuntu 22.04.3 LTS

Python: 3.10.12

CUDA: Cuda compilation tools, release 12.1, V12.1.105, Build cuda_12.1.r12.1/compiler.32688072_0

torch.cuda.nccl.version(): (2, 18, 1)

Dependencies:

  • torch==2.1.2
  • speechbrain==0.5.16

Relevant Log Output

log_1

root@sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5:~/speechbraindebugexample# rm -rf experiments/ && torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --number_of_epochs=1 --ckpt_b
ehavior=="All threads" 2>&1 | tee log_1.txt
[2024-02-09 08:45:57,971] torch.distributed.run: [WARNING] 
[2024-02-09 08:45:57,971] torch.distributed.run: [WARNING] *****************************************
[2024-02-09 08:45:57,971] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-02-09 08:45:57,971] torch.distributed.run: [WARNING] *****************************************
[NeMo W 2024-02-09 08:46:01 optimizers:65] Could not import distributed_fused_adam optimizer from Apex
100%|██████████| 160/160 [00:01<00:00, 147.62it/s, train_loss=0.68] 
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer
100%|██████████| 40/40 [00:00<00:00, 1391.91it/s]
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer

log_2

root@sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5:~/speechbraindebugexample# torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --ckpt_interval_minutes=0.001 2>&1 | tee log_2.txt
[2024-02-09 08:48:08,092] torch.distributed.run: [WARNING] 
[2024-02-09 08:48:08,092] torch.distributed.run: [WARNING] *****************************************
[2024-02-09 08:48:08,092] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-02-09 08:48:08,092] torch.distributed.run: [WARNING] *****************************************
[NeMo W 2024-02-09 08:48:11 optimizers:65] Could not import distributed_fused_adam optimizer from Apex
 41%|████▏     | 66/160 [00:00<00:00, 113.92it/s, train_loss=0.72]Traceback (most recent call last):
  File "/root/speechbraindebugexample/ckpt_repro.py", line 93, in <module>
    brain.fit(hparams['epoch_counter'], train_loader, test_loader)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/root/speechbraindebugexample/ckpt_repro.py", line 46, in fit
    super(TestBrain, self).fit(epoch_counter, train_set, valid_set, progressbar, train_loader_kwargs, valid_loader_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1366, in fit
    self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1212, in _fit_train
    self._save_intra_epoch_ckpt()
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1386, in _save_intra_epoch_ckpt
    self.checkpointer.save_and_keep_only(
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 685, in save_and_keep_only
    self.delete_checkpoints(
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 988, in delete_checkpoints
    self.find_checkpoints(
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 825, in find_checkpoints
    ckpts = self.list_checkpoints()
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 914, in list_checkpoints
    return self._construct_checkpoint_objects(self._list_checkpoint_dirs())
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 1064, in _construct_checkpoint_objects
    for ckptfile in ckpt_dir.iterdir():
  File "/usr/lib/python3.10/pathlib.py", line 1017, in iterdir
    for name in self._accessor.listdir(self):
FileNotFoundError: [Errno 2] No such file or directory: 'experiments/ddp_crash_repro/save/CKPT+2024-02-09+08-48-16+01'
[2024-02-09 08:48:23,119] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 170033 closing signal SIGTERM
[2024-02-09 08:48:23,384] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 1 (pid: 170034) of binary: /usr/bin/python3.10
[2024-02-09 08:48:23,392] torch.distributed.elastic.multiprocessing.errors.error_handler: [ERROR] no error file defined for parent, to copy child error file (/tmp/torchelastic_ft1cfo45/none_xcnrnd0c/attempt_0/1/error.json)
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
ckpt_repro.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-02-09_08:48:16
  host      : sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 170034)
  error_file: /tmp/torchelastic_ft1cfo45/none_xcnrnd0c/attempt_0/1/error.json
  traceback : Traceback (most recent call last):
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
      return f(*args, **kwargs)
    File "/root/speechbraindebugexample/ckpt_repro.py", line 46, in fit
      super(TestBrain, self).fit(epoch_counter, train_set, valid_set, progressbar, train_loader_kwargs, valid_loader_kwargs)
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1366, in fit
      self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1212, in _fit_train
      self._save_intra_epoch_ckpt()
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1386, in _save_intra_epoch_ckpt
      self.checkpointer.save_and_keep_only(
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 685, in save_and_keep_only
      self.delete_checkpoints(
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 988, in delete_checkpoints
      self.find_checkpoints(
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 825, in find_checkpoints
      ckpts = self.list_checkpoints()
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 914, in list_checkpoints
      return self._construct_checkpoint_objects(self._list_checkpoint_dirs())
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 1064, in _construct_checkpoint_objects
      for ckptfile in ckpt_dir.iterdir():
    File "/usr/lib/python3.10/pathlib.py", line 1017, in iterdir
      for name in self._accessor.listdir(self):
  FileNotFoundError: [Errno 2] No such file or directory: 'experiments/ddp_crash_repro/save/CKPT+2024-02-09+08-48-16+01'
  
============================================================
root@sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5:~/speechbraindebugexample# ls experiments/ddp_crash_repro/save/
CKPT+2024-02-09+08-48-16+00

log_3

root@sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5:~/speechbraindebugexample# torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --ckpt_behavior="Main thread only" 2>&1 | tee log_3.
txt
[2024-02-09 08:51:23,862] torch.distributed.run: [WARNING] 
[2024-02-09 08:51:23,862] torch.distributed.run: [WARNING] *****************************************
[2024-02-09 08:51:23,862] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-02-09 08:51:23,862] torch.distributed.run: [WARNING] *****************************************
[NeMo W 2024-02-09 08:51:27 optimizers:65] Could not import distributed_fused_adam optimizer from Apex
100%|██████████| 160/160 [00:01<00:00, 149.47it/s, train_loss=0.68] 
100%|██████████| 40/40 [00:00<00:00, 1391.75it/s]
  0%|          | 0/160 [00:00<?, ?it/s, train_loss=0.625]

log_4

root@sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5:~/speechbraindebugexample# TORCH_DISTRIBUTED_DEBUG=DETAIL torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --ckpt_behavior="Main
 thread only" 2>&1 | tee log_4.txt
[2024-02-09 09:09:00,140] torch.distributed.run: [WARNING] 
[2024-02-09 09:09:00,140] torch.distributed.run: [WARNING] *****************************************
[2024-02-09 09:09:00,140] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-02-09 09:09:00,140] torch.distributed.run: [WARNING] *****************************************
[NeMo W 2024-02-09 09:09:03 optimizers:65] Could not import distributed_fused_adam optimizer from Apex
100%|██████████| 160/160 [00:01<00:00, 88.35it/s, train_loss=0.68]  
Traceback (most recent call last):
  File "/root/speechbraindebugexample/ckpt_repro.py", line 93, in <module>
    brain.fit(hparams['epoch_counter'], train_loader, test_loader)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/root/speechbraindebugexample/ckpt_repro.py", line 46, in fit
    super(TestBrain, self).fit(epoch_counter, train_set, valid_set, progressbar, train_loader_kwargs, valid_loader_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1366, in fit
    self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1218, in _fit_train
    self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
  File "/root/speechbraindebugexample/ckpt_repro.py", line 51, in on_stage_end
    self.checkpointer.save_checkpoint({'test': 'test'})
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 586, in save_checkpoint
    torch.distributed.broadcast_object_list(communication_list, src=0)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 2603, in broadcast_object_list
    broadcast(object_sizes_tensor, src=src, group=group)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 1906, in broadcast
    work = default_pg.broadcast([tensor], opts)
RuntimeError: Detected mismatch between collectives on ranks. Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=485, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 1 is running collective: CollectiveFingerPrint(SequenceNumber=485, OpType=ALLREDUCE, TensorShape=[25], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects:   Op type: BROADCASTvs ALLREDUCE  Tensor Tensor shapes: 1vs 25  Tensor Tensor dtypes: Longvs Float
Traceback (most recent call last):
  File "/root/speechbraindebugexample/ckpt_repro.py", line 93, in <module>
    brain.fit(hparams['epoch_counter'], train_loader, test_loader)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/root/speechbraindebugexample/ckpt_repro.py", line 46, in fit
    super(TestBrain, self).fit(epoch_counter, train_set, valid_set, progressbar, train_loader_kwargs, valid_loader_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1366, in fit
    self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1193, in _fit_train
    loss = self.fit_batch(batch)
  File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1066, in fit_batch
    (loss / self.grad_accumulation_factor).backward()
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Detected mismatch between collectives on ranks. Rank 1 is running collective: CollectiveFingerPrint(SequenceNumber=485, OpType=ALLREDUCE, TensorShape=[25], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=485, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects:   Op type: ALLREDUCEvs BROADCAST  Tensor Tensor shapes: 25vs 1  Tensor Tensor dtypes: Floatvs Long
[2024-02-09 09:09:15,164] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 170581) of binary: /usr/bin/python3.10
[2024-02-09 09:09:15,172] torch.distributed.elastic.multiprocessing.errors.error_handler: [ERROR] no error file defined for parent, to copy child error file (/tmp/torchelastic_ttv1bvv7/none_8kh569f4/attempt_0/0/error.json)
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
ckpt_repro.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-02-09_09:09:10
  host      : sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 170582)
  error_file: /tmp/torchelastic_ttv1bvv7/none_8kh569f4/attempt_0/1/error.json
  traceback : Traceback (most recent call last):
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
      return f(*args, **kwargs)
    File "/root/speechbraindebugexample/ckpt_repro.py", line 46, in fit
      super(TestBrain, self).fit(epoch_counter, train_set, valid_set, progressbar, train_loader_kwargs, valid_loader_kwargs)
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1366, in fit
      self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1193, in _fit_train
      loss = self.fit_batch(batch)
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1066, in fit_batch
      (loss / self.grad_accumulation_factor).backward()
    File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 492, in backward
      torch.autograd.backward(
    File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 251, in backward
      Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  RuntimeError: Detected mismatch between collectives on ranks. Rank 1 is running collective: CollectiveFingerPrint(SequenceNumber=485, OpType=ALLREDUCE, TensorShape=[25], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=485, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects:   Op type: ALLREDUCEvs BROADCAST  Tensor Tensor shapes: 25vs 1  Tensor Tensor dtypes: Floatvs Long
  
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-02-09_09:09:10
  host      : sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 170581)
  error_file: /tmp/torchelastic_ttv1bvv7/none_8kh569f4/attempt_0/0/error.json
  traceback : Traceback (most recent call last):
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
      return f(*args, **kwargs)
    File "/root/speechbraindebugexample/ckpt_repro.py", line 46, in fit
      super(TestBrain, self).fit(epoch_counter, train_set, valid_set, progressbar, train_loader_kwargs, valid_loader_kwargs)
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1366, in fit
      self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/core.py", line 1218, in _fit_train
      self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
    File "/root/speechbraindebugexample/ckpt_repro.py", line 51, in on_stage_end
      self.checkpointer.save_checkpoint({'test': 'test'})
    File "/usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py", line 586, in save_checkpoint
      torch.distributed.broadcast_object_list(communication_list, src=0)
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
      return func(*args, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 2603, in broadcast_object_list
      broadcast(object_sizes_tensor, src=src, group=group)
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
      return func(*args, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 1906, in broadcast
      work = default_pg.broadcast([tensor], opts)
  RuntimeError: Detected mismatch between collectives on ranks. Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=485, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 1 is running collective: CollectiveFingerPrint(SequenceNumber=485, OpType=ALLREDUCE, TensorShape=[25], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects:   Op type: BROADCASTvs ALLREDUCE  Tensor Tensor shapes: 1vs 25  Tensor Tensor dtypes: Longvs Float

log_5

root@sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5:~/speechbraindebugexample# torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --ckpt_behavior="All threads" --number_of_epochs=3 2
>&1 | tee log_5.txt
[2024-02-09 09:22:30,687] torch.distributed.run: [WARNING] 
[2024-02-09 09:22:30,687] torch.distributed.run: [WARNING] *****************************************
[2024-02-09 09:22:30,687] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-02-09 09:22:30,687] torch.distributed.run: [WARNING] *****************************************
[NeMo W 2024-02-09 09:22:34 optimizers:65] Could not import distributed_fused_adam optimizer from Apex
root@sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5:~/speechbraindebugexample# rm -rf experiments/
root@sbx-60283d040ccf4433b126ad86e96ba6ac-5ff484847d-kcvm5:~/speechbraindebugexample# torchrun --nnodes=1 --nproc-per-node=2 ckpt_repro.py ckpt_repro.yaml --ckpt_behavior="All threads" --number_of_epochs=3 2>&1 | tee log_5.txt
[2024-02-09 09:22:49,895] torch.distributed.run: [WARNING] 
[2024-02-09 09:22:49,895] torch.distributed.run: [WARNING] *****************************************
[2024-02-09 09:22:49,895] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-02-09 09:22:49,895] torch.distributed.run: [WARNING] *****************************************
[NeMo W 2024-02-09 09:22:53 optimizers:65] Could not import distributed_fused_adam optimizer from Apex
100%|██████████| 160/160 [00:01<00:00, 151.04it/s, train_loss=0.68] 
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer
100%|██████████| 40/40 [00:00<00:00, 1410.13it/s]
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer
100%|██████████| 160/160 [00:00<00:00, 391.79it/s, train_loss=0.55] 
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer
100%|██████████| 40/40 [00:00<00:00, 1409.52it/s]
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer
100%|██████████| 160/160 [00:00<00:00, 361.50it/s, train_loss=0.467]
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer
100%|██████████| 40/40 [00:00<00:00, 1422.81it/s]
0       counter
0       brain
1       counter
1       brain
1       optimizer
0       optimizer


### Additional Context

_No response_

Hello @kokamido, thanks for opening this issue! Could you please let us know if your speechbrain version is from the main branch or the develop branch? How did you installed SpeechBrain ? Through pip install speechbrain or git clone ? Thanks.

I'm pinging again @pplantinga as this is a very important issue.

I installed speechbrain==0.5.16 via pip.
In order to add a "print" described in the "Multiple writings of the same checkpoint" section I modified /usr/local/lib/python3.10/dist-packages/speechbrain/utils/checkpoints.py file of the speechbrain package installed via pip.

Could you please try with the SpeechBrain version available in the develop branch and get back to me with the results? We fixed several issues with DDP in this new version.

You can install it with the following command:

pip install git+https://github.com/speechbrain/speechbrain.git@develop

I tested develop version of the speechbrain package installed as pip install git+https://github.com/speechbrain/speechbrain.git@develop

1. Write intra-epoch checkpoints only

Seems fixed. It takes a few epochs to crash if I use speechbrain==0.5.16 from pip, but it worked well for 100 epochs if I use develop version. I think it means that this issue is fixed in the develop branch

2. Write end-of-epoch checkpoints in main thread only.

No changes. Both setups (with and without TORCH_DISTRIBUTED_DEBUG=DETAIL) behave as described in the issue

3. Write end-of-epoch checkpoints in all threads.

No changes. Both DDP-workers write a checkpoint according to logs from print(f'{os.environ.get("LOCAL_RANK")}\t{ckpt_dir}/{name}') injected to this line.

100%|██████████| 160/160 [00:01<00:00, 153.53it/s, train_loss=0.68] 
0       experiments/ddp_crash_repro/save/CKPT+2024-02-10+13-30-56+00/counter
0       experiments/ddp_crash_repro/save/CKPT+2024-02-10+13-30-56+00/brain
1       experiments/ddp_crash_repro/save/CKPT+2024-02-10+13-30-56+00/counter
1       experiments/ddp_crash_repro/save/CKPT+2024-02-10+13-30-56+00/brain
1       experiments/ddp_crash_repro/save/CKPT+2024-02-10+13-30-56+00/optimizer
0       experiments/ddp_crash_repro/save/CKPT+2024-02-10+13-30-56+00/optimizer

Hi, thanks for your very detailed investigation of this issue, this makes it much easier to debug and fix on our side. To address these three issues, let me respond below:

  1. Yes this was an issue and we have fixed it.
  2. This approach should be unnecessary, it should "just work" as the default saving function is marked with @main_process_only see this line. However, I have opened a PR #2404 based on this feedback to enable this approach to work, though you'd have to use a @main_process_only function rather than if_main_process.
  3. I don't think this is the right place to insert the print statement. Instead, try putting it inside the default saving function (same line as above). The issue should no longer occur, if it does please let us know.

Thanks for the clarification. Now I understand how the checkpoints should be saved, and I have no more questions.

Solved in #2404