iseekwonderful/HPA-singlecell-2nd-dual-head-pipeline

Multi-GPU training fails

Opened this issue · 5 comments

I'm trying to train with 2 or 4 P100s. I had to tweak some of the argument/cconfiguration parsing to get the code to accept multiple GPUs. Once it starts though, it fails straight away with the following exception. Do you think there is an easy fix for this?

4
[ √ ] Landmark!
[0, 1, 2, 3]
[ √ ] Using #0,1,2,3 GPU
['0', '1', '2', '3']
[ ! ] Full fold coverage training! for fold: 0
[ √ ] Using transformation: s_0220/sin_256_final & None, image size: 256
[ i ] The length of train_dl is 7028, valid dl is 4391
Using cache found in /home/users/allstaff/thomas.e/.cache/torch/hub/rwightman_gen-efficientnet-pytorch_master
tf_efficientnet_b3
[ i ] Model: tf_efficientnet_b3, loss_func: bce, optimizer: Adam
parallel
4
[ ! ] pos weight: 0.1
[ √ ] Basic training
  0%|          | 0/7028 [00:12<?, ?it/s]
Traceback (most recent call last):
  File "main.py", line 172, in <module>
    basic_train(cfg, model, train_dl, valid_dl, loss_func, optimizer, result_path, scheduler, writer)
  File "/vast/scratch/users/thomas.e/HPA-singlecell-2nd-dual-head-pipeline/basic_train.py", line 93, in basic_train
    cell, exp = model(ipt, cfg.experiment.count)
  File "/stornext/HPCScratch/home/thomas.e/.conda/envs/rxrx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/stornext/HPCScratch/home/thomas.e/.conda/envs/rxrx/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/stornext/HPCScratch/home/thomas.e/.conda/envs/rxrx/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/stornext/HPCScratch/home/thomas.e/.conda/envs/rxrx/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/stornext/HPCScratch/home/thomas.e/.conda/envs/rxrx/lib/python3.7/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/stornext/HPCScratch/home/thomas.e/.conda/envs/rxrx/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/stornext/HPCScratch/home/thomas.e/.conda/envs/rxrx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/vast/scratch/users/thomas.e/HPA-singlecell-2nd-dual-head-pipeline/models/efficient.py", line 64, in forward
    viewed_pooled = pooled.view(-1, cnt, pooled.shape[-1])
RuntimeError: shape '[-1, 16, 1536]' is invalid for input of size 36864

What multiple-GPU method do you use, DataParallel or DistributedDataParallel? I suspect the sequence count size (cnt) cause the issue, you could do some debugging or add breakpoint at that code.

I haven't substantially changed your code. It seems to be wrapping the model in torch.nn.DataParallel near line . I'm running in a batch environment so interactive debugging is not possible. I'll put some print statements in models/efficient.py and see if I can pickup anything useful.

There is a train_net.py is used for DDP training but needs some coding to make it work.

Hello, may I ask how to set the parameters of this main.py? I did not find the train.sh file in readme.
Looking forward to your reply.

I spent some effort trying to get multi-GPU running but couldn't get it work. I decided it was too hard when I found that AMP is breaking basic python APIs by pulling methods out of objects and replacing them with functions of different types which another of the API depends on - really bizarre.