GPipe for PyTorch
Project description
A GPipe implementation in PyTorch.
from torchgpipe import GPipe
model = nn.Sequential(a, b, c, d)
model = GPipe(model, balance=[1, 1, 1, 1], chunks=8)
for input in data_loader:
output = model(input)
What is GPipe?
GPipe is a scalable pipeline parallelism library published by Google Brain, which allows efficient training of large, memory-consuming models. According to the paper, GPipe can train a 25x larger model by using 8x devices (TPU), and train a model 3.5x faster by using 4x devices.
GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism
Google trained AmoebaNet-B with 557M parameters over GPipe. This model has achieved 84.3% top-1 and 97.0% top-5 accuracy on ImageNet classification benchmark (the state-of-the-art performance as of May 2019).
Links
Source Code: https://github.com/kakaobrain/torchgpipe
Documentation: https://torchgpipe.readthedocs.io/
Original Paper: https://arxiv.org/abs/1811.06965
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for torchgpipe-0.0.7-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 052c704f8c03b7695110e81853e4166c1fc1db3ad1e4cc0fa4eceb8a63e32627 |
|
MD5 | 75a817456ed1a0fe59b76a5abc6bfc45 |
|
BLAKE2b-256 | f48497f3c3b27b666de92477dda6425dd7cb56c7bbaeb115c3a8a1ec7dbe8e05 |