juntang-zhuang/GSAM

Mixed Precision Support

Opened this issue · 12 comments

Great work! When I was taking a look at your code and your example, I saw no mention of mixed precision. Does the current implementation of SAM and GSAM support torch.cuda.amp training ?

Hi, thanks for asking. Personally I'm not very familiar with mixed precision. I think using the autocast in PyTorch should be feasible (https://pytorch.org/docs/stable/amp.html#torch.autocast), if all we need to do is adding with autocast(): to here. Please take a try and let me know if it works, so that if it works I can add it as an option later.

BTW, do you have experience on how to make each worker gets its own gradient and weight perturbation, as discussed in the readme "Potionally unresolved issues"? That's one of the key factors to make SAM-family work.

Thank you for your response but even though mixed precision is achievable only by adding with autocast() that is not the recommended way in Pytorch. According to the Pytorch docs, https://pytorch.org/docs/stable/amp.html#gradient-scaling is to be used in mixed precision training. However, I cannot see how to use scaler.step with GSAM because amp.GradScaler does not currently support optimizers with closure.

As for synchronization, I am not sure if Pytorch currently support multi-gpu sync with DataParallel. DistributedDataParallel seems like the only option. If you believe it would help, there are github repositories that implement SyncBatchNorm for DataParallel like this one (https://github.com/vacancy/Synchronized-BatchNorm-PyTorch).

Furthermore, I could not find your Jax code. I am also using Jax, and would appreciate using this on my Jax projects as well. Isn't it open source, and if not do you plan to release it ?

Hi, thanks for the info. The set_closure ultimately creates a function self.forward_backward_func so that self.forward_backward_func() gets the gradient of parameters. I think you can put scaler.scale() and scaler.unscale() here to get gradients with amp. Note that self.forward_backward_func() is called twice here and here, so I guess some caution is required.

For sync, sorry what I actually want is "unsync". Ther perturbation <img src="https://latex.codecogs.com/svg.latex?&space;\Delta{w_t}=\rho_t\frac{\nabla{f}}{\vert\vert{\nabla{f}}\vert\vert}" requires each worker has its own value, all operations require no_sync except the last step so all workers have the same weights after optimizer.step.

For jax code, since it was during my intern at Google last year, so it really depends on Google policy. It's actually much easier in jax than PyTorch, because it's very natural to use "unsync gradients" and "per-worker perturbation" in jax. I'll talk with my supervisor regarding the release of code, and see if it's possible to re-write a public version by my memory.

Thanks for the directions for amp integration. Would you be willing to review and merge a PR on this if I implement correct grad scaling and finite grad checking ?

As for the desired unsync for nn.DataParallel, I do not believe that it is possible because nn.DataParallel implements multi-gpu distribution on module level meaning that once each worker produces the outputs. In the most basic sense, it is a wrapper that does scatter + gather for each worker and then does the backward pass on the main worker.

From (https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html#torch.nn.DataParallel):

In each forward, module is replicated on each device, so any updates to the running module in forward will be lost. For example, if module has a counter attribute that is incremented in each forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after forward. However, [DataParallel](https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html#torch.nn.DataParallel) guarantees that the replica on device[0] will have its parameters and buffers sharing storage with the base parallelized module. So in-place updates to the parameters or buffers on device[0] will be recorded.

Cool, thanks a lot for the help! I guess in order to get unsynced grad, we will need DistributedDataParallel rather than DataParallel. I think Jax is equivalent to DistributedDataParallel by default, I'll leave a message here if I'm allowed to share original jax code.

I made a few changes with the scheduler class and train.py (in order to make the example code match the readme file in terms of rho_scheduler and lr_scheduler), but they should not affect GSAM class.

Yes, jax.pmap is equivalent to DistributedDataParallel. I am sending a PR then for mixed precision support ?

Cool, thanks a lot. BTW, I wrote a PR in jax on the core code for gsam google-research/vision_transformer#169 according to my memory. I'm not at Google now so would not be possible to re-run experiments using this PR, you can check it out as a reference. A reminder is that google-research/vision-transformer does not provide the full code to train a ViT from scratch, which is different from our paper setting of training from scratch.

Thanks a lot for providing a reference : ) ! Before sending a PR, I have a few questions. Normally, GradScaler ignores steps when there are NaN values in the gradients. Since SAM variants have 2 updates, is it safe to ignore the whole step if there are NaN values after the first update ? Furthermore, GradScaler scales gradients by some constant that is updated in each step and the update of that constant is determined by the user (usually right after scaling the loss). When should the gradient scaling factor be updated in each of the 2 updates or after the 2nd one ?

I think GSAM should ignore the whole step if either 1st or 2nd backward fails, that's because it requires a linear projection operation of the two gradients, and one NAN would make the final output gradient invalid.

I think the best way is to scale grad back for each of the two steps, so that the inner product / projection of the two gradients are meaningful. Maybe within the get_grad function, put both scale and unscale, so that each call of get_grad gets the correct gradient. BTW, if unscale is applied on the gradient, does it need to be applied on the optimizer again?

Please take a look at #3 . I could not test this extensively and honestly I do not know how much this would affect performance. I believe we need to test this with a few models and datasets. Do we need to explicitly check if the model parameters in the scaler are the same with the ones given to SAM? What are other possible edge cases that need to be eliminated.

By the way, how can we also adapt this to SAM ?

Hi, thanks for the contribution. I think we would need some basic test to determine if we deal with gradscaler in the correct way. Maybe the check can be done in a test but not in the function itself.

To reduce to SAM, you can set rho_max=rho_min (basically a constant rho scheduler) and gsam_alpha=0 and adaptive=False

Okay thanks : ) . Currently, I am unable to run large scale training so if you test it, please keep me updated.