gpauloski/kfac-pytorch

Is this implementation of the kfac algorithm compatible with the nn module which is not a distributed module?

Closed this issue · 4 comments

Describe the Request

Hi gpauloski,

I am trying to use kfac to optimize a model based on torch.func (vmap) which is not supported by torch.distributedDataParallel right now. I was wondering if the implementation of the kfac algorithm is compatible with the nn module, which is not a distributed module?

Thank you in advance for your kind help!

Best,
Yaolong

Hi, Yaolong. Thanks for reaching out!

It is not a requirement of the KFACPreconditioner class to use a DistributedDataParallel module nor to have torch distributed initialized. The KFACPreconditioner accepts any torch.nn.Module

model: torch.nn.Module,

but will only register Linear and Conv2d layers for KFAC updates.
All of the communication operations have shortcuts for when torch distributed is not available, defaulting to a world size of 1. E.g.,
if get_world_size(group) == 1:

So the code should work for single device training.

However, I'm not sure about the torch.func compatibility. functorch/torch.func was released after I did most of this development so I have not tried it myself.

The KFACPreconditioner does do some in-place modification of gradients and definitely has side effects (i.e., KFACPreconditioner.step() is not a pure function). So I'd be a bit worried about potential problems there because torch.func/JAX generally use pure functions for transforms.

That's mostly speculation though, so if you give it a try and post any stack traces here we can discuss it further.

Hi, Yaolong. Thanks for reaching out!

It is not a requirement of the KFACPreconditioner class to use a DistributedDataParallel module nor to have torch distributed initialized. The KFACPreconditioner accepts any torch.nn.Module

model: torch.nn.Module,

but will only register Linear and Conv2d layers for KFAC updates.
All of the communication operations have shortcuts for when torch distributed is not available, defaulting to a world size of 1. E.g.,

if get_world_size(group) == 1:

So the code should work for single device training.
However, I'm not sure about the torch.func compatibility. functorch/torch.func was released after I did most of this development so I have not tried it myself.

The KFACPreconditioner does do some in-place modification of gradients and definitely has side effects (i.e., KFACPreconditioner.step() is not a pure function). So I'd be a bit worried about potential problems there because torch.func/JAX generally use pure functions for transforms.

That's mostly speculation though, so if you give it a try and post any stack traces here we can discuss it further.

Hi,

Thank you for the quick reply.
" but will only register Linear and Conv2d layers for KFAC updates." Does this mean that only the parameters of these layers will be updated by KFAC? If I use KFAC and SGD, will the other parameters in the model still be updated by SGD or remain the same?

I will try KFAC with torch.func and let you know what is happened.

Best,
Yaolong

Yes, that is correct. KFAC will update the gradient in-place for the Linear and Conv2d modules, and all other layers will just be ignored by KFAC. Then your optimizer, SGD in your case, is applied to update the weights for all layers, as it normally would regardless of if the gradients were preconditioned.

Yes, that is correct. KFAC will update the gradient in-place for the Linear and Conv2d modules, and all other layers will just be ignored by KFAC. Then your optimizer, SGD in your case, is applied to update the weights for all layers, as it normally would regardless of if the gradients were preconditioned.

Many thanks!