PyTorch implementation of VQ-VAE
Project description
Pytorch VQVAE implementation
Example
from vqvae import VQVAE, sequential_encoder, sequential_decoder
from torch.optim import Adam
from functools import partial
input_channels = 3
output_channels = 3
embedding_length = 256
hidden_channels = 64
beta = 0.25
embedding_size = 512
opt = partial(Adam, lr=2e-4)
encoder = sequential_encoder(input_channels, embedding_size, hidden_channels) # Encoder from the paper
decoder = sequential_decoder(embedding_size, output_channels, hidden_channels) # Decoder from the paper
vqvae = VQVAE(encoder, decoder, opt, beta, embedding_length, embedding_size) # Pytorch-Lightning module,
# hence usable to train the model
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
vqvae-1.0.2.tar.gz
(11.6 kB
view hashes)
Built Distribution
vqvae-1.0.2-py3-none-any.whl
(1.6 kB
view hashes)