Sxela/WarpFusion

Tensor shape error on 'Generate optical flow and consistency maps' step

Chloe-323 opened this issue · 2 comments

Hello!

I am trying to get set up with stable_warpfusion_v0_8_6_stable (the public one) to try to experiment and figure out how to work this thing. I've followed the instructions to be able to run it locally. During installation, I get several deprecation warnings:

DEPRECATION: pytorch-lightning 1.7.7 has a non-standard dependency specifier torch>=1.9.*. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063

DEPRECATION: torchsde 0.2.5 has a non-standard dependency specifier numpy>=1.19.*; python_version >= "3.7". pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of torchsde or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063

These don't seem too serious, so I keep going. When I reach the 'Generate optical flow and consistency maps' step, I face the following error after about 30 seconds:
RuntimeError Traceback (most recent call last)
Cell In[21], line 225
223 p = Pool(threads)
224 for i,batch in enumerate(tqdm(dl)):
--> 225 flow_batch(i, batch, p)
226 p.close()
227 p.join()

    Cell In[21], line 170, in flow_batch(i, batch, pool)
        168 pool.apply_async(np.save, (out_flow21_fn, flow21))
        169 if check_consistency:
    --> 170   _, flow12 = raft_model(frame_1, frame_2)
        171   flow12 = flow12[0].permute(1, 2, 0).detach().cpu().numpy()
        172   if flow_save_img_preview:
    
    File ~\WarpFusion\WarpFusion\env\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
       1496 # If we don't have any hooks, we want to skip the rest of the logic in
       1497 # this function, and just call forward.
       1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
       1499         or _global_backward_pre_hooks or _global_backward_hooks
       1500         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1501     return forward_call(*args, **kwargs)
       1502 # Do not call functions when jit is used
       1503 full_backward_hooks, non_full_backward_hooks = [], []


    RuntimeError: The following operation failed in the TorchScript interpreter.
    Traceback of TorchScript (most recent call last):
    RuntimeError: The following operation failed in the TorchScript interpreter.
    Traceback of TorchScript (most recent call last):
    RuntimeError: shape '[0, 1, 4, 2, 5, 3]' is invalid for input of size 524288

I've tried a number of different things: restarting my computer, reinstalling everything, clearing the pip cache and then reinstalling all the dependencies, using different video files, shelling into the pipenv to try to manually reinstall pytorch, disabling the lq option, and basically everything else I can think of. Nothing seems to work. When I try to output the shape of the frames, it shows the following: torch.Size([1, 3, 512, 512]) torch.Size([1, 3, 512, 512]). This is a different shape from what it wants, and the shape of the frames seems correct. I have no idea how to proceed from here. Is there anything I can do in order to address this?

Thank you!

Sxela commented

Hi, probably this is an older version using jit compiled raft version. it has been compiled for v1.12 of pytorch, try downgrading to that version.

After some tweaking, this seemed to work. Thank you so much! Awesome tool btw!