Skip to main content

Useful utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.

Project description

Ruff image image image lint test

torch-featurelayer

🧠 Simple utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.

[!TIP] This library is intended to be a simplified and well-documented implementation for extracting a PyTorch model's intermediate layer output(s). For a more sophisticated and complete implementation, either consider using torchvision.models.feature_extraction, or check the official torch.fx.

Install

pip install torch-featurelayer

Usage

Imports:

import torch
from torchvision.models import vgg11
from torch_featurelayer import FeatureLayer

Load a pretrained VGG-11 model:

model = vgg11(weights='DEFAULT').eval()

Hook onto layer features.15 of the model:

layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)

Forward pass an input tensor through the model:

x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)

feature_output is the output of layer features.15. Print the output shape:

print(f'Feature layer output shape: {feature_output.shape}')  # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}')  # [1, 1000]

Check the examples directory for more.

API

torch_featurelayer.FeatureLayer

The FeatureLayer class wraps a model and provides a hook to access the output of a specific feature layer.

  • __init__(self, model: torch.nn.Module, feature_layer_path: str)

    Initializes the FeatureLayer instance.

    • model: The model containing the feature layer.
    • feature_layer_path: The path to the feature layer in the model.
  • __call__(self, *args: Any, **kwargs: Any) -> tuple[torch.Tensor | None, torch.Tensor]

    Performs a forward pass through the model and updates the hooked feature layer.

    • *args: Variable length argument list.
    • **kwargs: Arbitrary keyword arguments.

    Returns a tuple containing the feature layer output and the model output.

torch_featurelayer.FeatureLayers

The FeatureLayers class wraps a model and provides hooks to access the output of multiple feature layers.

  • __init__(self, model: torch.nn.Module, feature_layer_paths: list[str])

    Initializes the FeatureLayers instance.

    • model: The model containing the feature layers.
    • feature_layer_paths: A list of paths to the feature layers in the model.
  • __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]

    Performs a forward pass through the model and updates the hooked feature layers.

    • *args: Variable length argument list.
    • **kwargs: Arbitrary keyword arguments.

    Returns a tuple containing the feature layer outputs and the model output.

torch_featurelayer.get_layer_candidates(module: torch.nn.Module, max_depth: int = 1) -> Generator[str, None, None]

The get_layer_candidates function returns a generator of layer paths for a given model up to a specified depth.

  • model: The model to get layer paths from.
  • max_depth: The maximum depth to traverse in the model's layers.

Returns a generator of layer paths.

License

MIT

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page