question about only updating the domain weights on process 0
SueJane opened this issue · 4 comments
Hi Michael,
Thanks for releasing this code base and all the amazing work you have done! I'm learning about DoReMi and have a question: I noticed that the domain weights are updated only on the process 0, so how do other processes get the new weights when compute the loss and update the proxy model?
Thanks!
The domain weight update is communicated to the other processes automatically since the weights are stored in a torch buffer. This is similar to how batch norm is implemented.
Hi,
Thanks for sharing the project code!
Are you sure buffers are automatically broadcasted to all processes when you change the weights/values of the buffer only on process/rank 0? I implemented something similar a while ago, and if I changed the buffer on rank/process 0, only the buffer on rank/process 0 was updated, the others remained the same. I explicitly had to broadcast the update.
Thanks,
Maurits
Yes, we've checked before (and I just checked again) and the buffers are automatically updated. I also ran a version with broadcasting for a while and the domain weight trajectories were almost identical. In your other implementation, did you create the buffer before wrapping the model in DDP?
Also see the image below from pytorch DDP docs: https://pytorch.org/docs/stable/notes/ddp.html