davda54/sam

Multi GPU

lthilnklover opened this issue · 3 comments

Hi,

I am currently running some experiments with SAM. Previously I used Jax/Flax code provided by the authors of SAM. However for some reasons, I also have to run experiments with pytorch version. Then I found your pytorch implementation of SAM/ASAM.

First of all thank you for your work, it saved a lot of work and time for me. Especially listing known issues in Readme was very helpful.

I realized that Multi GPU version of SAM is yet to be implemented (and also not planned to be implemented in near future according to your comments in a closed issue). So I have following questions regarding Multi GPU settings.

  1. Is the experiment results you presented in Readme using only 1 GPU?
  2. In the Readme, there seems to be some remarks about Data Parallel by @evanatyourservice. However to my knowledge, DDP in pytorch automatically average out the gradients with .backward(). So rather than reducing all gradients after the second pass, shouldn't we NOT SYNC the gradients at the first pass? If not, it would be grateful if @evanatyourservice could provide the code for '''reduce_all_gradients''' part.

Once again thank you for the great work!

Hi,

Thanks for using this repository!

  1. Yes, the results are based on the code in ./example/
  2. I think the correct solution is to use the with model.no_sync(): block for the first pass. I'll try to make the readme less confusing ragarding the multi-gpu training.

Thank you for the reply!

Hi @lthilnklover! The reduce all gradients from my comments come from using the pytorch xla where you can explicitly reduce gradients manually, so i leave that out the first pass leaving everything separate on each accelerator, then putting that in for after the second pass for the actual update from the "noised" model.