Simulating quantum circuits with JAX
Project description
qujax
Represent a (parameterised) quantum circuit as a pure JAX function that takes as input any parameters of the circuit and outputs a statetensor. The statetensor encodes all $2^N$ amplitudes of the quantum state and can then be used downstream for exact expectations, gradients or sampling.
qujax also supports densitytensor simulations. A densitytensor is a tensor representation of the density matrix and allows for mixed states and generic Kraus operators.
A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support for GPUs/TPUs.
Some useful links:
Install
pip install qujax
Parameterised quantum circuits with qujax
from jax import numpy as jnp
import qujax
circuit_gates = ['H', 'Ry', 'CZ']
circuit_qubit_inds = [[0], [0], [0, 1]]
circuit_params_inds = [[], [0], []]
qujax.print_circuit(circuit_gates, circuit_qubit_inds, circuit_params_inds);
# q0: -----H-----Ry[0]-----◯---
# |
# q1: ---------------------CZ--
param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
circuit_qubit_inds,
circuit_params_inds)
We now have a pure JAX function that generates the statetensor for given parameters
param_to_st(jnp.array([0.1]))
# DeviceArray([[0.58778524+0.j, 0. +0.j],
# [0.80901706+0.j, 0. +0.j]], dtype=complex64)
The statevector can be obtained from the statetensor via .flatten()
.
param_to_st(jnp.array([0.1])).flatten()
# DeviceArray([0.58778524+0.j, 0.+0.j, 0.80901706+0.j, 0.+0.j], dtype=complex64)
We can also use qujax to map the statetensor to an expected value
st_to_expectation = qujax.get_statetensor_to_expectation_func([['Z']], [[0]], [1.])
Combining the two gives us a parameter to expectation function that can be differentiated seamlessly and exactly with JAX
from jax import value_and_grad
param_to_expectation = lambda param: st_to_expectation(param_to_st(param))
expectation_and_grad = value_and_grad(param_to_expectation)
expectation_and_grad(jnp.array([0.1]))
# (DeviceArray(-0.3090171, dtype=float32),
# DeviceArray([-2.987832], dtype=float32))
Notes
- We use the convention where parameters are given in units of π (i.e. in [0,2] rather than [0, 2π]).
- By default the parameter to statetensor function initiates in the all 0 state, however there is an optional
statetensor_in
argument to initiate in an arbitrary state.
pytket-qujax
You can also generate the parameter to statetensor function from a pytket
circuit using the pytket-qujax
extension.
In particular, the
tk_to_qujax
and
tk_to_qujax_symbolic
functions.
An example notebook can be found at pytket-qujax_heisenberg_vqe.ipynb
.
Contributing
Bugs and feature requests are managed using GitHub issues.
Pull requests are welcomed!
- First fork the repo and create your branch from
develop
. - Add your code.
- Add your tests.
- Update the documentation if required.
- Issue a pull request into
develop
.
New commits on develop
will then be merged into
main
on the next release.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.