Skip to main content

JAX-Based Evolution Strategies

Project description

evosax: JAX-Based Evolution Strategies 🦎

Pyversions PyPI version Code style: black codecov

Tired of having to handle asynchronous processes for neuroevolution? Do you want to leverage massive vectorization and high-throughput accelerators for evolution strategies (ES)? evosax allows you to leverage JAX, XLA compilation and auto-vectorization/parallelization to scale ES to your favorite accelerators. The API is based on the classical ask, evaluate, tell cycle of ES. Both ask and tell calls are compatible with jit, vmap/pmap and lax.scan. It includes a vast set of both classic (e.g. CMA-ES, Differential Evolution, etc.) and modern neuroevolution (e.g. OpenAI-ES, Augmented RS, etc.) strategies. You can get started here 👉 Colab

Basic evosax API Usage 🍲

import jax
from evosax import CMA_ES

# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
params = strategy.default_params
state = strategy.initialize(rng, params)

# Run ask-eval-tell loop - NOTE: By default minimization!
for t in range(num_generations):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, params)
    fitness = ...  # Your population evaluation fct 
    state = strategy.tell(x, fitness, state, params)

# Get best overall population member & its fitness
state["best_member"], state["best_fitness"]

Implemented Evolution Strategies 🦎

Strategy Reference Import Example
OpenAI-ES Salimans et al. (2017) Open_ES Colab
PGPE Sehnke et al. (2010) PGPE_ES Colab
ARS Mania et al. (2018) Augmented_RS Colab
CMA-ES Hansen (2016) CMA_ES Colab
Simple Gaussian Rechenberg (1975) Simple_ES Colab
Simple Genetic Such et al. (2017) Simple_GA Colab
x-NES Wierstra et al. (2014) xNES Colab
Particle Swarm Optimization Kennedy & Eberhart (1995) PSO_ES Colab
Differential ES Storn & Price (1997) Differential_ES Colab
Persistent ES Vicol et al. (2021) Persistent_ES Colab
Population-Based Training Jaderberg et al. (2017) PBT_ES Colab

Installation ⏳

The latest evosax release can directly be installed from PyPI:

pip install evosax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/evosax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples 📖

Key Selling Points 💵

  • Strategy Diversity: evosax implements more than 10 classical and modern neuroevolution strategies. All of them follow the same simple ask/eval API and come with tailored tools such as the ClipUp optimizer, parameter reshaping into PyTrees and fitness shaping (see below).

  • Vectorization/Parallelization of ask/tell Calls: Both ask and tell calls can leverage jit, vmap/pmap. This enables vectorized/parallel rollouts of different evolution strategies.

from evosax import Augmented_RS
# E.g. vectorize over different lrate decays
strategy = Augmented_RS(popsize=100, num_dims=20)
es_params = {
    "lrate_decay": jnp.array([0.999, 0.99, 0.9]),
    ...
}
map_dict = {
    "lrate_decay": 0,
    ...
}

# Vmap-composed batch initialize, ask and tell functions 
batch_init = jax.vmap(strategy.init, in_axes=(None, map_dict))
batch_ask = jax.vmap(strategy.ask, in_axes=(None, 0, map_dict))
batch_tell = jax.vmap(strategy.tell, in_axes=(0, 0, 0, map_dict))
  • Scan Through Evolution Rollouts: You can also lax.scan through entire init, ask, eval, tell loops for fast compilation of ES loops:
@partial(jax.jit, static_argnums=(1,))
def run_es_loop(rng, num_steps):
    """Run evolution ask-eval-tell loop."""
    es_params = strategy.default_params
    state = strategy.initialize(rng, es_params)

    def es_step(state_input, tmp):
        """Helper es step to lax.scan through."""
        rng, state = state_input
        rng, rng_iter = jax.random.split(rng)
        x, state = strategy.ask(rng_iter, state, es_params)
        fitness = ...
        state = strategy.tell(y, fitness, state, es_params)
        return [rng, state], fitness[jnp.argmin(fitness)]

    _, scan_out = jax.lax.scan(es_step,
                               [rng, state],
                               [jnp.zeros(num_steps)])
    return jnp.min(scan_out)
  • Population Parameter Reshaping: We provide a ParamaterReshaper wrapper to reshape flat parameter vectors into PyTrees. The wrapper is compatible with JAX neural network libraries such as Flax/Haiku and makes it easier to afterwards evaluate network populations.
from flax import linen as nn
from evosax import ParameterReshaper

class MLP(nn.Module):
    num_hidden_units: int
    ...

    @nn.compact
    def __call__(self, obs):
        ...
        return ...

network = MLP(64)
policy_params = network.init(rng, jnp.zeros(4,), rng)

# Initialize reshaper based on placeholder network shapes
param_reshaper = ParameterReshaper(policy_params["params"])

# Get population candidates & reshape into stacked pytrees
x = strategy.ask(...)
x_shaped = param_reshaper.reshape(x)
  • Flexible Fitness Shaping: By default evosax assumes that the fitness objective is to be minimized. If you would like to maximize instead, perform rank centering, z-scoring or add weight regularization you can use the FitnessShaper:
from evosax import FitnessShaper

# Instantiate jittable fitness shaper
fit_shaper = FitnessShaper(centered_rank=True,
                           z_score=True,
                           weight_decay=0.01,
                           maximize=True)

# Shape the evaluated fitness scores
fit_shaped = fit_shaper.apply(x, fitness) 

References & Other Great JAX-ES Tools 📝

  • 💻 Evojax: JAX-ES library by Google Brain with great rollout wrappers.
  • 💻 QDax: Quality-Diversity algorithms in JAX.
  • 💻 Rob's Blog: Tutorial on CMA-ES & leveraging JAX's primitives.

Development 👷

You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

evosax-0.0.3.tar.gz (36.9 kB view hashes)

Uploaded Source

Built Distribution

evosax-0.0.3-py3-none-any.whl (47.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page