Skip to main content

TensorGuard helps to guard against bad Tensor Shapes

Project description

Tensor Guard

PyPI version fury.io PyPI pyversions PyPI download month GitHub followers

TensorGuard helps to guard against bad Tensor shapes in any tensor based library (e.g. Numpy, Pytorch, Tensorflow) using an intuitive symbolic-based syntax

Installation

pip install tensorguard

Basic Usage

import numpy as np  # could be tensorflow or torch as well
import tensorguard as tg

# tensorguard = tg.TensorGuard()  #could be done in a OOP fashion
img = np.ones([64, 32, 32, 3])
flat_img = np.ones([64, 1024])
labels = np.ones([64])

# check shape consistency
tg.guard(img, "B, H, W, C")
tg.guard(labels, "B, 1")  # raises error because of rank mismatch
tg.guard(flat_img, "B, H*W*C")  # raises error because 1024 != 32*32*3

# guard also returns the tensor, so it can be inlined
mean_img = tg.guard(np.mean(img, axis=0), "H, W, C")

# more readable reshapes
flat_img = tg.reshape(img, 'B, H*W*C')

# evaluate templates
assert tg.get_dims('H, W*C+1') == [32, 97]

Shape Template Syntax

The shape template mini-DSL supports many different ways of specifying shapes:

  • numbers: "64, 32, 32, 3"
  • named dimensions: "B, width, height2, channels"
  • wildcards: "B, *, *, *"
  • ellipsis: "B, ..., 3"
  • addition, subtraction, multiplication, division: "B*N, W/2, H*(C+1)"
  • dynamic dimensions: "?, H, W, C" (only matches [None, H, W, C])

Original Repo link: https://github.com/Qwlouse/shapeguard

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensorguard-1.0.0.tar.gz (27.4 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page