Skip to main content

PyTorch implementation of VQ-VAE

Project description

Pytorch VQVAE implementation

Example

from vqvae import VQVAE, sequential_encoder, sequential_decoder
from torch.optim import  Adam
from functools import partial

input_channels = 3
output_channels = 3
embedding_length = 256
hidden_channels = 64
beta = 0.25
embedding_size = 512
opt = partial(Adam, lr=2e-4)

encoder = sequential_encoder(input_channels, embedding_size, hidden_channels)  # Encoder from the paper
decoder = sequential_decoder(embedding_size, output_channels, hidden_channels)  # Decoder from the paper
vqvae = VQVAE(encoder, decoder, opt, beta, embedding_length, embedding_size)  # Pytorch-Lightning module, 
                                                                              # hence usable to train the model

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vqvae-1.0.2.tar.gz (11.6 kB view hashes)

Uploaded Source

Built Distribution

vqvae-1.0.2-py3-none-any.whl (1.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page