Dataparallel issues on Slurm with Pl lightning
Opened this issue · 0 comments
mbreuss commented
Hi,
thanks for sharing your great work and the pretrained weights!
I encountered some issues when trying to run the model on my SLURM cluster with multiple GPUs using a PyTorch Lightning setup. The root cause was the line:
https://github.com/penn-pal-lab/LIV/blob/main/liv/__init__.py#L48
To improve compatibility, I'd suggest adding a use_dataparallel
argument to the load_liv function. Here's my adapted implementation, that worked for me:
def load_liv(modelid='resnet50', use_dataparallel: bool = True):
assert modelid == 'resnet50'
home = os.path.join(expanduser("~"), ".liv")
[...]
# Load the state dictionary
state_dict = torch.load(modelpath, map_location=torch.device(device))['liv']
# Handle possible 'module.' prefix mismatch
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
rep.load_state_dict(new_state_dict)
if use_dataparallel:
rep = torch.nn.DataParallel(rep)
return rep