kohjingyu/fromage

torch.distributed.all_gather does not have grads

Closed this issue · 2 comments

Thank you for your great work! While walking through your code, I noticed a significant bug when training distributedly:

fromage/main.py

Lines 463 to 464 in b36a188

dist.all_gather(all_visual_embs, visual_embs)
dist.all_gather(all_last_embedding, last_embedding)

See the comparison between:
https://github.com/salesforce/LAVIS/blob/7f00a0891b2890843f61c002a8e9532a40343648/lavis/models/base_model.py#L241
and
https://github.com/salesforce/LAVIS/blob/7f00a0891b2890843f61c002a8e9532a40343648/lavis/models/base_model.py#L223

Basically, if we want the gradients to flow across ranks when doing all_gather, we have to opt for the latter solution: patching with the autograd functions.

I am wondering if you are experiencing troubles when training the Fromage in a distributed setting.

Hi, thanks for bringing this up. It's an interesting question. Did you experience this yourself when running it?

To my understanding, these lines should fix the gradient issue:
https://github.com/kohjingyu/fromage/blob/b36a1889e16cb9486e83e1853dce68ab653068c9/main.py#L465C42-L467

This seems to be doing essentially the same thing as the code that you shared (L220).

We didn't experiment much with distributed training in Fromage, but for GILL (with the same code) we trained on 2 GPUs and it works.

Oh, that makes sense. I am only attempting to walk through different versions of large-scale contrastive training codebase and notice these differences. I did not experience this problem.

Thank you for your time!