Jit-able runtime assertions for jax in numpy style.
Project description
testax provides runtime assertions for JAX through the testing interface familiar to NumPy users.
>>> import jax >>> from jax import numpy as jnp >>> import testax >>> >>> def safe_log(x): ... testax.assert_array_less(0, x) ... return jnp.log(x) >>> >>> safe_log(jnp.arange(2)) Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: Arrays are not less-ordered <BLANKLINE> Mismatched elements: 1 / 2 (50%) Max absolute difference: 1 Max relative difference: 1 x: Array(0, dtype=int32, weak_type=True) y: Array([0, 1], dtype=int32)
testax assertions are jit
-able, although errors need to be functionalized to conform to JAX’s requirement that functions are pure and do not have side effects (see the checkify
guide for details). In short, a checkify
-d function returns a tuple (error, value)
. The first element is an error that may have occurred, and the second is the return value of the original function.
>>> jitted = jax.jit(safe_log) >>> checkified = testax.checkify(jitted) >>> error, y = checkified(jnp.arange(2)) >>> error.throw() Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: Arrays are not less-ordered <BLANKLINE> Mismatched elements: 1 / 2 (50%) Max absolute difference: 1 Max relative difference: 1 x: Array(0, dtype=int32, weak_type=True) y: Array([0, 1], dtype=int32) >>> y Array([-inf, 0.], dtype=float32)
Installation
testax is pip-installable and can be installed by running
pip install testax
Interface
testax mirrors the testing interface familiar to NumPy users, such as assert_allclose
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.