google/fedjax

Support for manually modifying client/server learning rate

marcociccone opened this issue · 1 comments

Hi,
I'm playing around with clients learning rate but I cannot find a clean way of modifying it.

Basically, I need to change the LR following a schedule based on the current round.
Is that possible?

Thanks

Hi,

Thanks for trying out FedJAX and filing this issue!

I think this would mainly be an optax functionality since FedJAX optimizers just wrap around existing optax optimizers.

If you want a different client learning rate for each round of federated averaging, here are few suggestions:

  1. Pass round number as part of the server_state in the federated algorithm and then create a new client optimizer with learning rate based on it at each round of federated training. (caveat: potentially slow due to effectively recompiling client optimizer apply function each round)
  2. Optax has built-in support for learning rate decay e.g., optax.exponential_decay. Some of these schedulers can be used directly / tweaked to support your use cases.

Some potentially helpful links:

If you have more concrete examples of what you're trying to do, we'd be happy to look into it!