Tired of having to handle asynchronous processes for neuroevolution? Do you want to leverage massive vectorization and high-throughput accelerators for evolution strategies (ES)? evosax
allows you to leverage JAX, XLA compilation and auto-vectorization/parallelization to scale ES to your favorite accelerators. The API is based on the classical ask
, evaluate
, tell
cycle of ES. Both ask
and tell
calls are compatible with jit
, vmap
/pmap
and lax.scan
. It includes a vast set of both classic (e.g. CMA-ES, Differential Evolution, etc.) and modern neuroevolution (e.g. OpenAI-ES, Augmented RS, etc.) strategies. You can get started here 👉
import jax
from evosax import CMA_ES
# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
# Run ask-eval-tell loop - NOTE: By default minimization!
for t in range(num_generations):
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = strategy.ask(rng_gen, state, es_params)
fitness = ... # Your population evaluation fct
state = strategy.tell(x, fitness, state, es_params)
# Get best overall population member & its fitness
state.best_member, state.best_fitness
The latest evosax
release can directly be installed from PyPI:
pip install evosax
If you want to get the most recent commit, please install directly from the repository:
pip install git+https://github.com/RobertTLange/evosax.git@main
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
- 📓 Classic ES Tasks: API introduction on Rosenbrock function (CMA-ES, Simple GA, etc.).
- 📓 CartPole-Control: OpenES & PEPG on the
CartPole-v1
gym task (MLP/LSTM controller). - 📓 MNIST-Classifier: OpenES on MNIST with CNN network.
- 📓 LRateTune-PES: Persistent ES on meta-learning problem as in Vicol et al. (2021).
- 📓 Quadratic-PBT: PBT on toy quadratic problem as in Jaderberg et al. (2017).
- 📓 Restart-Wrappers: Custom restart wrappers as e.g. used in (B)IPOP-CMA-ES.
- 📓 Brax Control: Evolve Tanh MLPs on Brax tasks using the
EvoJAX
wrapper.
-
Strategy Diversity:
evosax
implements more than 30 classical and modern neuroevolution strategies. All of them follow the same simpleask
/eval
API and come with tailored tools such as the ClipUp optimizer, parameter reshaping into PyTrees and fitness shaping (see below). -
Vectorization/Parallelization of
ask
/tell
Calls: Bothask
andtell
calls can leveragejit
,vmap
/pmap
. This enables vectorized/parallel rollouts of different evolution strategies.
from evosax.strategies.ars import ARS, EvoParams
# E.g. vectorize over different initial perturbation stds
strategy = ARS(popsize=100, num_dims=20)
es_params = EvoParams(sigma_init=jnp.array([0.1, 0.01, 0.001]), sigma_decay=0.999, ...)
# Specify how to map over ES hyperparameters
map_dict = EvoParams(sigma_init=0, sigma_decay=None, ...)
# Vmap-composed batch initialize, ask and tell functions
batch_init = jax.vmap(strategy.init, in_axes=(None, map_dict))
batch_ask = jax.vmap(strategy.ask, in_axes=(None, 0, map_dict))
batch_tell = jax.vmap(strategy.tell, in_axes=(0, 0, 0, map_dict))
- Scan Through Evolution Rollouts: You can also
lax.scan
through entireinit
,ask
,eval
,tell
loops for fast compilation of ES loops:
@partial(jax.jit, static_argnums=(1,))
def run_es_loop(rng, num_steps):
"""Run evolution ask-eval-tell loop."""
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
def es_step(state_input, tmp):
"""Helper es step to lax.scan through."""
rng, state = state_input
rng, rng_iter = jax.random.split(rng)
x, state = strategy.ask(rng_iter, state, es_params)
fitness = ...
state = strategy.tell(y, fitness, state, es_params)
return [rng, state], fitness[jnp.argmin(fitness)]
_, scan_out = jax.lax.scan(es_step,
[rng, state],
[jnp.zeros(num_steps)])
return jnp.min(scan_out)
- Population Parameter Reshaping: We provide a
ParamaterReshaper
wrapper to reshape flat parameter vectors into PyTrees. The wrapper is compatible with JAX neural network libraries such as Flax/Haiku and makes it easier to afterwards evaluate network populations.
from flax import linen as nn
from evosax import ParameterReshaper
class MLP(nn.Module):
num_hidden_units: int
...
@nn.compact
def __call__(self, obs):
...
return ...
network = MLP(64)
net_params = network.init(rng, jnp.zeros(4,), rng)
# Initialize reshaper based on placeholder network shapes
param_reshaper = ParameterReshaper(net_params)
# Get population candidates & reshape into stacked pytrees
x = strategy.ask(...)
x_shaped = param_reshaper.reshape(x)
- Flexible Fitness Shaping: By default
evosax
assumes that the fitness objective is to be minimized. If you would like to maximize instead, perform rank centering, z-scoring or add weight regularization you can use theFitnessShaper
:
from evosax import FitnessShaper
# Instantiate jittable fitness shaper (e.g. for Open ES)
fit_shaper = FitnessShaper(centered_rank=True,
z_score=False,
weight_decay=0.01,
maximize=True)
# Shape the evaluated fitness scores
fit_shaped = fit_shaper.apply(x, fitness)
Additonal Work-In-Progress
**Strategy Restart Wrappers**: *Work-in-progress*. You can also choose from a set of different restart mechanisms, which will relaunch a strategy (with e.g. new population size) based on termination criteria. Note: For all restart strategies which alter the population size the ask and tell methods will have to be re-compiled at the time of change. Note that all strategies can also be executed without explicitly providing `es_params`. In this case the default parameters will be used.```Python
from evosax import CMA_ES
from evosax.restarts import BIPOP_Restarter
# Define a termination criterion (kwargs - fitness, state, params)
def std_criterion(fitness, state, params):
"""Restart strategy if fitness std across population is small."""
return fitness.std() < 0.001
# Instantiate Base CMA-ES & wrap with BIPOP restarts
# Pass strategy-specific kwargs separately (e.g. elite_ration or opt_name)
strategy = CMA_ES(num_dims, popsize, elite_ratio)
re_strategy = BIPOP_Restarter(
strategy,
stop_criteria=[std_criterion],
strategy_kwargs={"elite_ratio": elite_ratio}
)
state = re_strategy.initialize(rng)
# ask/tell loop - restarts are automatically handled
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = re_strategy.ask(rng_gen, state)
fitness = ... # Your population evaluation fct
state = re_strategy.tell(x, fitness, state)
```
- **Batch Strategy Rollouts**: *Work-in-progress*. We are currently also working on different ways of incorporating multiple subpopulations with different communication protocols.
```Python
from evosax.experimental.subpops import BatchStrategy
# Instantiates 5 CMA-ES subpops of 20 members
strategy = BatchStrategy(
strategy_name="CMA_ES",
num_dims=4096,
popsize=100,
num_subpops=5,
strategy_kwargs={"elite_ratio": 0.5},
communication="best_subpop",
)
state = strategy.initialize(rng)
# Ask for evaluation candidates of different subpopulation ES
x, state = strategy.ask(rng_iter, state)
fitness = ...
state = strategy.tell(x, fitness, state)
```
- **Indirect Encodings**: *Work-in-progress*. ES can struggle with high-dimensional search spaces (e.g. due to harder estimation of covariances). One potential way to alleviate this challenge, is to use indirect parameter encodings in a lower dimensional space. So far we provide JAX-compatible encodings with random projections (Gaussian/Rademacher) and Hypernetworks for MLPs. They act as drop-in replacements for the `ParameterReshaper`:
```Python
from evosax.experimental.decodings import RandomDecoder, HyperDecoder
# For arbitrary network architectures / search spaces
num_encoding_dims = 6
param_reshaper = RandomDecoder(num_encoding_dims, net_params)
x_shaped = param_reshaper.reshape(x)
# For MLP-based models we also support a HyperNetwork en/decoding
reshaper = HyperDecoder(
net_params,
hypernet_config={
"num_latent_units": 3, # Latent units per module kernel/bias
"num_hidden_units": 2, # Hidden dimensionality of a_i^j embedding
},
)
x_shaped = param_reshaper.reshape(x)
```
- 📺 Rob's MLC Research Jam Talk: Small motivation talk at the ML Collective Research Jam.
- 📝 Rob's 02/2021 Blog: Tutorial on CMA-ES & leveraging JAX's primitives.
- 💻 Evojax: JAX-ES library by Google Brain with great rollout wrappers.
- 💻 QDax: Quality-Diversity algorithms in JAX.
If you use evosax
in your research, please cite the following paper:
@article{evosax2022github,
author = {Robert Tjarko Lange},
title = {evosax: JAX-based Evolution Strategies},
journal={arXiv preprint arXiv:2212.04180},
year = {2022},
}
We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.
You can run the test suite via python -m pytest -vv --all
. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.