Skip to main content

Momentum Residual Neural Networks

Project description

GHActions PyPI Downloads

This repository hosts Python code for Momentum ResNets.

See the documentation, our ICML 2021 paper and a 5 min presentation.

Model

Official library for using Momentum Residual Neural Networks [1]. These models extend any Residual architecture (for instance it also work with Transformers) to a larger class of deep learning models that consume less memory. They can be initialized with the same weights as a pretrained ResNet and are promising in fine-tuning applications.

Installation

pip

To install momentumet, you first need to install its dependencies:

$ pip install numpy matplotlib torch

Then install momentumnet with pip:

$ pip install momentumnet

or to get the latest version of the code:

$ pip install git+https://github.com/michaelsdr/momentumnet.git#egg=momentumnet

If you do not have admin privileges on the computer, use the --user flag with pip. To upgrade, use the --upgrade flag provided by pip.

check

To check if everything worked fine, you can do:

$ python -c 'import momentumnet'

and it should not give any error message.

Quickstart

The main class is MomentumNet. It creates a Momentum ResNet for which forward equations can be reversed in closed-form, enabling learning without standard memory consuming backpropagation. This process trades memory for computations.

To get started, you can create a toy Momentum ResNet by specifying the functions f for the forward pass and the value of the momentum term, gamma.

>>> from torch import nn
>>> from momentumnet import MomentumNet
>>> hidden = 8
>>> d = 500
>>> function = nn.Sequential(nn.Linear(d, hidden), nn.Tanh(), nn.Linear(hidden, d))
>>> mresnet = MomentumNet([function,] * 10, gamma=0.9)

Momentum ResNets are a drop-in replacement for ResNets

We can transform a ResNet into a MomentumNet with the same parameters in two lines of codes. For instance, the following code instantiates a Momentum ResNet with weights of a pretrained Resnet-101 on ImageNet. We set “use_backprop” to False so that activations are not saved during the forward pass, allowing smaller memory consumptions.

>>> import torch
>>> from momentumnet import transform_to_momentumnet
>>> from torchvision.models import resnet101
>>> resnet = resnet101(pretrained=True)
>>> mresnet101 = transform_to_momentumnet(resnet, gamma=0.9, use_backprop=False)

Importantly, this method also works with Pytorch Transformers module, specifying the residual layers to be turned into their Momentum version.

>>> import torch
>>> from momentumnet import transform_to_momentumnet
>>> transformer = torch.nn.Transformer(num_encoder_layers=6, num_decoder_layers=6)
>>> mtransformer = transform_to_momentumnet(transformer, sub_layers=["encoder.layers", "decoder.layers"], gamma=0.9,
>>>                                          use_backprop=False, keep_first_layer=False)

This initiates a Momentum Transformer with the same weights as the original Transformer.

Memory savings when applying Momentum ResNets to Transformers

Here is a short tutorial showing the memory gains when using Momentum Transformers.

Dependencies

These are the dependencies to use momentumnet:

  • numpy (>=1.8)

  • matplotlib (>=1.3)

  • torch (>= 1.9)

  • memory_profiler

  • vit_pytorch

Cite

If you use this code in your project, please cite:

Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyré
Momentum Residual Neural Networks
Proceedings of the 38th International Conference on Machine Learning, PMLR 139:9276-9287
https://arxiv.org/abs/2102.07870

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

momentumnet-0.10.tar.gz (39.4 kB view hashes)

Uploaded Source

Built Distribution

momentumnet-0.10-py3-none-any.whl (19.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page