wasidennis/AdaptSegNet

Understanding the weights for multi-level adversarial learning

Closed this issue · 5 comments

Hello,
It is a great work. I saw the readme file where you say the in pytorch-0.4 version the multi-level learning needs the weights for adv learning to be 0.0005 and 0.00005 respectively. I would like to know how you came up with this value ? I am asking because I am working on different dataset and it would be very helpful if I can understand how to specify the reasonable value for adv learning.

For the original version, we used the weight for the first layer smaller than the one only using the single-level model. In the pytorch-0.4 version, we have not tuned the weights heavily but found that decreasing the weights improves performance, based on experiments.

As a general guideline on training different datasets, in addition to parameter tuning, you can try to observe the loss behavior, so that adversarial learning could reach a balance between generator (segmentation network) and discriminator.

For example, you can first set the weight really small and may observe that it does nothing. Then you can gradually increase the weight, until it is too strong to reach the balance too fast (i.e., the best model can be in early iterations but degrades significantly after that).

Wow. That is a good guidance. I do find my model achieving best in initial iterations and degrading after that. I will lower my adv weights and see and I also find that the discriminator loss does not decrease after few iterations and when I use Batchnorm as in patchGAN the discriminator loss reaches to zero very quickly. I saw your other issues where you say the discriminator should gradually decrease and adv loss gradually increase. Any guidance on this will also be very helpful.

Huge thanks.

Yes, if the discriminator loss decreases too fast near zero, it basically means that adversarial learning is not working (easy to separate source and target domains).

Yes, ideally, the discriminator loss should gradually decrease and adv loss gradually increases. And eventually, they will fluctuate a bit, which means they are trying to find the balance all the time.

Thanks for the info. While playing with different weights for adv loss, did you guys train the model from the beginning at each time or when the segmentation network (generator) performance degrades you lower the weights and keeping the discriminator network loaded with previous saved parameter weights ?

To keep the training scheme simpler, we use the same weights all the time during the entire training process, as shown in our training script.