Simple state handling for JAX
Project description
🥷 Ninjax
Ninjax brings the flexibility of PyTorch and TensorFlow 2 to JAX. Ninjax is a simple state manager for JAX that makes it easy to have nested components that update their own state. It's intended to be used together with a neural network library, such as Flax or Haiku.
Installation
Ninjax is a single file, so you can just copy it to your project directory. Or you can install the package:
pip install ninjax
Quickstart
import haiku as hk
import jax
import jax.numpy as jnp
import ninjax as nj
class Model(nj.Module):
def __init__(self, size, act=jax.nn.relu):
self.size = size
self.act = act
self.h1 = nj.HaikuModule(hk.Linear, 128)
self.h2 = nj.HaikuModule(hk.Linear, 128)
self.h3 = nj.HaikuModule(hk.Linear, size)
def __call__(self, x):
x = self.act(self.h1(x))
x = self.act(self.h2(x))
x = self.h3(x)
return x
def train(self, x, y):
self(x) # Create weights needed for gradient.
loss, grad = nj.grad(self.loss, [self.h1, self.h2, self.h3])(x, y)
state = jax.tree_map(lambda p, g: p - 0.01 * g, state, grad)
self.update(state)
return loss
def loss(self, x, y):
return ((self(x) - y) ** 2).mean()
model = Model(8)
main = jax.random.PRNGKey(0)
state = {}
for x, y in dataset:
rng, main = jax.random.split(main)
state, loss = nj.run(model.train, state, rng, x, y)
print('Loss:', float(loss))
How To
How can I use JIT compilation?
The nj.run()
function makes the state your JAX code uses explicit, so it can
be jitted and transformed freely:
model = Model()
train = jax.jit(functools.partial(nj.run, model.train))
train(state, rng, ...)
How can I compute gradients?
You can use jax.grad
as normal for computing gradients with respect to
explicit inputs of your function. To compute gradients with respect to Ninjax
state, use nj.grad(fn, keys)
:
class Module(nj.Module):
def train(self, x, y):
params = self.state()
loss, grads = nj.grad(self.loss, params.keys())(x, y)
params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
self.update(params)
The self.state(filter)
method optionally accepts a regex pattern to select
only a subset of the state dictionary. It also returns only state entries of
the current module. To access the global state, use nj.state()
.
How can I define modules compactly?
You can use self.get(name, ctor, *args, **kwargs)
inside methods of your
modules. When called for the first time, it creates a new state entry from the
constructor ctor(*args, **kwargs)
. Later calls return the existing entry:
class Module(nj.Module):
def __call__(self, x):
x = jax.nn.relu(self.get('h1', Linear, 128)(x))
x = jax.nn.relu(self.get('h2', Linear, 128)(x))
x = self.get('h3', Linear, 32)(x)
return x
How can I use Haiku modules?
Haiku requires its modules to be passed through hk.transform
and the
initialized via transformed.init(rng, batch)
. Ninjax provides
nj.HaikuModule
to do this for you:
class Module(nj.Module):
def __init__(self):
self.mlp = nj.HaikuModule(hk.nets.MLP, [128, 128, 32])
def __call__(self, x):
return self.mlp(x)
You can also predefine a list of aliases for Haiku modules that you want to use frequently:
Linear = functools.partial(nj.HaikuModule, hk.Linear)
Conv2D = functools.partial(nj.HaikuModule, hk.Conv2D)
MLP = functools.partial(nj.HaikuModule, hk.nets.MLP)
# ...
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.