Nonlinear optimisation in JAX and Equinox.
Project description
Optimistix
Optimistix is a JAX library for nonlinear solvers: root finding, minimisation, fixed points, and least squares.
Features include:
- interoperable solvers: e.g. autoconvert root find problems to least squares problems, then solve using a minimisation algorithm.
- modular optimisers: e.g. use a BFGS quadratic bowl with a dogleg descent path with a trust region update.
- using a PyTree as the state.
- fast compilation and runtimes.
- interoperability with Optax.
- all the benefits of working with JAX: autodiff, autoparallism, GPU/TPU support etc.
Installation
pip install optimistix
Requires Python 3.9+ and JAX 0.4.14+ and Equinox 0.11.0+.
Documentation
Available at https://docs.kidger.site/optimistix.
Quick example
import jax.numpy as jnp
import optimistix as optx
# Let's solve the ODE dy/dt=tanh(y(t)) with the implicit Euler method.
# We need to find y1 s.t. y1 = y0 + tanh(y1)dt.
y0 = jnp.array(1.)
dt = jnp.array(0.1)
def fn(y, args):
return y0 + jnp.tanh(y) * dt
solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.fixed_point(fn, solver, y0)
y1 = sol.value # satisfies y1 == fn(y1)
Finally
JAX ecosystem
jaxtyping: type annotations for shape/dtype of arrays.
Equinox: neural networks.
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Diffrax: numerical differential equation solvers.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
Orbax: checkpointing (async/multi-host/multi-device).
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
Eqxvision: computer vision models.
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
PySR: symbolic regression. (Non-JAX honourable mention!)
Disclaimer
This is not an official Google product.
Credit
Optimistix was primarily built by Jason Rader (@packquickly): Twitter; GitHub; Website.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for optimistix-0.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8d8378898ba9039f0e47fc8b789029b249b3bc3d1ed4b23929661d8d940ff08c |
|
MD5 | 371b1f6cd558b5c567104221c3b7c6ba |
|
BLAKE2b-256 | b53b7126190367af452581d26121fb30c6faa17c9dda4bf889f378b2b362d7c2 |