Is this implementation of the kfac algorithm compatible with the nn module which is not a distributed module?
Closed this issue · 4 comments
Describe the Request
Hi gpauloski,
I am trying to use kfac to optimize a model based on torch.func (vmap) which is not supported by torch.distributedDataParallel right now. I was wondering if the implementation of the kfac algorithm is compatible with the nn module, which is not a distributed module?
Thank you in advance for your kind help!
Best,
Yaolong
Hi, Yaolong. Thanks for reaching out!
It is not a requirement of the KFACPreconditioner
class to use a DistributedDataParallel
module nor to have torch distributed initialized. The KFACPreconditioner
accepts any torch.nn.Module
kfac-pytorch/kfac/preconditioner.py
Line 56 in 86bf926
but will only register
Linear
and Conv2d
layers for KFAC updates.All of the communication operations have shortcuts for when torch distributed is not available, defaulting to a world size of 1. E.g.,
kfac-pytorch/kfac/distributed.py
Line 216 in 86bf926
So the code should work for single device training.
However, I'm not sure about the torch.func
compatibility. functorch/torch.func
was released after I did most of this development so I have not tried it myself.
The KFACPreconditioner
does do some in-place modification of gradients and definitely has side effects (i.e., KFACPreconditioner.step()
is not a pure function). So I'd be a bit worried about potential problems there because torch.func
/JAX generally use pure functions for transforms.
That's mostly speculation though, so if you give it a try and post any stack traces here we can discuss it further.
Hi, Yaolong. Thanks for reaching out!
It is not a requirement of the
KFACPreconditioner
class to use aDistributedDataParallel
module nor to have torch distributed initialized. TheKFACPreconditioner
accepts anytorch.nn.Module
kfac-pytorch/kfac/preconditioner.py
Line 56 in 86bf926
but will only register
Linear
andConv2d
layers for KFAC updates.
All of the communication operations have shortcuts for when torch distributed is not available, defaulting to a world size of 1. E.g.,
kfac-pytorch/kfac/distributed.py
Line 216 in 86bf926
So the code should work for single device training.
However, I'm not sure about thetorch.func
compatibility. functorch/torch.func
was released after I did most of this development so I have not tried it myself.The
KFACPreconditioner
does do some in-place modification of gradients and definitely has side effects (i.e.,KFACPreconditioner.step()
is not a pure function). So I'd be a bit worried about potential problems there becausetorch.func
/JAX generally use pure functions for transforms.That's mostly speculation though, so if you give it a try and post any stack traces here we can discuss it further.
Hi,
Thank you for the quick reply.
" but will only register Linear
and Conv2d
layers for KFAC updates." Does this mean that only the parameters of these layers will be updated by KFAC? If I use KFAC and SGD, will the other parameters in the model still be updated by SGD or remain the same?
I will try KFAC with torch.func and let you know what is happened.
Best,
Yaolong
Yes, that is correct. KFAC will update the gradient in-place for the Linear
and Conv2d
modules, and all other layers will just be ignored by KFAC. Then your optimizer, SGD in your case, is applied to update the weights for all layers, as it normally would regardless of if the gradients were preconditioned.
Yes, that is correct. KFAC will update the gradient in-place for the
Linear
andConv2d
modules, and all other layers will just be ignored by KFAC. Then your optimizer, SGD in your case, is applied to update the weights for all layers, as it normally would regardless of if the gradients were preconditioned.
Many thanks!