Jackson-Kang/Pytorch-VAE-tutorial

Question about the distance calculation inside VQEmbeddingEMA

KinWaiCheuk opened this issue · 1 comments

First of all, thank you for making this tutorial. I learned a lot from it. There is one thing I would like to double check.

In cell 5 of your Jupyter notebook, you tried to calculate all the possible distance pairs between the continuous latent space x_flat and the codebook embedding as

torch.addmm(torch.sum(embedding ** 2, dim=1) +
    torch.sum(x_flat ** 2, dim=1),
    x_flat, embedding.t(),
    alpha=-2.0, beta=1)

However, I don't think it is the correct way to calculate the distance between $z_e(x)_i$ and $e_j$ for $i \neq j$.
Let's look at the following example.

For simplicity, let's assume there are only two entries $e_1$ and $e_2$ in the codebook.
Let's set $e_1=(-0.8567, 1.1006, -1.0712)$ and $e_2=(0.1227, -0.5663, 0.3731)$.
Similarly, let's set $z_e(x)_1 = (0.4033, 0.8380, -0.7193)$ and $z_e(x)_2 = (-0.4033, -0.5966, 0.1820)$.

import torch

embedding = torch.tensor(
    [[-0.8567,  1.1006, -1.0712],
     [ 0.1227, -0.5663,  0.3731]]
    )
x_flat = torch.tensor(
    [[ 0.4033,  0.8380, -0.7193],
     [-0.4033, -0.5966,  0.1820]]
    )

dist_kang = torch.addmm(torch.sum(embedding ** 2, dim=1) +
                     torch.sum(x_flat ** 2, dim=1),
                     x_flat, embedding.t(),
                     alpha=-2.0, beta=1)

dist_lucidrain = (-torch.cdist(x_flat, embedding, p=2)) ** 2

Your implementation (dist_kang) returns

tensor([[1.7804, 2.4136],
        [5.4872, 0.3141]])

While the correct implementation (dist_lucidrain from lucidrain) returns

tensor([[1.7804, 3.2441],
        [4.6566, 0.3141]])

As you can see, even though your implementation returns the correct results for the $z_e(x)_i$ and $e_j$ pairs when $i = j$ (the diagonal elements), the results for the $i \neq j$ cases are all off.
vqvae

Here is an example on how to verify the correct values for the two cases:

# e1-z1 and e2-z2 distance pair
torch.sum(embedding**2 - 2*embedding*x_flat + x_flat**2, dim=1)

>>> tensor([1.7804, 0.3141])
# e2-z1 and e1-z2 distance pair
torch.sum(embedding.flip(0)**2 - 2*embedding.flip(0)*x_flat + x_flat**2, dim=1)

>>> tensor([3.2441, 4.6566])

So the correct value for the $e_2$, $z_e(x)_1$ pair should be 3.2441, and $e_1$, $z_e(x)_2$ should be 4.6566.
And therefore, torch.cdist should be used to calculate the distance instead ofr torch.addmm.

Hello, @KinWaiCheuk .

Sorry for late response, and thank you for kind explanation.
Within few days, I will fix tutorial code as what you suggested.

Again, thank you for your kindness.