errors occur when loading the pretrained model
Opened this issue · 2 comments
Hi @DK-Jang ! Thank you for sharing this nice job :-)
I meet some troubles when loading the pre-trained model and pytorch maps the location to a device as the code described below:
self.device = torch.cuda.current_device()
Line 33 in 7f1eca9
For my case it returns 0
and trigger an error:
TypeError: 'int' object is not callable. '
When I modify this line to self.device = torch.device('cuda:0')
the error message changes to
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.enc_content.edge_importance_j", ...
Unexpected key(s) in state_dict: "enc_content.edge_importance_j", ...
I think this is because the model is trained and saved in a parallel approach, however it is impossible for me to run on multiple GPUs.
Please offer me a help, thanks ahead!
I faced the same problem.
I think that pretrained_network trained with no data parallel (or use cpu).
This worked by modifying the code as follows:
In trainer.py
Line 15 in 7f1eca9
from collections import OrderedDict
Line 33 in 7f1eca9
self.device = torch.device("cuda:{}".format(torch.cuda.current_device()))
Lines 163 to 164 in 7f1eca9
gen_dict = OrderedDict()
for key, value in state_dict["gen"].items():
if not key.startswith("module."):
key = "module." + key
gen_dict[key] = value
self.gen.load_state_dict(gen_dict)
gen_ema_dict = OrderedDict()
for key, value in state_dict["gen_ema"].items():
if not key.startswith("module."):
key = "module." + key
gen_ema_dict[key] = value
self.gen_ema.load_state_dict(gen_ema_dict)
In test.py
Lines 128 to 131 in 7f1eca9
rec = rec.cpu().numpy()*std + mean
tra = tra.cpu().numpy()*std + mean
con_gt = con_gt.cpu().numpy()*std + mean
sty_gt = sty_gt.cpu().numpy()*std + mean
If you want to retrain this work, these changes must be erased.
I have the same problem. The way I tried:
Change this part to:
Lines 34 to 35 in 52af967
self.gen = self.gen.to(self.device)
self.gen_ema = self.gen_ema.to(self.device)
And:
Line 162 in 52af967
state_dict = torch.load(model_path, map_location="cuda:0")
And do the same thing in test.py in KosukeFukazawa's thread.