leftthomas/SimCLR

DataParrallel to avoid memory issue

zhoubin-me opened this issue · 4 comments

Change the model creation from line 114 in main.py to the following:

    # model setup and optimizer config
    model = Model(feature_dim)
    flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32),))
    flops, params = clever_format([flops, params])
    print('# Model Params: {} FLOPs: {}'.format(params, flops))
    model = torch.nn.DataParallel(model).cuda()

Otherwise, you will face CUDA out of memory issue even if you have enough GPU cards to support higher batch size

For loading the checkpoint in linear.py line 20, it should be modified accordingly:

         model = Model()                                                                                                                                                                                                                             
         new_state_dict = OrderedDict()                                                                                                                                                                                                              
         state_dict = torch.load(pretrained_path)                                                                                                                                                                                                    
         for k, v in state_dict.items():                                                                                                                                                                                                             
             name = k[7:] # remove `module.`                                                                                                                                                                                                         
             new_state_dict[name] = v                                                                                                                                                                                                                
         model.load_state_dict(new_state_dict, strict=False)                                                                                                                                                                                         
         # self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False)                                                                                                                                                       
         self.f = model.f                                                                                                                                                                                                                            
         self.fc = nn.Linear(2048, num_class, bias=True) 

@zhoubin-me, first of all, what you said is not accurate:

you will face CUDA out of memory issue even if you have enough GPU cards to support higher batch size

if you have V100 GPU (32G) or A100 GPU (40G or 80G), then it'll work fine, or if you have 3090 GPU (24G), change the batch size as half of default, it's ok too.
Second, the solution you provide is incomplete, you just need to change that sentence

SimCLR/main.py

Line 138 in cee178b

torch.save(model.state_dict(), 'results/{}_model.pth'.format(save_name_pre))

to

torch.save(model.module.state_dict(), 'results/{}_model.pth'.format(save_name_pre)) 

no need to change any code in linear.py.

I have 2 GPU cards of A6000 which has 45 GB GPU memory, with batch-size = 1000, if not using data-paralell, it will cause RuntimeError: CUDA out of memory. error; use data parallel model will solve this issue.

@zhoubin-me I mentioned it in readme that this code with batch size = 512 and trained with V100 GPU, and works fine. You changed default hyper-paramers, that's another story, can not be called an issue. You can certainly make other changes based on this code, but these are not problems.