Skip to main content

Neural network sampling utilities.

Project description

https://raw.githubusercontent.com/BrianPugh/torchsample/main/assets/banner-white-bg-512w.png

GHA Status Coverage Documentation Status

Lightweight pytorch functions for neural network featuremap sampling.

WARNING: API is not yet stable. API subject to change!

Introduction

Sampling neural network featuremaps at explicit coordinates has become more and more common with popular developments like:

PyTorch provides the tools necessary that to sample coordinates, but they result in a large amount of error-prone code. TorchSample intends to make it simple so you can focus on other parts of the model.

Usage

Installation

Requires python >=3.8 Install torchsample via pip:

pip install torchsample

Or, if you want to install the nightly version:

pip install git+https://github.com/BrianPugh/torchsample.git@main

Training

A common scenario is to randomly sample points from a featmap and from the ground truth.

import torchsample as ts

b, c, h, w = batch["image"].shape
coords = ts.coord.rand(b, 4096, 2)  # (b, 4096, 2) where the last dim is (x, y)

featmap = feature_extractor(batch["image"])  # (b, feat, h, w)
sampled = ts.sample(coords, featmap)  # (b, 4096, feat)
gt_sample = ts.sample(coords, batch["gt"])

Inference

During inference, a comprehensive query of the network to form a complete image is common.

import torch
import torchsample as ts

b, c, h, w = batch["image"].shape
coords = ts.coord.full_like(batch["image"])
featmap = encoder(batch["image"])  # (b, feat, h, w)
feat_sampled = ts.sample(coords, featmap)  # (b, h, w, c)
output = model(featmap)  # (b, h, w, pred)
output = output.permute(0, 3, 1, 2)

Positional Encoding

Common positional encoding schemes are available.

import torchsample as ts

coords = ts.coord.rand(b, 4096, 2)
pos_enc = ts.encoding.gamma(coords)

A common task it concatenating the positional encoding to sampled values. You can do this by passing a callable into ts.sample:

import torchsample as ts

encoder = ts.encoding.Gamma()
sampled = ts.sample(coords, featmap, encoder=encoder)

Models

torchsample has some common builtin models:

import torchsample as ts

# Properly handles (..., feat) tensors.
model = ts.models.MLP(256, 256, 512, 512, 1024, 1024, 1)

Design Decisions

  • align_corners=False by default (same as Pytorch). You should probably not touch it; explanation here.

  • Everything is in normalized coordinates [-1, 1] by default.

  • Coordinates are always in order (x, y, ...).

  • Whenever a size is given, it will be in (w, h) order; i.e. matches coordinate order. It makes implementation simpler and a consistent rule helps prevent bugs.

  • When coords is a function argument, it comes first.

  • Simple wrapper functions (like ts.coord.rand) are provided to make the intentions of calling code more clear.

  • Try and mimic native pytorch and torchvision interfaces as much as possible.

  • Try and make the common-usecase as simple and intuitive as possible.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsample-0.1.0.tar.gz (393.6 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page