Quasigeostrophic model in JAX (port of PyQG)
Project description
PyQG JAX Port
This is a partial port of PyQG to JAX which enables GPU acceleration, batching, automatic differentiation, etc.
⚠️ Warning: this is a partial, early stage port. There may be bugs
and other numerical issues. Only part of the QGModel
has been
ported, and the API will very likely evolve as work continues.
Installation
Install from PyPI using pip:
$ python -m pip install pyqg-jax
This should install required dependencies, but JAX itself may require special attention. Follow the JAX installation instructions.
Usage
Documentation is a work in progress. The parameters QGModel
implemented here are the same as for the model in the original PyQG,
so consult the pyqg
documentation for details.
However, there are a few overarching changes used to make the models JAX-compatible:
-
The model state is now a separate, immutable object rather than being attributes of the
QGModel
class -
Time-stepping is now separated from the models. Use
steppers.AB3Stepper
for the same time stepping as in the originalQGModel
. -
Random initialization requires an explicit
key
variable as with all JAX random number generation.
The QGModel
uses double precision (float64
) values for part of its
computation regardless of the precision setting. Make sure JAX is set
to enable 64-bit. See the
documentation
for details. One option is to set the following environment variables:
export JAX_ENABLE_X64=True
export JAX_DEFAULT_DTYPE_BITS=32
or use the %env
magic
in a Jupyter notebook.
Short Example
A short example initializing a QGModel
, adding a parameterization,
and taking a single step.
>>> import pyqg_jax
>>> import jax
>>> # Construct model, parameterization, and time-stepper
>>> stepped_model = pyqg_jax.steppers.SteppedModel(
... model=pyqg_jax.parameterizations.smagorinsky.apply_parameterization(
... pyqg_jax.qg_model.QGModel(),
... constant=0.08,
... ),
... stepper=pyqg_jax.steppers.AB3Stepper(dt=3600.0),
... )
>>> # Initialize the model state (wrapped in stepper and parameterization state)
>>> stepper_state = stepped_model.create_initial_state(
... jax.random.PRNGKey(0)
... )
>>> # Compute next state
>>> next_stepper_state = stepped_model.step_model(stepper_state)
>>> # Unwrap the result from the stepper and parameterization
>>> next_param_state = next_stepper_state.state
>>> next_model_state = next_param_state.model_state
>>> final_q = next_model_state.q
For repeated time-stepping combine step_model
with
jax.lax.scan
.
Useful Methods and attributes
A subset of methods and attributes available on common objects
- For
pyqg_jax.qg_model.QGModel
create_initial_state(jax.random.PRNGKey) -> PseudoSpectralState
: Randomly initializes the modelget_full_state(PseudoSpectralState) -> FullPseudoSpectralState
: Expands the state, computing other attributes fromq
get_updates(PseudoSpectralState) -> PseudoSpectralState
: Computes time updates forqh
. Combine with a time-stepper
- For
pyqg_jax.steppers.AB3Stepper(dt=float)
initialize_stepper_state(PseudoSpectralState) -> AB3State[PseudoSpectralState]
: Initialize a time-stepper state around a model stateapply_updates(AB3State[PseudoSpectralState], updates=PseudoSpectralState) -> AB3State[PseudoSpectralState]
: Apply model updates to a time stepper state
- For
pyqg_jax.steppers.AB3State
state
: extract thePseudoSpectralState
at the current timet
: the current timetc
: the current step counter
- For
pyqg_jax.steppers.SteppedModel(model, stepper)
create_initial_state(key=jax.random.PRNGKey) -> StepperState[PseudoSpectralState]
: Create a new, random, state ready to stepinitialize_stepper_state(PseudoSpectralState) -> StepperState[PseudoSpectralState]
: Wraps an existing model state to prepare it for time steppingstep_model(StepperState[PseudoSpectralState]) -> StepperState[PseudoSpectralState]
: Steps the model forward, and handles filteringget_full_state(StepperState[PseudoSpectralState]) -> FullPseudoSpectralState
: Extracts the state and expands it, computing all attributes fromq
- For
pyqg_jax.state.PseudoSpectralState
q
: The potential vorticityqh
: Spectral form of potential vorticityupdate(q=, qh=) -> PseudoSpectralState
: Return a newPseudoSpectralState
with the given value replacements
- For
pyqg_jax.state.FullPseudoSpectralState
dqhdt
: Spectral updates forqh
state
: The innerPseudoSpectralState
update(q=, qh=, dqhdt=, ...) -> PseudoSpectralState
: Return a newFullPseudoSpectralState
with the given value replacements
License
The code in this repository is distributed under the MIT license. See LICENSE.txt for the license text.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.