takerum/vat_tf

Unstable training in MNIST

Opened this issue · 0 comments

Hello.

I have created an implementation of VAT in pytorch and I am facing the following issue:
While the code works for the toy example (two moons dataset) and produces the correct decision boundary, when running on MNIST with 100 labeled samples, training becomes unstable and test accuracy oscillates constantly as training progresses. This issue is mitigated when I increase the number of labeled samples to 300 or more. In that case, training becomes stable and there is noticeable improvement in comparison to the supervised baseline, as expected.

Do you have any intuition as to why the above happens?

e.g. The network for MNIST consists of fully connected layers with batch normalization and dropout. Removing the batch norm layers and/or the dropout doesn't seem to affect the issue