improve network performance
bodokaiser opened this issue · 13 comments
These are the results of the current segnet implementation:
As they are far from the paper results there must be a systematic error in our implementation. Here I want to collect ideas what to do better.
- ignore background class in loss
- ignore border class in loss
- apply CRF to final result
- finetune learning rate (extra lr for prertrained layers)
- zero initialize model layers (in the first 4 epochs they converge to zero so why not init them that way)
- use batch normalized vgg16 for segnet as purposed in the paper
Hi, I just started watching your repo, as I'm keen on using it myself.
re: ignoring the background class, you can weight all the other classes except the background one via your loss, as done here: https://github.com/delta-onera/segnet_pytorch/blob/master/train.py#L77
Hi, I added loss weights and now retraining SegNet on AWS (this time for more then 31 epochs). I post the results when they are ready.
Do you have any further suggestions how to increase performance?
Unfortunately I can't post FCN results as I can't afford to train FCN8 for more than one day on AWS. To get PSPNet running I need to apply "atrous convolution" to ResNet101 which I don't totally understand at the moment. So for now I am more or less limited to SegNet.
Hm, when looking at your segnet implementation, I noticed that you aren't saving the maxpool indices but instead concatenating the earlier encoding activations. Here's an example of using the maxpool indices: https://github.com/delta-onera/segnet_pytorch/blob/master/segnet.py#L57 (you'll have to modify the way you define the encoder layers and not just rolling out the vgg layers).
Also, I'd recommend for training for longer. Before you renamed segnet2, I got the following results by training for 100 epochs using adam (though I didn't notice much difference compared to when using SGD) with a batch size of around 40 over two GPUs. My loss was still slowly but steadily decreasing at 100 epochs.
Wow the couch segmentation looks really good!
As far as I see the "fuse" problem is solved with:
- Deconvolution (Transposed Convolution)
- Upsampling
- Unpooling
Did you make any more experiments with SegNet and can recommend one of them?
I tried to stay as close to the caffe reference (visualization) as possible and I think they just use Upsampling (+ Concat).
For unpooling, they don't use a concatenation, they use a "pool mask" (aka indices from the corresponding encoding pool layer [https://github.com/alexgkendall/SegNet-Tutorial/blob/master/Models/segnet_train.prototxt#L1443]). It does a similar thing to concatenating but uses a lot less memory (not sure how it affects performance).
I'll try to implement a version with batch normalization and pooling indices and share it with you and let you know how it goes.
Ah good to know! Thank you very much.
I quickly added batch normalization and max pooling indices here, so definitely still WIP code (https://github.com/ruthcfong/piwise/blob/master/piwise/network.py#L298). I haven't been able to get significantly better performance. Using just batch normalization seems to result similarly [see SegNet], and bn + max pooling seems to do worse than concatenating [See SegNet2].
So you already ran ~100 epochs training on both variants?
Do you have more ideas how to improve SegNet performance? Maybe apply soft learn rates for the pretrained layers?
Is it correct to pass the indices of the prior layer and not the indices of the "prior-prior" layer?
Yes, I ran 100+ epochs on both (without indices seems to work better).
I'm not sure what you mean by "prior-prior", but I think I'm passing the right indices to match the diagram.
From looking at the training section, a few ideas to try:
- Local Contrast Normalization for the input
- Using Median Frequency Balancing (https://arxiv.org/pdf/1411.4734v4.pdf) to determine class weights for loss: weight(c) = median_freq / freq(c)
It might also be valuable to try training on the datasets they use in the original paper (CamVid and SUN RGB-D).
"We evaluate the performance of SegNet on two scene segmentation
tasks, CamVid road scene segmentation [22] and SUN
RGB-D indoor scene segmentation [23]. Pascal VOC12 [21] has
been the benchmark challenge for segmentation over the years.
However, the majority of this task has one or two foreground
classes surrounded by a highly varied background. This implicitly
favours techniques used for detection as shown by the recent
work on a decoupled classification-segmentation network [18]
where the classification network can be trained with a large set of
weakly labelled data and the independent segmentation network
performance is improved."
I added you as a contributor. Feel free to update this repo with your suggestions. I won't be able to work on this over the next weeks :(