zalandoresearch/pytorch-vq-vae

EMA update before quantization

stangelid opened this issue · 5 comments

Hi and thanks for providing a nice and clean implementation of VQ-VAEs :)

While playing around with your code, I noticed that in VectorQuantizerEMA you first perform the EMA update of the codebook counts and embeddings, and then use the updated codebook embeddings as the quantized vectors (and for computing e_latent_loss).

In particular, the order in which you perform operations is:

  1. Nearest neighbour search
  2. EMA updates
  3. Quantization
  4. e_latent_loss computation

Is there a reason why you do the EMA updates before steps 3 and 4? My intuition says that the order should be:

  1. Nearest neighbour search
  2. Quantization
  3. e_latent_loss computation
  4. EMA updates

Looking forward to hearing your thoughts!

Many thanks,
Stefanos

@stangelid thanks for the issue! I think your intuition is right! I don't remember exactly my notes from the time I implemented this but thinking through it I think your way is correct. If you can send a PR I'll be happy to merge it, else I will put it into my TODO!

I a question in the same context, can you please provide an explanation to why you apply Laplace smoothing to the cluster sizes _ema_cluster_size. I am having a hard time understanding why (to my knowledge it was not mentioned in the paper) Thanks.

@yassouali According to my understanding, the laplace smoothing makes sure that no element of _ema_cluster_size will ever be exactly zero. If that ever happened, it would result in division with zero, when updating _ema_w.

I know this is late, but hope it helps :)

thanks @stangelid perhaps i'll add this explanation in the notebook

Thank you.