sniklaus/softmax-splatting

Having difficulty training the model

Closed this issue · 3 comments

Hi,

Great work and thanks for providing the implementation!

I have implemented the network as you described in the paper and in several issues, but my training is not working (the loss not decreasing, and validation performance not improving). I have tried out various learning rates but nothing worked. I am having difficulty trouble-shooting, so wondering if you could quickly check my implementation at high level. Below is the pseudo-code of my implementation, where *_ds is short for "*_downsampled". Thank you very much in advance for your help!

Input: frame0, frame2
Output: frame1
Initialise: flow_estimator = PWCNet()
               feature_extractor = CNNDescribedinPaper()
               synthesis = GridNet()
               compute_metric = Metric()
               softsplat = ModuleSoftsplat(strType='softmax')

1. normedframe0, normedframe2 = InstanceNorm(frame0, frame2)
### splatting 0->2
2. tenFlow_ds4 = flow_estimator(frame0, frame2)
3. tenFlow_ds2 = 2 * upsample(tenFlow_ds4)
4. tenFlow = 2 * upsample(tenFlow_ds2)
5. metric = compute_metric(frame0, frame2, tenFlow)
6. metric_ds2 = downsample(metric)
7. metric_ds4 = downsample(metric_ds2)
8. feature0, feature0_ds2, feature0_ds4 = feature_extractor(normedframe0)
9. normedframe0_warped = softsplat(normedframe0, 0.5*tenFlow, metric)
10. feature0_warped = softsplat(feature0, 0.5*tenFlow, metric)
11. feature0_ds2_warped = softsplat(feature0_ds2, 0.5*tenFlow_ds2, metric_ds2)
12. feature0_ds4_warped = softsplat(feature0_ds4, 0.5*tenFlow_ds4, metric_ds4)
### splatting 2->0
13. tenFlow_ds4 = flow_estimator(frame2, frame0)
14. tenFlow_ds2 = 2 * upsample(tenFlow_ds4)
15. tenFlow = 2 * upsample(tenFlow_ds2)
16. metric = compute_metric(frame2, frame0, tenFlow)
17. metric_ds2 = downsample(metric)
18. metric_ds4 = downsample(metric_ds2)
19. feature2, features2_ds2, feature2_ds4 = feature_extractor(normedframe2)
20. normedframe2_warped = softsplat(normedframe2, 0.5*tenFlow, metric)
21. feature2_warped = softsplat(feature2, 0.5*tenFlow, metric)
22. feature2_ds2_warped = softsplat(feature2_ds2, 0.5*tenFlow_ds2, metric_ds2)
23. feature2_ds4_warped = softsplat(feature2_ds4, 0.5*tenFlow_ds4, metric_ds4)
### synthesis
24. frame1 = synthesis(
                cat(normedframe0_warped, feature0_warped, normedframe2_warped, feature2_warped),
                cat(feature0_ds2_warped, feature2_ds2_warped),
                cat(feature0_ds4_warped, feature2_ds4_warped)
             )
25. frame1 = reverseInstanceNorm(frame1)

PS: some more about my settings: PWCNet obtained from your repo and frozen; compute_metric taken from here; ADAMax optimizer with lr=1e-4; loss is Lap loss with REDUCE and EXPAND factors as you described here.

Problem solved - my implementation of Lap loss was flawed. I forgot to append to my laplacian pyramid the last blurred version of the input image, and as a result my pyramid only contained the differences between levels of Gaussian blurring.

Now the network is learning as expected, but it would still be great if you could confirm the algorithm above is correct. Many thanks for your time and help!

I am happy you were able to find the issue in your loss function! I only glanced at it but didn't see anything obviously wrong. Some thoughts though:

  1. You call flow_estimator(...) twice and it seems like you could be more efficient there since each time you call PWC-Net it will create feature pyramids of the inputs, so you are doing that feature pyramid extraction twice.
  2. You can merge lines 9 and 10 by concatenating normedframe0 and feature0 before calling softsplat, that way you also simplify line 24 a little bit (same for lines 20 and 21).
  3. Make sure to have the PWC-Net fixed (not trained end-to-end) until the synthesis network is converged (then fine-tune the entire pipeline with end-to-end training). You can use .detach() on the estimated flow to avoid supervising PWC-Net (which stops backprop at that point).

I am happy you were able to find the issue in your loss function! I only glanced at it but didn't see anything obviously wrong. Some thoughts though:

  1. You call flow_estimator(...) twice and it seems like you could be more efficient there since each time you call PWC-Net it will create feature pyramids of the inputs, so you are doing that feature pyramid extraction twice.
  2. You can merge lines 9 and 10 by concatenating normedframe0 and feature0 before calling softsplat, that way you also simplify line 24 a little bit (same for lines 20 and 21).
  3. Make sure to have the PWC-Net fixed (not trained end-to-end) until the synthesis network is converged (then fine-tune the entire pipeline with end-to-end training). You can use .detach() on the estimated flow to avoid supervising PWC-Net (which stops backprop at that point).

Thank you very much for your reply and advice - it is very helpful :)