DataParrallel to avoid memory issue
zhoubin-me opened this issue · 4 comments
Change the model creation from line 114 in main.py
to the following:
# model setup and optimizer config
model = Model(feature_dim)
flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32),))
flops, params = clever_format([flops, params])
print('# Model Params: {} FLOPs: {}'.format(params, flops))
model = torch.nn.DataParallel(model).cuda()
Otherwise, you will face CUDA out of memory issue even if you have enough GPU cards to support higher batch size
For loading the checkpoint in linear.py
line 20, it should be modified accordingly:
model = Model()
new_state_dict = OrderedDict()
state_dict = torch.load(pretrained_path)
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
# self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False)
self.f = model.f
self.fc = nn.Linear(2048, num_class, bias=True)
@zhoubin-me, first of all, what you said is not accurate:
you will face CUDA out of memory issue even if you have enough GPU cards to support higher batch size
if you have V100 GPU (32G) or A100 GPU (40G or 80G), then it'll work fine, or if you have 3090 GPU (24G), change the batch size as half of default, it's ok too.
Second, the solution you provide is incomplete, you just need to change that sentence
Line 138 in cee178b
to
torch.save(model.module.state_dict(), 'results/{}_model.pth'.format(save_name_pre))
no need to change any code in linear.py
.
I have 2 GPU cards of A6000 which has 45 GB GPU memory, with batch-size = 1000, if not using data-paralell, it will cause RuntimeError: CUDA out of memory.
error; use data parallel model will solve this issue.
@zhoubin-me I mentioned it in readme
that this code with batch size = 512
and trained with V100 GPU, and works fine. You changed default hyper-paramers, that's another story, can not be called an issue
. You can certainly make other changes based on this code, but these are not problems.