Unstable training in MNIST
Opened this issue · 0 comments
Hello.
I have created an implementation of VAT in pytorch and I am facing the following issue:
While the code works for the toy example (two moons dataset) and produces the correct decision boundary, when running on MNIST with 100 labeled samples, training becomes unstable and test accuracy oscillates constantly as training progresses. This issue is mitigated when I increase the number of labeled samples to 300 or more. In that case, training becomes stable and there is noticeable improvement in comparison to the supervised baseline, as expected.
Do you have any intuition as to why the above happens?
e.g. The network for MNIST consists of fully connected layers with batch normalization and dropout. Removing the batch norm layers and/or the dropout doesn't seem to affect the issue