lescientifik/open_brats2020

Usage of Multiple GPUs causes Error in State Dict

Harsh-Gill opened this issue · 1 comments

When using multiple GPUS -> torch.nn.DataParallel(model).cuda()

This changes the state_dict to inculde a module.xxx compared to the expected state_dict which has no module prefix.
I can make a pull request to fix this behaviour by using a trick to modify the state_dict of multiple GPUs to be usable.

insert below in src.utils.py, it will be ok.

def reload_ckpt_bis(ckpt, model, device='cuda', optimizer=None):
    if os.path.isfile(ckpt):
        print(f"=> loading checkpoint {ckpt}")
        try:
             checkpoint = torch.load(ckpt)
             checkpoint= {k.replace('module.', '') if 'module.' in k else k:v for k,v in checkpoint['state_dict'].items()}