penn-pal-lab/LIV

Dataparallel issues on Slurm with Pl lightning

Opened this issue · 0 comments

Hi,

thanks for sharing your great work and the pretrained weights!

I encountered some issues when trying to run the model on my SLURM cluster with multiple GPUs using a PyTorch Lightning setup. The root cause was the line:

https://github.com/penn-pal-lab/LIV/blob/main/liv/__init__.py#L48

To improve compatibility, I'd suggest adding a use_dataparallel argument to the load_liv function. Here's my adapted implementation, that worked for me:

def load_liv(modelid='resnet50', use_dataparallel: bool = True):
    assert modelid == 'resnet50'
    home = os.path.join(expanduser("~"), ".liv")

    [...]

     # Load the state dictionary
    state_dict = torch.load(modelpath, map_location=torch.device(device))['liv']

    # Handle possible 'module.' prefix mismatch
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
    rep.load_state_dict(new_state_dict)
    if use_dataparallel:
        rep = torch.nn.DataParallel(rep)

    return rep