zilongzhong/SSTN

about the train.py of SSRN

nyg919 opened this issue · 0 comments

Thanks for the code shared by the author, but I don't quite understand some of the code. When I used the SSRN model to train my dataset, the following error occurred:
(base) root@autodl-container-cdc411aaac-ce90372b:~/code1/SSTN-main/SSTN-main# python train_longkou.py
Experiment dir : longkou-train-model-SSRN-arch-AEAE-20230311-200059-lr0.002
(270, 220000)
Traceback (most recent call last):
File "train_longkou.py", line 385, in
outputs = net(inputs.float())
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/code1/SSTN-main/SSTN-main/myNetworksBlocks.py", line 306, in forward
x = self.bn4(F.leaky_relu(self.layer4(x)))
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 135, in forward
self._check_input_dim(input)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 408, in _check_input_dim
raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
ValueError: expected 4D input (got 5D input)
I would like to ask if the you can point out the cause of the error and what needs to be changed, Thanks!