JAX-Based Evolution Strategies
Project description
evosax
: JAX-Based Evolution Strategies
Tired of having to handle asynchronous processes for neuroevolution? Do you want to leverage massive vectorization and high-throughput accelerators for evolution strategies (ES)? evosax
allows you to leverage JAX, XLA compilation and autovectorization to scale ES to your favorite accelerators. The API follows the classical ask
, evaluate
, tell
cycle of ES and only requires you to vmap
and pmap
over the fitness function axes of choice. It includes popular strategies such as Simple Gaussian, CMA-ES, and different NES variants.
Basic evosax
API Usage 🍲
import jax
from evosax import CMA_ES
from evosax.problems import batch_rosenbrock
# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
params = strategy.default_params
state = strategy.initialize(rng, params)
# Run the ask-eval-tell loop
for t in range(num_generations):
rng, rng_gen = jax.random.split(rng)
x, state = strategy.ask(rng_gen, state, params)
fitness = batch_rosenbrock(x, 1, 100)
state = strategy.tell(x, fitness, state, params)
# Get best overall population member & its fitness
state["best_member"], state["best_fitness"]
Implemented Evolution Strategies 🦎
Strategy | Reference | Import | Example |
---|---|---|---|
CMA-ES | Hansen (2016) | CMA_ES |
Pendulum RL task |
Differential ES | Storn & Price (1997) | Differential_ES |
- |
OpenAI-ES | Salimans et al. (2017) | Open_NES |
Simple Quadratic |
Particle Swarm Optimization | Kennedy & Eberhart (1995) | PSO_ES |
- |
PEPG | Sehnke et al. (2009) | PEPG_ES |
- |
Persistent ES | Vicol et al. (2021) | Persistent_ES |
- |
Population-Based Training | Jaderberg et al. (2017) | PBT_ES |
- |
Simple Gaussian | ❓ | Simple_ES |
Low Dim. optimisation |
Simple Genetic | Such et al. (2017) | Simple_GA |
Low Dim. optimisation |
x-NES | Wierstra et al. (2014) | xNES |
- |
To Be Completed
Strategy | Reference | Import | Example |
---|---|---|---|
IPOP/BIPOP/SEP | - | 🚉 | - |
NSLC | Lehman & Stanley (2011) | 🚉 | - |
MAP-Elites | Mouret & Clune (2015) | 🚉 | - |
CMA-ME | Fontaine et al. (2020) | 🚉 | - |
Installation ⏳
evosax
can directly be installed from PyPi.
pip install evosax
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
Examples 📖
- 📖 Blog post: Walk through of CMA-ES and how to leverage JAX's primitives
- 📓 Low-dim. Optimisation: Simple Gaussian strategy on Rosenbrock function
- 📓 MLP-Pendulum-Control: CMA-ES on the
Pendulum-v0
gym task. - 📓 CNN-MNIST-Classifier: Open AI NES on MNIST-CNN.
- 📓 RNN-Meta-Bandit: CMA-ES on an LSTM evolved to learn on bandit tasks.
Contributing & Development 🧑🤝🧑
Feel free to ping me (@RobertTLange), open an issue or start contributing yourself.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.