Simple L2 reconstruction loss?
SilenceMonk opened this issue · 5 comments
Thx for the excellent work! Currently, it seems that we are using some obscure DiscMixLogistic reconstruction loss. Is there any guide on using simple L2 reconstruction loss? Do I need to change the model architecture for that?
Hello @SilenceMonk, thanks for the interest you show in our work!
tl;dr: Using the L2 reconstruction loss on a VAE is sub-optimal for training a VAE and should be avoided as it causes the VAE to train in an over-regularized regime, causing its outputs to be blurry. If over-regularization is actually desirable (to encode less information in the latent space), then changing the output filters to the number of channels in the image (3 if RGB) and changing the loss function to L2 should be enough for training. During inference, there is no sampling process and the logits are the actual image.
First and foremost, it should be pointed out that VAE shouldn't be trained with an RGB output layer and an L2 reconstruction loss (MSE). As beautifully explained in section 5.1 (page 12) of this paper, training the VAE with MSE instead of
When you consider the ELBO of a (standard) VAE:
maximizing the right hand size of the inequality (or equivalently minimizing its opposite) decomposes into a reconstruction loss term
To deal with the reconstruction term
The decoder's output layer can model any distribution really: Gaussian, Logistic, etc. I will not go too deep into the details of why we tend to use mixture of discretized logistics as the distribution when modeling pixels (or audio) in this comment, but the github answer I shared earlier talks about that. The short answer to "why use a mixture of discretized logistics" is:
- Mixture because using a mixture of distributions is more expressive than a single distribution: you can make more complicated distributions with a mixture of gaussians instead of a single gaussian for example. In fact a mixture of gaussians can be multi-modal while a simple gaussian is unimodal.
- Discretized because we are trying to predict discrete pixel values in the integer range [0, 255]. i.e: values in the range [2.5, 3.5) are all be considered equal to 3 and no penalty should be on them.
- Logistics because the logistic distribution has a simple Cumulative Distribution Function (CDF) which is the sigmoid function. This is both helpful in making a stable, easy to compute training loss and in having a simple inversion sampling formula during inference.
- Why not a classification layer like a softmax distribution? Because the pixels are ordinals and not categorical. i.e: 101 is closer to 100 than it is to 99, so the order of the "classes" (pixel discrete values) matters. Softmax outputs have been explored in image generation in the past, but are usually outperformed by mixture of logistics outputs.
Optimizing the MSE instead of performing an explicit MLE on a model distribution however is sub-optimal when maximizing the likelihood of the data under the gaussian distribution. In fact, the reconstruction loss of the ELBO under a decoder modeling a single gaussian distribution can be written as:
You can notice that this reconstruction loss is equivalent to the MSE loss under the assumption that the gaussian scale
Now, if the over-regularization is still desirable (side note: over-regularization can be achieved using beta-VAE as well), then changing the output filters to the number of channels in the image (3 if RGB) and changing the loss function to L2 should be enough for training. The output layer can be a simple linear layer with 3 filters for RGB. During inference, there is no sampling process and the logits are the actual image. It is worth pointing out that the logits in this case are floating points and should be discretized and clipped to the integer range [0, 255].
If the needs also arises at any time, the output layer and MoL loss can be changed to work with any other distribution that fits the need (Gaussian, Cauchy, etc).
I hope this long comment answers your question and also gives some extra information to help reduce the obscurity around DiscMixLogistic! :)
Let me know if there are any concerns or extra questions about this topic!
Thank you very much for reaching out.
Rayhane.
Wow not expect for such a quick, loooooong and detailed guide on MoL loss! I‘ve certainly learnt a lot from it, and I'll definitely check the github anwser later on. Great thx again!
Hi there I get a little follow-up question on MoL loss: what if I am dealing with some discrete data where order doesn't matter, like words? Will it be an issue to use MoL loss in this case?
Hello again @SilenceMonk!
In the case where you are dealing with discrete data, then you change the model's output distribution from a Mixture of (discretized) logistics to a categorical (also called Multinoulli) distribution. Lucky for us, the negative log of the Multinoulli distribution is exactly the cross entropy loss function (derivation below). Which means, minimizing the cross entropy loss is equivalent to performing a maximum likelihood estimation on a categorical distribution.
So, when dealing with a purely categorical output (where the order doesn't matter), Having a simple softmax output layer + using a cross entropy loss function is sufficient. This is the common practice when dealing with categorical data (as far as I am aware).
What follows is the derivation of the cross entropy loss from the Multinoulli PMF:
- Assume there is a categorical distribution over K classes that you want to model (K is the vocabulary size for words for example).
- Consider the Multinoulli (categorical) distribution PMF, modeled by the output layer of our model
$p_{\theta}(x|z) = \prod_{i} p_{i}^{x_i}$ where$x_i$ is the target for class$i$ and$p_{i}$ is the probability of the model for class$i$ ,$i \in$ {1, ..., K}. i.e$p_{i} = p_{\theta}(X = x_i | z)$ with$X$ being the categorical random variable. - The probabilities
$p_i$ must be restricted to be fully positive and$\sum_{i} p_i = 1$ . In practice, this typically means$p_i$ must be the output of your softmax layer pretty much. - The reconstruction term
$-\log p_{\theta}(x|z)$ then exactly computes to the cross entropy loss:
$$-\log p_{\theta}(x|z) = -\log \prod_{i} p_{i}^{x_i} = -\sum_{i} \log p_{i}^{x_i} = -\sum_{i} x_i \log p_{i}$$
Hope this answers the question :)
Feel free to reach out with any more related/unrelated questions!
Rayhane.
Wow great thanks again @Rayhane-mamah! I guess I'll go with softmax+cross entropy when dealing with discrete data and see what I'll get.