Rayhane-mamah/Efficient-VDVAE

comparison between jax vs torch

Closed this issue · 2 comments

Thanks for the great work!!!!

I am curious about the performance comparison between Jax and torch implementation, specifically, the training speed and NLL. Do you use jax or torch for the results in the paper?

Thank you for your interest in this work.

Tl;dr : Pytorch and JAX versions are basically the same.

NLL wise, we get the same results using either JAX or Pytorch. We confirmed across multiple experiments that the results are the same up to seed differences.

Speed wise, for newer versions of JAX there's a 10% speed increase. But for late 2021 versions, there's no speed differences.

The results in the paper are basically the same across libraries, and you should be able to get the same NLL results if you run inference from any checkpoint in the table. You can also compare the logs on Tensorboard and see that they're very similar.

P.S: At the moment of me writing this comment, CIFAR-10 Pytorch implementation does not include L2 Loss which results in it having a worse NLL than the JAX implementation.

Let me know if you have any other comments on this. Feel free to close this issue if you're happy with this answer.

Thank you!
Hazami Louay

Thanks for your reply and helpful information.