JAX version code
magicknight opened this issue · 1 comments
magicknight commented
I am interested in the repository and noticed that a large portion of the code is based on MAE and VQGAN, and the original implementation was in JAX/TPU. I was wondering if it would be possible to obtain the JAX version code that was used to create this repository? Thank you.
LTH14 commented
Hi, thanks for your interest! Currently, we do not have a plan to release the JAX code as it depends on some of Google's internal codebase. From our experiment, this PyTorch version should give very similar performance as the JAX version.