
Training with different batch size

eayumi opened this issue · 10 comments

eayumi commented

I noticed training with --batch_size other than 1 does not work, amonst onthers due to the assert B==1 in cotracker/models/core/cotracker/cotracker.py

Why is that so?
Can't I train with say batch_size 16 ? And how would I do that?

Hi @eayumi, the current version of the code supports only a batch size of 1, which was enough for training on a 32GB GPU with 256 trajectories. There's some logic with adding points that are tracked after the first sliding window that's not going to work with bigger batch sizes yet.
For now, If you have more GPU memory, you can increase traj_per_sample instead. We are working on the next version of the model and might add support for bigger batch sizes later.


If the number of points across different videos is fixed,
such as using 8 points for all videos, with the number fixed throughout each video,
can we perform batch inference on multiple different videos within a single GPU?

Hi @sfchen94, we currently initialize a point token only at the sliding window where it appears first. So, the number of tokens for each sliding window will be different if the batch size is bigger than 1. We will fix this in the next version of CoTracker that we plan to release in late November. It will support training and inference with different batch sizes.

Got it. Keep up the excellent work!

Could you shortly explain, which part of the model will have a logical issue if we use a batch size greater than 1.
I could not exactly pin point the problem of adding points after the first sliding window being a logical fault.

Hi @zetaSaahil, we currently have a different number of tokens for every sample, so these samples can't be processed by the transformer in a batched way. This will be fixed soon with a more elegant solution.

The problem is now fixed. The updated codebase supports varying batch sizes for both training and inference.


It seems that to allow multi-batch inference, this

if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
queries = torch.cat([queries, grid_pts], dim=1)

should be changed to

        if add_support_grid:
            grid_pts = get_points_on_a_grid(
                self.support_grid_size, self.interp_shape, device=video.device
            grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
            grid_pts = grid_pts.repeat(B, 1, 1)
            queries = torch.cat([queries, grid_pts], dim=1)

otherwise, the concatenation operation fails due to the non-matching first dimension.

And this

mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)

should be changed to

        mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)

for a similar reason.

Hi @16lemoing, thank you for pointing this out!
Fixed it: f084a93

Thanks a lot!