serre-lab/Warped

grid_sample breaks for batched images

Closed this issue · 0 comments

Error:
RuntimeError: grid_sampler(): expected grid and input to have same batch size, but got input with sizes [128, 3, 224, 224] and grid with sizes [1, 224, 224, 2]

The trailing shapes are not matching so broadcasting doesn't seem to happen. I've had to do a hacky fix which is to do a simple repeat of the grid for now. Creating the issue to check if there are any better ways to do it.

Fix:

grid = grid.to(img.device)
grid = grid.repeat(batch_size, 1, 1, 1)
warped_img = F.grid_sample(img, grid, align_corners=True, mode='bilinear', padding_mode='zeros')