uzh-rpg/RVT

TBPTT

Closed this issue · 2 comments

Hi Magehrig could you tell me where is the TBPTT dataloader in your code? I didn't find it

This function creates the 'streaming' datasets.

If you look further into the code you can find the DataModule that is use by Pytorch-Lightning to handle the dataloading. The datamodule that I wrote can return 2 dataloaders during training. One dataloader returns random samples (for BPTT) and another one streams sequences (for TBPTT). For evaluation, we only use the streaming dataloader.
Finally, in the training loop, we merge the batches from both dataloaders here.

Let me know if I answered your question.

Thanks a lot for your reply, I'll close the issue