JAX-Based Evolution Strategies
Project description
evosax
: JAX-Based Evolution Strategies 🦎
Tired of having to handle asynchronous processes for neuroevolution? Do you want to leverage massive vectorization and high-throughput accelerators for evolution strategies (ES)? evosax
allows you to leverage JAX, XLA compilation and auto-vectorization/parallelization to scale ES to your favorite accelerators. The API is based on the classical ask
, evaluate
, tell
cycle of ES. Both ask
and tell
calls are compatible with jit
, vmap
/pmap
and lax.scan
. It includes a vast set of both classic (e.g. CMA-ES, Differential Evolution, etc.) and modern neuroevolution (e.g. OpenAI-ES, Augmented RS, etc.) strategies. You can get started here 👉
Basic evosax
API Usage 🍲
import jax
from evosax import CMA_ES
# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
params = strategy.default_params
state = strategy.initialize(rng, params)
# Run ask-eval-tell loop - NOTE: By default minimization!
for t in range(num_generations):
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = strategy.ask(rng_gen, state, params)
fitness = ... # Your population evaluation fct
state = strategy.tell(x, fitness, state, params)
# Get best overall population member & its fitness
state["best_member"], state["best_fitness"]
Implemented Evolution Strategies 🦎
Strategy | Reference | Import | Example |
---|---|---|---|
OpenAI-ES | Salimans et al. (2017) | Open_ES |
|
PGPE | Sehnke et al. (2010) | PGPE_ES |
|
ARS | Mania et al. (2018) | Augmented_RS |
|
CMA-ES | Hansen (2016) | CMA_ES |
|
Simple Gaussian | Rechenberg (1975) | Simple_ES |
|
Simple Genetic | Such et al. (2017) | Simple_GA |
|
x-NES | Wierstra et al. (2014) | xNES |
|
Particle Swarm Optimization | Kennedy & Eberhart (1995) | PSO_ES |
|
Differential ES | Storn & Price (1997) | Differential_ES |
|
Persistent ES | Vicol et al. (2021) | Persistent_ES |
|
Population-Based Training | Jaderberg et al. (2017) | PBT_ES |
Installation ⏳
The latest evosax
release can directly be installed from PyPI:
pip install evosax
If you want to get the most recent commit, please install directly from the repository:
pip install git+https://github.com/RobertTLange/evosax.git@main
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
Examples 📖
- 📓 Classic ES Tasks: API introduction on Rosenbrock function (CMA-ES, Simple GA, etc.).
- 📓 CartPole-Control: OpenES & PEPG on the
CartPole-v1
gym task (MLP/LSTM controller). - 📓 MNIST-Classifier: OpenES on MNIST with CNN network.
- 📓 LRateTune-PES: Persistent ES on meta-learning problem as in Vicol et al. (2021).
- 📓 Quadratic-PBT: PBT on toy quadratic problem as in Jaderberg et al. (2017).
Key Selling Points 💵
-
Strategy Diversity:
evosax
implements more than 10 classical and modern neuroevolution strategies. All of them follow the same simpleask
/eval
API and come with tailored tools such as the ClipUp optimizer, parameter reshaping into PyTrees and fitness shaping (see below). -
Vectorization/Parallelization of
ask
/tell
Calls: Bothask
andtell
calls can leveragejit
,vmap
/pmap
. This enables vectorized/parallel rollouts of different evolution strategies.
from evosax import Augmented_RS
# E.g. vectorize over different lrate decays
strategy = Augmented_RS(popsize=100, num_dims=20)
es_params = {
"lrate_decay": jnp.array([0.999, 0.99, 0.9]),
...
}
map_dict = {
"lrate_decay": 0,
...
}
# Vmap-composed batch initialize, ask and tell functions
batch_init = jax.vmap(strategy.init, in_axes=(None, map_dict))
batch_ask = jax.vmap(strategy.ask, in_axes=(None, 0, map_dict))
batch_tell = jax.vmap(strategy.tell, in_axes=(0, 0, 0, map_dict))
- Scan Through Evolution Rollouts: You can also
lax.scan
through entireinit
,ask
,eval
,tell
loops for fast compilation of ES loops:
@partial(jax.jit, static_argnums=(1,))
def run_es_loop(rng, num_steps):
"""Run evolution ask-eval-tell loop."""
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
def es_step(state_input, tmp):
"""Helper es step to lax.scan through."""
rng, state = state_input
rng, rng_iter = jax.random.split(rng)
x, state = strategy.ask(rng_iter, state, es_params)
fitness = ...
state = strategy.tell(y, fitness, state, es_params)
return [rng, state], fitness[jnp.argmin(fitness)]
_, scan_out = jax.lax.scan(es_step,
[rng, state],
[jnp.zeros(num_steps)])
return jnp.min(scan_out)
- Population Parameter Reshaping: We provide a
ParamaterReshaper
wrapper to reshape flat parameter vectors into PyTrees. The wrapper is compatible with JAX neural network libraries such as Flax/Haiku and makes it easier to afterwards evaluate network populations.
from flax import linen as nn
from evosax import ParameterReshaper
class MLP(nn.Module):
num_hidden_units: int
...
@nn.compact
def __call__(self, obs):
...
return ...
network = MLP(64)
policy_params = network.init(rng, jnp.zeros(4,), rng)
# Initialize reshaper based on placeholder network shapes
param_reshaper = ParameterReshaper(policy_params["params"])
# Get population candidates & reshape into stacked pytrees
x = strategy.ask(...)
x_shaped = param_reshaper.reshape(x)
- Flexible Fitness Shaping: By default
evosax
assumes that the fitness objective is to be minimized. If you would like to maximize instead, perform rank centering, z-scoring or add weight regularization you can use theFitnessShaper
:
from evosax import FitnessShaper
# Instantiate jittable fitness shaper
fit_shaper = FitnessShaper(centered_rank=True,
z_score=True,
weight_decay=0.01,
maximize=True)
# Shape the evaluated fitness scores
fit_shaped = fit_shaper.apply(x, fitness)
References & Other Great JAX-ES Tools 📝
- 💻 Evojax: JAX-ES library by Google Brain with great rollout wrappers.
- 💻 QDax: Quality-Diversity algorithms in JAX.
- 💻 Rob's Blog: Tutorial on CMA-ES & leveraging JAX's primitives.
Development 👷
You can run the test suite via python -m pytest -vv --all
. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.