facebookresearch/swav

Slow Training with Swav on 2 GPUs using torch.distributed

Opened this issue · 1 comments

Hello,
I've set up a Conda environment, and here are the version details of some important libraries (they differ from the versions specified in the ReadMe file):

  • Python 3.8.12
  • PyTorch 1.13.1 (py3.8_cuda11.6_cudnn8.3.2_0)
  • CUDA Toolkit 11.6.1
  • Torchvision 0.14.1
  • Apex (installed as suggested in the ReadMe file)
    NB: the installation of these packages is done through Conda (not pip).

I ran main_swav.py using torch.distributed with two GPUs (both GeForce RTX 3090, each with 24GB). The command line prompt used was:
python -m torch.distributed.launch --nproc_per_node=2 main_swav.py --local_rank=0

I'm working with four datasets, stored on two SSDs to speed up the Dataloader loading process, with a total number of 1M images. During each epoch, approximately 400K images (balanced with a Sampler) are processed with a batch size of 2048, resulting in around 220 iterations.
The issue is that each iteration is taking approximately 400 seconds to complete, which leads to an extremely slow training process. As a result, training for a single epoch takes about a day.

In terms of computational time to process each single epoch, the current training sessions are significantly slower (around 400 seconds per iteration) compared to the previous ones (around 10 seconds per iteration) achieved using a different environment.
NB: this prior environment, which was set up and utilized approximately a year ago, followed specific version requirements as outlined in the ReadMe file (Python 3.6, PyTorch 1.4, CUDA 10.1).

The issue now is that replicating a similar environment is no longer possible due to the following error when attempting to execute the command:
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
The error message received is as follows:

PackagesNotFoundError: The following packages are not available from current channels:
  - torchvision==0.5.0

Any insights into why the training process is significantly slower when using a more recent version of Python and Pytorch? Additionally, I would appreciate any recommendations on how to address and resolve this issue.

Hello, I'd like to ask how you set up the environment for training using non-readme files. I always seem to encounter issues when installing Apex. Thank you.