Tutorial (tutorial.ipynb) fails
Closed this issue ยท 3 comments
hotstone1993 commented
๐ Describe the bug
After successfully building with the help of the link, I tried running the tutorial, but it is failing.
Cell In[2], line 26, in Net.forward(self, input)
23 # RNN output shape is (seq_len, batch, input_size)
24 # Get last output of RNN
25 output = output[:, -1, :]
---> 26 output = self.norm(output)
27 output = self.dropout1(output)
28 output = self.fc1(output)
File [~/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py:1190](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/baron/dev/trident/examples/~/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py:1190), in Module._call_impl(self, *input, **kwargs)
...
35 ) = args
37 if use_input_stats:
38 InstanceNorm.__optimize(inp, run_mean, run_var, momentum)
ValueError: too many values to unpack (expected 8)
The context for performing the backward operation is passed as the first argument in the forward function of torch.autograd.Function, but only custom ops are used in the custom class(e.g: InstanceNorm, ReLU, ... ).
Am I missing something?
๐ป Requirements
- Platform: Ubuntu 22.04.2 LTS(WSL2)
- Version: Python 3.8.17, PyTorch 1.13.0
๐ฌ Additional context
Of course, it works well if you modify the forward function as follows.
(
class InstanceNorm(torch.autograd.Function):
@staticmethod
def forward(*args, **kwargs):
(
ctx, # This variable was added
inp,
run_mean,
run_var,
wgt,
bis,
use_input_stats,
momentum,
eps,
) = args
if use_input_stats:
InstanceNorm.__optimize(inp, run_mean, run_var, momentum)
run_mean = run_var = None
return InstanceNorm.__forward(
inp,
run_mean,
run_var,
wgt,
bis,
eps,
)
kakao-steve-ai commented
@hotstone1993 It seems that it can be solved if you use PyTorch 2.0. Could you try again with PyTorch 2.0?
hotstone1993 commented
Thank you! After upgrading the version, it works well.
I confirmed that requirements.txt specifies 2.0.0 or higher.
daemyung commented
Hi @hotstone1993, we start to support from PyTorch 1.13. Happy coding! Thanks.