claudiom4sir/StableVSR

Question about codes

stillbetter opened this issue · 5 comments

Hi ! Thx for the codes. I am still confused at one point.
I'm curious which specific operation stands for Bidirectional Sampling . Is the random reverse or line 928-934 in train.py. and I can't make clear what lines 928-934 do.

Hi, the bidirectional sampling is implemented at inference time, so you can find the operations in stablevsr_pipeline.py.

In the train.py script,

random_t = [round(random.random()) * 2 for _ in range(b)] # <- decide t-1 or t+1
gt_prev = torch.stack([gt[i, t] for i, t in enumerate(random_t)])
upscaled_lq_prev = torch.stack([upscaled_lq[i, t] for i, t in enumerate(random_t)])
lq_prev = torch.stack([lq[i, t] for i, t in enumerate(random_t)])
gt = gt[:, t // 2, ...]
lq = lq[:, t // 2, ...]
upscaled_lq_cur = upscaled_lq[:, t // 2, ...]

these operations allows to randomly select the previous or the next frame for conditioning at train time. Indeed, a sequence of three frames is considered and the target is always the one in the middle, so it randomly selects one of the adjacent frames for conditioning. This is just to prepare the network to receive the frame sequence in the correct or reverse order at inference time.

Great, I got it. I just treat random select and the random reverse as a augmentation to mimic the bidirectional sampling while the segment has just 3 frames.

Yes, It could be interpreted in that way too. You can also increase the number of frames of the segment to 5 for example. Just edit the reds configuration file and use a odd number of frames.

Yes, It could be interpreted in that way too. You can also increase the number of frames of the segment to 5 for example. Just edit the reds configuration file and use a odd number of frames.

but using 5 frames seems useless ? because just pick t-1 or t+1 frame as previous frame

                random_t = [round(random.random()) * 2 for _ in range(b)] # <- decide t-1 or t+1
                gt_prev = torch.stack([gt[i, t] for i, t in enumerate(random_t)])
                upscaled_lq_prev = torch.stack([upscaled_lq[i, t] for i, t in enumerate(random_t)])
                lq_prev = torch.stack([lq[i, t] for i, t in enumerate(random_t)])

Hi @GraceZhuuu,
oh, you are right, sorry.
This code should fix the problem:
random_t = random.choices([i for i in range(t) if i != t // 2], k=b)