kakaobrain/trident

Tutorial (tutorial.ipynb) fails

Closed this issue ยท 3 comments

๐Ÿž Describe the bug

After successfully building with the help of the link, I tried running the tutorial, but it is failing.

Cell In[2], line 26, in Net.forward(self, input)
     23 # RNN output shape is (seq_len, batch, input_size)
     24 # Get last output of RNN
     25 output = output[:, -1, :]
---> 26 output = self.norm(output)
     27 output = self.dropout1(output)
     28 output = self.fc1(output)

File [~/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py:1190](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/baron/dev/trident/examples/~/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py:1190), in Module._call_impl(self, *input, **kwargs)
...
     35     ) = args
     37     if use_input_stats:
     38         InstanceNorm.__optimize(inp, run_mean, run_var, momentum)

ValueError: too many values to unpack (expected 8)

The context for performing the backward operation is passed as the first argument in the forward function of torch.autograd.Function, but only custom ops are used in the custom class(e.g: InstanceNorm, ReLU, ... ).

Am I missing something?

๐Ÿ’ป Requirements

  • Platform: Ubuntu 22.04.2 LTS(WSL2)
  • Version: Python 3.8.17, PyTorch 1.13.0

๐Ÿ’ฌ Additional context

Of course, it works well if you modify the forward function as follows.

(
class InstanceNorm(torch.autograd.Function):
    @staticmethod
    def forward(*args, **kwargs):
        (
            ctx, # This variable was added
            inp,
            run_mean,
            run_var,
            wgt,
            bis,
            use_input_stats,
            momentum,
            eps,
        ) = args

        if use_input_stats:
            InstanceNorm.__optimize(inp, run_mean, run_var, momentum)
            run_mean = run_var = None

        return InstanceNorm.__forward(
            inp,
            run_mean,
            run_var,
            wgt,
            bis,
            eps,
        )

@hotstone1993 It seems that it can be solved if you use PyTorch 2.0. Could you try again with PyTorch 2.0?

Thank you! After upgrading the version, it works well.
I confirmed that requirements.txt specifies 2.0.0 or higher.

Hi @hotstone1993, we start to support from PyTorch 1.13. Happy coding! Thanks.