yongchaoz/FRePo

Would you like to release pytorch version of the code

Closed this issue · 1 comments

This can help promote your research.

Hi Guangxiang,

Thanks for your interest in our work and the suggestion.

I include a torch branch that implements the essential part of the FRePo based on the DC repo. From the cifar100 1 ipc/cls experiment, the Pytorch version achieves a test accuracy of around ~25%, which is ~3% worse than the JAX version. The speed and memory consumption is also slightly worse than the Jax version.

There are three major differences between Pytorch Version and Jax version.

  1. Optimizer: Jax version uses LAMB, while the Pytorch version uses Adam.
  2. Data augmentation: the implementation is slightly different in these two versions, and the strength is slightly different. Pytorch version use DC implementation, and Jax builds upon Augmax implementation.
  3. Network initialization: Jax uses Lecun as default, while Pytorch uses He as default.

I am not very sure if there is any other difference between these two frameworks that affects the performance but it seems that the hyperparameter works for JAX version does not work very well for Pytorch. Depending on your goal, if you want to improve the FRePo on standard benchmarks, building upon the Jax version may make the life easier. However, if you're going to use FRePo for other applications, I guess torch version may be fine. If you implement a version that is better than mine, feel free to create a pull request.

Best,
Yongchao