Skip to main content

Quasigeostrophic model in JAX (port of PyQG)

Project description

PyQG JAX Port

PyQG-JAX on PyPI PyQG-JAX on conda-forge Documentation Tests Zenodo

This is a partial port of PyQG to JAX which enables GPU acceleration, batching, automatic differentiation, etc.

⚠️ Warning: this is a partial, early stage port. There may be bugs and other numerical issues. The API may evolve as work continues.

Installation

Install from PyPI using pip:

$ python -m pip install pyqg-jax

or from conda-forge:

$ conda install -c conda-forge pyqg-jax

This should install required dependencies, but JAX itself may require special attention, particularly for GPU support. Follow the JAX installation instructions.

Usage

Documentation is a work in progress. The parameters QGModel implemented here are the same as for the model in the original PyQG, so consult the pyqg documentation for details.

However, there are a few overarching changes used to make the models JAX-compatible:

  1. The model state is now a separate, immutable object rather than being attributes of the QGModel class

  2. Time-stepping is now separated from the models. Use steppers.AB3Stepper for the same time stepping as in the original QGModel.

  3. Random initialization requires an explicit key variable as with all JAX random number generation.

The QGModel uses double precision (float64) values for part of its computation regardless of the precision setting. Make sure JAX is set to enable 64-bit. See the documentation for details. One option is to set the following environment variable:

export JAX_ENABLE_X64=True

or use the %env magic in a Jupyter notebook.

Short Example

A short example initializing a QGModel, adding a parameterization, and taking a single step (for more, see the examples in the documentation).

>>> import pyqg_jax
>>> import jax
>>> # Construct model, parameterization, and time-stepper
>>> stepped_model = pyqg_jax.steppers.SteppedModel(
...     model=pyqg_jax.parameterizations.smagorinsky.apply_parameterization(
...         pyqg_jax.qg_model.QGModel(),
...         constant=0.08,
...     ),
...     stepper=pyqg_jax.steppers.AB3Stepper(dt=3600.0),
... )
>>> # Initialize the model state (wrapped in stepper and parameterization state)
>>> stepper_state = stepped_model.create_initial_state(
...     jax.random.key(0)
... )
>>> # Compute next state
>>> next_stepper_state = stepped_model.step_model(stepper_state)
>>> # Unwrap the result from the stepper and parameterization
>>> next_param_state = next_stepper_state.state
>>> next_model_state = next_param_state.model_state
>>> final_q = next_model_state.q

For repeated time-stepping combine step_model with jax.lax.scan.

License

This software is distributed under the MIT license. See LICENSE.txt for the license text.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyqg_jax-0.8.1.tar.gz (35.1 kB view hashes)

Uploaded Source

Built Distribution

pyqg_jax-0.8.1-py3-none-any.whl (35.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page