mseitzer/srgan

Runtime error with channels

aradhyamathur opened this issue · 1 comments

I tried to run the eval.py on a test image. However, the following error occurred. Could you please point out what error might have occurred. Does it expect images to be preprocessed in a specific way?

The following is the traceback
Running on GPU 0 Restored checkpoint from resources/pretrained/srgan.pth Traceback (most recent call last): File "./eval.py", line 152, in <module> main(sys.argv[1:]) File "./eval.py", line 117, in main data = runner.infer(loader) File "/home/thor/PythonProjects/ML/Pytorch/srgan/training/base_runner.py", line 128, in infer _, data = self._val_step(loader, compute_metrics=False) File "/home/thor/PythonProjects/ML/Pytorch/srgan/training/adversarial_runner.py", line 294, in _val_step prediction = self.gen(inp) File "/home/thor/anaconda3/envs/srgan/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__ result = self.forward(*input, **kwargs) File "/home/thor/PythonProjects/ML/Pytorch/srgan/models/srresnet.py", line 193, in forward initial = self.initial_conv(x) File "/home/thor/anaconda3/envs/srgan/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__ result = self.forward(*input, **kwargs) File "/home/thor/anaconda3/envs/srgan/lib/python3.6/site-packages/torch/nn/modules/container.py", line 67, in forward input = module(input) File "/home/thor/anaconda3/envs/srgan/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__ result = self.forward(*input, **kwargs) File "/home/thor/anaconda3/envs/srgan/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 277, in forward self.padding, self.dilation, self.groups) File "/home/thor/anaconda3/envs/srgan/lib/python3.6/site-packages/torch/nn/functional.py", line 90, in conv2d return f(input, weight, bias) RuntimeError: Given groups=1, weight[64, 3, 9, 9], so expected input[1, 4, 1088, 1088] to have 3 channels, but got 4 channels instead

This is probably because your input has an alpha channel, i.e. is in RGBA format. If you try to convert it to RGB, it should work. At some point I should implement something to do that automatically.