Skip to main content

Signax: Signature computation in JAX

Project description

Signax: Computing signatures in JAX

CI

Goal

To have a library that supports signature computation in JAX. See this paper to see how to adopt signatures in machine learning.

This implementation is inspired by patrick-kidger/signatory.

Examples

Basic usage

import jax
import jax.random as jrandom

from signax.signature import signature

key = jrandom.PRNGKey(0)
depth = 3

# compute signature for a single path
length = 100
dim = 20
path = jrandom.normal(shape=(length, dim), key=key)
output = signature(path, depth)
# output is a list of array presenting tensor algebra

# compute signature for batches (multiple) of paths
# this is done via `jax.vmap`
batch_size = 20
path = jrandom.normal(shape=(batch_size, length, dim), key=key)
output = jax.vmap(lambda x: signature(x, depth))(path)

Integrate with equinox library

import equinox as eqx
import jax.random as jrandom

from signax.module import SignatureTransform

# random generator key
key = jrandom.PRNGKey(0)
mlp_key, data_key = jrandom.split(key)

depth=3
length, dim = 100, 3

# we signature transfrom
signature_layer = SignatureTransform(depth=depth)
# finally, getting output via a neural network 
last_layer = eqx.nn.MLP(depth=1, 
                        in_size=3 + 3**2 + 3**3,
                        width_size=4, 
                        out_size=1,
                        key=mlp_key)

model = eqx.nn.Sequential(layers=[signature_layer, last_layer])
x = jrandom.normal(shape=(length, dim), key=data_key)
output = model(x)

Also, check notebooks in examples folder for some experiments of deep signature transforms paper.

Installation

git clone https://github.com/anh-tong/signax.git
cd signax
python setup.py install .

Parallelism

This implementation makes use of jax.vmap to perform the parallelism over batch dimension.

Signatory allows dividing a path into chunks and performing asynchronous multithread computation over chunks.

Why is using pure JAX good enough?

Because JAX make use of just-in-time (JIT) compilations with XLA, this implementation can be reasonably fast.

We observe that the performance of this implementation is similar to Signatory in CPU and slightly better in GPU. It could be because of the optimized operators of XLA in JAX. As mentioned in the paper, signatory is not fully optimized for CUDA but relies on LibTorch.

Acknowledgement

This repo is based on

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

signax-0.1.0.tar.gz (10.5 kB view hashes)

Uploaded Source

Built Distribution

signax-0.1.0-py3-none-any.whl (9.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page