Skip to main content

Accelerated gridworld navigation with JAX for deep reinforcement learning

Project description

NAVIX

Project Status: WIP – Initial development is in progress, but there has not yet been a stable, usable release suitable for the public. CI CD GitHub release (latest by date)

Quickstart | Installation | Examples | Cite

What is NAVIX?

NAVIX is minigrid in JAX, >10000x faster with Autograd and XLA support. You can see a superficial performance comparison here.

Installation

We currently support the OSs supported by JAX. You can find a description here.

You might want to follow the same guide to install jax for your faviourite accelerator (e.g. CPU, GPU, or TPU ).

Then, install navix and its dependencies with:

pip install navix

Examples

XLA compilation

One straightforward use case is to accelerate the computation of the environment with XLA compilation. For example, here we vectorise the environment to run multiple environments in parallel, and compile the full training run.

You can find a partial performance comparison with minigrid in the docs.

import jax
import navix as nx


def run(seed)
  env = nx.environments.Room(16, 16, 8)
  key = jax.random.PRNGKey(seed)
  timestep = env.reset(key)
  actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6)

  def body_fun(timestep, action):
      timestep = env.step(timestep, jnp.asarray(action))
      return timestep, ()

  return jax.lax.scan(body_fun, timestep, jnp.asarray(actions, dtype=jnp.int32))[0]

final_timestep = jax.jit(jax.vmap(run))(jax.numpy.arange(1000))

Backpropagation through the environment

Another use case it to backpropagate through the environment transition function, for example to learn a world model.

TODO(epignatelli): add example.

Cite

If you use helx please consider citing it as:

@misc{pignatelli2023navix,
  author = {Pignatelli, Eduardo},
  title = {Navix: Accelerated gridworld navigation with JAX},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/epignatelli/navix}}
  }

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

Navix-0.2.0.tar.gz (20.3 kB view hashes)

Uploaded Source

Built Distribution

Navix-0.2.0-py2.py3-none-any.whl (24.6 kB view hashes)

Uploaded Python 2 Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page