Signax: Signature computation in JAX
Project description
Signax: Computing signatures in JAX
Goal
To have a library that supports signature computation in JAX. See this paper to see how to adopt signatures in machine learning.
This implementation is inspired by patrick-kidger/signatory.
Examples
Basic usage
import jax
import jax.random as jrandom
from signax.signature import signature
key = jrandom.PRNGKey(0)
depth = 3
# compute signature for a single path
length = 100
dim = 20
path = jrandom.normal(shape=(length, dim), key=key)
output = signature(path, depth)
# output is a list of array presenting tensor algebra
# compute signature for batches (multiple) of paths
# this is done via `jax.vmap`
batch_size = 20
path = jrandom.normal(shape=(batch_size, length, dim), key=key)
output = jax.vmap(lambda x: signature(x, depth))(path)
Integrate with equinox library
import equinox as eqx
import jax.random as jrandom
from signax.module import SignatureTransform
# random generator key
key = jrandom.PRNGKey(0)
mlp_key, data_key = jrandom.split(key)
depth=3
length, dim = 100, 3
# we signature transfrom
signature_layer = SignatureTransform(depth=depth)
# finally, getting output via a neural network
last_layer = eqx.nn.MLP(depth=1,
in_size=3 + 3**2 + 3**3,
width_size=4,
out_size=1,
key=mlp_key)
model = eqx.nn.Sequential(layers=[signature_layer, last_layer])
x = jrandom.normal(shape=(length, dim), key=data_key)
output = model(x)
Also, check notebooks in examples
folder for some experiments of deep signature transforms paper.
Installation
git clone https://github.com/anh-tong/signax.git
cd signax
python setup.py install .
Parallelism
This implementation makes use of jax.vmap
to perform the parallelism over batch dimension.
Signatory allows dividing a path into chunks and performing asynchronous multithread computation over chunks.
Why is using pure JAX good enough?
Because JAX make use of just-in-time (JIT) compilations with XLA, this implementation can be reasonably fast.
We observe that the performance of this implementation is similar to Signatory in CPU and slightly better in GPU. It could be because of the optimized operators of XLA in JAX. As mentioned in the paper, signatory is not fully optimized for CUDA but relies on LibTorch.
Acknowledgement
This repo is based on
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.