JusperLee/Deep-Encoder-Decoder-Conv-TasNet

Error message on Asteroid - list output

Closed this issue · 1 comments

jvel07 commented

Hi! Nice repo, @JusperLee :)
I am trying to use Conv-Tasnet-Deep-w-dilation.py with Asteroid but I get the error below when training. Edit: Solved this by returning torch.unsqueeze(s[0], dim=1) for when num_spks=1.

  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 90, in launch
    return function(*args, **kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run
    results = self._run_stage()
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage
    self._run_train()
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1190, in _run_train
    self._run_sanity_check()
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1262, in _run_sanity_check
    val_loop.run()
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance
    output = self._evaluation_step(**kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluation_step
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1480, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 360, in validation_step
    return self.model(*args, **kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 110, in forward
    return self._forward_module.validation_step(*inputs, **kwargs)
  File "/media/jvel/data/repos/asteroid/asteroid/engine/system.py", line 130, in validation_step
    loss = self.common_step(batch, batch_nb, train=False)
  File "/media/jvel/data/repos/asteroid/asteroid/engine/system.py", line 102, in common_step
    loss = self.loss_func(est_targets, targets)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/media/jvel/data/repos/asteroid/asteroid/losses/pit_wrapper.py", line 100, in forward
    pw_losses = self.loss_func(est_targets, targets, **kwargs)
  File "/home/jvel/anaconda3/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/media/jvel/data/repos/asteroid/asteroid/losses/sdr.py", line 46, in forward
    if targets.size() != est_targets.size() or targets.ndim != 3:
AttributeError: 'list' object has no attribute 'size'

I think it is caused by your output being a list, you can check the output of the model first.