Skip to main content

Simulating quantum circuits with JAX

Project description

qujax

Represent a (parameterised) quantum circuit as a pure JAX function that takes as input any parameters of the circuit and outputs a statetensor. The statetensor encodes all $2^N$ amplitudes of the quantum state and can then be used downstream for exact expectations, gradients or sampling.

A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support for GPUs/TPUs.

Some useful links:

Install

pip install qujax

Parameterised quantum circuits with qujax

from jax import numpy as jnp
import qujax

circuit_gates = ['H', 'Ry', 'CZ']
circuit_qubit_inds = [[0], [0], [0, 1]]
circuit_params_inds = [[], [0], []]

qujax.print_circuit(circuit_gates, circuit_qubit_inds, circuit_params_inds);
# q0: -----H-----Ry[0]-----◯---
#                          |   
# q1: ---------------------CZ--
param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
                                                   circuit_qubit_inds,
                                                   circuit_params_inds)

We now have a pure JAX function that generates the statetensor for given parameters

param_to_st(jnp.array([0.1]))
# DeviceArray([[0.58778524+0.j, 0.        +0.j],
#              [0.80901706+0.j, 0.        +0.j]], dtype=complex64)

The statevector can be obtained from the statetensor via .flatten().

param_to_st(jnp.array([0.1])).flatten()
# DeviceArray([0.58778524+0.j, 0.+0.j, 0.80901706+0.j, 0.+0.j], dtype=complex64)

We can also use qujax to map the statetensor to an expected value

st_to_expectation = qujax.get_statetensor_to_expectation_func([['Z']], [[0]], [1.])

Combining the two gives us a parameter to expectation function that can be differentiated seamlessly and exactly with JAX

from jax import value_and_grad

param_to_expectation = lambda param: st_to_expectation(param_to_st(param))
expectation_and_grad = value_and_grad(param_to_expectation)
expectation_and_grad(jnp.array([0.1]))
# (DeviceArray(-0.3090171, dtype=float32),
#    DeviceArray([-2.987832], dtype=float32))

Notes

  • We use the convention where parameters are given in units of π (i.e. in [0,2] rather than [0, 2π]).
  • By default the parameter to statetensor function initiates in the all 0 state, however there is an optional statetensor_in argument to initiate in an arbitrary state.

pytket-qujax

You can also generate the parameter to statetensor function from a pytket circuit using the pytket-qujax extension. In particular, the tk_to_qujax and tk_to_qujax_symbolic functions. An example notebook can be found at pytket-qujax_heisenberg_vqe.ipynb.

Contributing

Bugs and feature requests are managed using GitHub issues.

Pull requests are welcomed!

  1. First fork the repo and create your branch from develop.
  2. Add your code.
  3. Add your tests.
  4. Update the documentation if required.
  5. Issue a pull request into develop.

New commits on develop will then be merged into main on the next release.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

qujax-0.2.9-py3-none-any.whl (14.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page