dajes/frame-interpolation-pytorch

Batch inference error

Closed this issue · 1 comments

Hello,

Thanks for this port, the results are great so far!

One issue I've run into is that running batches does not work. Running the test script from the README with a batch size of more than one leads to the error below:

import torch

device = torch.device('cuda')
precision = torch.float16

model = torch.jit.load('film_net_fp16.pt', map_location='cpu')
model.eval().to(device=device, dtype=precision)

img1 = torch.rand(2, 3, 720, 1080).to(precision).to(device)
img3 = torch.rand(2, 3, 720, 1080).to(precision).to(device)
dt = img1.new_full((2, 1), .5)

with torch.no_grad():
    img2 = model(img1, img3, dt)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/util.py", line 21, in forward
  for _5 in range(torch.len(pyramid)):
    image = pyramid[_5]
    _6 = torch.append(_4, torch.mul(image, scalar))
                          ~~~~~~~~~ <--- HERE
  return _4
def concatenate_pyramids(pyramid1: List[Tensor],

Traceback of TorchScript, original code (most recent call last):
  File "C:\Users\Danylo\PycharmProjects\frame-interpolation-pytorch\util.py", line 105, in forward
    # multiplied with a batch of scalars, then we transpose back to the standard
    # BxHxWxC form.
    return [image * scalar for image in pyramid]
            ~~~~~~~~~~~~~~ <--- HERE
RuntimeError: The size of tensor a (512) must match the size of tensor b (2) at non-singleton dimension 3

The easy fix is to just split and run individually, but I assume batch inference would be a bit faster (not that this model seems to be particularly heavy).

dajes commented

Hi @JCBrouwer. With the release 1.0.2 it is now fixed