sunshineatnoon/PytorchWCT

Fix: Read Lua weights with Pytorch > 1.0

Opened this issue · 2 comments

from torch.utils.serialization import load_lua doesnt work in current pytorch versions

here is a possible fix with torchfile

class pytorch_lua_wrapper:
    def __init__(self, lua_path):
        self.lua_model = torchfile.load(lua_path)

    def get(self, idx):
        return self.lua_model._obj.modules[idx]._obj

Now you can relace this line:

vgg1 = load_lua(args.vgg1)
with
vgg1 = pytorch_lua_wrapper(args.vgg1)

and this line
self.conv1.weight = torch.nn.Parameter(vgg1.get(0).weight.float())
with
self.conv1.weight =torch.nn.Parameter(torch.from_numpy(vgg1.get(0).weight).float())

Thanks, but get the error as follows

File "G:\Project\A_NST\PytorchWCT-master\util.py", line 28, in __init__
    vgg1 = pytorch_lua_wrapper(args.vgg1)
  File "G:\Project\A_NST\PytorchWCT-master\util.py", line 18, in __init__
    self.lua_model = torchfile.load(lua_path)
  File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 424, in load
    return reader.read_obj()
  File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 370, in read_obj
    obj._obj = self.read_obj()
  File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 385, in read_obj
    k = self.read_obj()
  File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 386, in read_obj
    v = self.read_obj()
  File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 370, in read_obj
    obj._obj = self.read_obj()
  File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 387, in read_obj
    obj[k] = v
TypeError: unhashable type: 'list'