Skip to main content

Lattice kernel for scalable Gaussian processes in GPyTorch

Project description

Simplex-GPs

This repository hosts the code for SKIing on Simplices: Kernel Interpolation on the Permutohedral Lattice for Scalable Gaussian Processes (Simplex-GPs) by Sanyam Kapoor, Marc Finzi, Ke Alexander Wang, Andrew Gordon Wilson.

The Idea

Fast matrix-vector multiplies (MVMs) are the cornerstone of modern scalable Gaussian processes. By building upon the approximation proposed by Structured Kernel Interpolation (SKI), and leveraging advances in fast high-dimensional image filtering, Simplex-GPs approximate the computation of the kernel matrices by tiling the space using a sparse permutohedral lattice, instead of a rectangular grid.

The matrix-vector product implied by the kernel operations in SKI are now approximated via the three stages visualized above --- splat (projection onto the permutohedral lattice), blur (applying the blur operation as a matrix-vector product), and slice (re-projecting back into the original space).

This alleviates the curse of dimensionality associated with SKI operations, allowing them to scale beyond ~5 dimensions, and provides competitive advantages in terms of runtime and memory costs, at little expense of downstream performance. See our manuscript for complete details.

Usage

The lattice kernels are packaged as GPyTorch modules, and can be used as a fast approximation to either the RBFKernel or the MaternKernel. The corresponding replacement modules are RBFLattice and MaternLattice.

RBFLattice kernel is simple to use by changing a single line of code:

import gpytorch as gp
from gpytorch_lattice_kernel import RBFLattice

class SimplexGPModel(gp.models.ExactGP):
  def __init__(self, train_x, train_y):
    likelihood = gp.likelihoods.GaussianLikelihood()
    super().__init__(train_x, train_y, likelihood)

    self.mean_module = gp.means.ConstantMean()
    self.covar_module = gp.kernels.ScaleKernel(
-      gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1))
+      RBFLattice(ard_num_dims=train_x.size(-1), order=1)
    )

  def forward(self, x):
    mean_x = self.mean_module(x)
    covar_x = self.covar_module(x)
    return gp.distributions.MultivariateNormal(mean_x, covar_x)

The GPyTorch Regression Tutorial provides a simpler example on toy data, where this kernel can be used as a drop-in replacement.

Install

To use the kernel in your code, install the package as:

pip install gpytorch-lattice-kernel

NOTE: The kernel is compiled lazily from source using CMake. If the compilation fails, you may need to install a more recent version. Additionally, ninja is required for compilation. One way to install is:

conda install -c conda-forge cmake ninja

Local Setup

For a local development setup, create the conda environment

$ conda env create -f environment.yml

Remember to add the root of the project to PYTHONPATH if not already.

$ export PYTHONPATH="$(pwd):${PYTHONPATH}"

Test

To verify the code is working as expected, a simple test file is provided, that tests for the training marginal likelihood achieved by Simplex-GPs and Exact-GPs. Run as:

python tests/train_snelson.py

The Snelson 1-D toy dataset is used. A copy is available in snelson.csv.

Results

The proposed kernel can be used with GPyTorch as usual. An example script to reproduce results is,

python experiments/train_simplexgp.py --dataset=elevators --data-dir=<path/to/uci/data/mat/files>

We use Fire to handle CLI arguments. All arguments of the main function are therefore valid arguments to the CLI.

All figures in the paper can be reproduced via notebooks.

NOTE: The UCI dataset mat files are available here.

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gpytorch-lattice-kernel-0.0.dev1.tar.gz (23.8 kB view hashes)

Uploaded Source

Built Distribution

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page