Support for manually modifying client/server learning rate
marcociccone opened this issue · 1 comments
marcociccone commented
Hi,
I'm playing around with clients learning rate but I cannot find a clean way of modifying it.
Basically, I need to change the LR following a schedule based on the current round.
Is that possible?
Thanks
jaehunro commented
Hi,
Thanks for trying out FedJAX and filing this issue!
I think this would mainly be an optax functionality since FedJAX optimizers just wrap around existing optax optimizers.
If you want a different client learning rate for each round of federated averaging, here are few suggestions:
- Pass round number as part of the server_state in the federated algorithm and then create a new client optimizer with learning rate based on it at each round of federated training. (caveat: potentially slow due to effectively recompiling client optimizer apply function each round)
- Optax has built-in support for learning rate decay e.g., optax.exponential_decay. Some of these schedulers can be used directly / tweaked to support your use cases.
Some potentially helpful links:
- https://github.com/deepmind/optax#schedules-schedulepy (intro to schedulers in optax)
- https://github.com/google/fedjax/blob/main/fedjax/experimental/notebooks/emnist_compression.ipynb (example using scheduler for server optimizer)
If you have more concrete examples of what you're trying to do, we'd be happy to look into it!