Question about the distance calculation inside VQEmbeddingEMA
KinWaiCheuk opened this issue · 1 comments
First of all, thank you for making this tutorial. I learned a lot from it. There is one thing I would like to double check.
In cell 5 of your Jupyter notebook, you tried to calculate all the possible distance pairs between the continuous latent space x_flat
and the codebook embedding
as
torch.addmm(torch.sum(embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1),
x_flat, embedding.t(),
alpha=-2.0, beta=1)
However, I don't think it is the correct way to calculate the distance between
Let's look at the following example.
For simplicity, let's assume there are only two entries
Let's set
Similarly, let's set
import torch
embedding = torch.tensor(
[[-0.8567, 1.1006, -1.0712],
[ 0.1227, -0.5663, 0.3731]]
)
x_flat = torch.tensor(
[[ 0.4033, 0.8380, -0.7193],
[-0.4033, -0.5966, 0.1820]]
)
dist_kang = torch.addmm(torch.sum(embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1),
x_flat, embedding.t(),
alpha=-2.0, beta=1)
dist_lucidrain = (-torch.cdist(x_flat, embedding, p=2)) ** 2
Your implementation (dist_kang
) returns
tensor([[1.7804, 2.4136],
[5.4872, 0.3141]])
While the correct implementation (dist_lucidrain
from lucidrain) returns
tensor([[1.7804, 3.2441],
[4.6566, 0.3141]])
As you can see, even though your implementation returns the correct results for the
Here is an example on how to verify the correct values for the two cases:
# e1-z1 and e2-z2 distance pair
torch.sum(embedding**2 - 2*embedding*x_flat + x_flat**2, dim=1)
>>> tensor([1.7804, 0.3141])
# e2-z1 and e1-z2 distance pair
torch.sum(embedding.flip(0)**2 - 2*embedding.flip(0)*x_flat + x_flat**2, dim=1)
>>> tensor([3.2441, 4.6566])
So the correct value for the 3.2441
, and 4.6566
.
And therefore, torch.cdist
should be used to calculate the distance instead ofr torch.addmm
.
Hello, @KinWaiCheuk .
Sorry for late response, and thank you for kind explanation.
Within few days, I will fix tutorial code as what you suggested.
Again, thank you for your kindness.