Skip to main content

Multi-Agent Reinforcement Learning with JAX

Project description

JaxMARL

Installation | Quick Start | Environments | Algorithms | Citation

Overcooked MPE STORM SMAX

Multi-Agent Reinforcement Learning in JAX

JaxMARL combines ease-of-use with GPU enabled efficiency, and supports a wide range of commonly used MARL environments as well as popular baseline algorithms. Our aim is for one library that enables thorough evaluation of MARL methods across a wide range of tasks and against relevant baselines. We also introduce SMAX, a vectorised, simplifed version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine.

For more details, take a look at our blog post or this notebook walks through the basic usage. LINKS TODO

Environments 🌍

Environment Reference README Summary
🔴 MPE Paper Source Communication orientated tasks in a multi-agent particle world
🍲 Overcooked Paper Source Fully-cooperative human-AI coordination tasks based on the homonyms video game
🦾 Multi-Agent Brax Paper Source Continuous multi-agent robotic control based on Brax, analagous to Multi-Agent MuJoCo
🎆 Hanabi Paper Source Fully-cooperative partially-observable multiplayer card game
👾 SMAX Novel Source Simplifed cooperative StarCraft micro-management environment
🧮 STORM: Spatial-Temporal Representations of Matrix Games Paper Source Matrix games represented as grid world scenarios
🪙 Coin Game Paper Source Two-player grid world environment which emulates social dilemmas
💡 Switch Riddle Paper Source Simple cooperative communication game included for debugging

Baseline Algorithms 🦉

We follow CleanRL's philosophy of providing single file implementations which can be found within the baselines directory.

Algorithm Reference README
IPPO Paper Source
MAPPO Paper Source
IQL Paper Source
VDN Paper Source
QMIX Paper Source

Installation 🧗

Before installing, ensure you have the correct JAX version for your hardware accelerator. JaxMARL can then be installed directly from PyPi:

pip install jaxmarl  -- NOTE THIS DOES NOT WORK YET USE: pip install -e .

We have tested JaxMARL on Python 3.8 and 3.9. To run our test scripts, some additional dependencies are required (for comparisons against existing implementations), these can be installed with:

pip install jaxmarl[dev]

Quick Start 🚀

We take inspiration from the PettingZoo and Gymnax interfaces. You can try out training an agent on XX in this Colab TODO. Further introduction scripts can be found here.

Basic JaxMARL API Usage 🖥️

Actions, observations, rewards and done values are passed as dictionaries keyed by agent name, allowing for differing action and observation spaces. The done dictionary contains an additional "__all__" key, specifying whether the episode has ended. We follow a parallel structure, with each agent passing an action at each timestep. For ascyhronous games, such as Hanabi, a dummy action is passed for agents not acting at a given timestep.

import jax
from jaxmarl import make

key = jax.random.PRNGKey(0)
key, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Initialise environment.
env = make('MPE_simple_world_comm_v3')

# Reset the environment.
obs, state = env.reset(key_reset)

# Sample random actions.
key_act = jax.random.split(key_act, env.num_agents)
actions = {agent: env.action_space(agent).sample(key_act[i]) for i, agent in enumerate(env.agents)}

# Perform the step transition.
obs, state, reward, done, infos = env.step(key_step, state, actions)

Contributing 🔨

Please contribute! Please take a look at our contributing guide for how to add an environment/algorithm or submit a bug report.

Citing JaxMARL 📜

If you use JaxMARL in your work, please cite us as follows:
TODO

See Also 🙌

There are a number of other libraries which inspired this work, we encourage you to take a look!

JAX-native algorithms:

  • Mava: JAX implementations of IPPO and MAPPO, two popular MARL algorithms.
  • PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.

JAX-native envrionments:

  • Gymnax: Implementations of classic RL tasks including classic control, bsuite and MinAtar.
  • Jumanji: A diverse set of environments ranging from simple games to NP-hard combinatoral problems.
  • Pgx: JAX implementations of classic board games, such as Chess, Go and Shogi.
  • Brax: A fully differentiable physics engine written in JAX, features continuous control tasks.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxmarl-0.0.1.tar.gz (137.7 kB view hashes)

Uploaded Source

Built Distribution

jaxmarl-0.0.1-py3-none-any.whl (130.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page