Skip to main content

Pytorch implementation of Multi-Object Network(MONet)

Project description

Pytorch MONet implementation

Python build package and test PyPI - Python Version PyPI - Downloads PyPI - Status

Pytorch implementation of Multi-Object Network(MONet)

monet_architecture

How to install

You can install through pip with the following command

pip install monet-pytorch

or clone this repository locally and install with poetry

git clone https://github.com/Michedev/MONet-pytorch
cd MONet-pytorch
poetry install

How to use

The package comes with a set of predefined configurations based on paper specifications, namely monet and monet-iodine (MONet as defined in IODINE paper).

from monet_pytorch import Monet

monet = Monet.from_config(model='monet')

There is also another custom architecture monet-lightweight which has less parameters than the original ones.

Furthermore, the model architecture slightly changes based on the dataset (e.g. U-Net blocks) picked between the ones cited in MONet paper (CLEVR 6, Multidsprites colored on colored, Multidsprited colored on grayscale, Tetrominoes).

from monet_pytorch import Monet

monet = Monet.from_config(model='monet', dataset='tetrominoes')

In alternative, you can set custom dataset parameters through the function arguments

from monet_pytorch import Monet

monet = Monet.from_config(model='monet', dataset_width=48, dataset_height=48, scene_max_objects=6)

Lastly, you can make your custom MONet by input your custom configuration as OmegaConf DictConfig

from monet_pytorch import Monet

custom_monet_config = OmegaConf.create("""
dataset:
  width: 44
  height: 44
  max_num_objects: 10
model:  #this config file follows MONet implementation from IODINE paper
  _target_: monet_pytorch.model.Monet
  height: ${dataset.height}
  width: ${dataset.width}
  num_slots: ${dataset.max_num_objects}
  name: monet-iodine
  bg_sigma: 0.32
  fg_sigma: 0.1
  beta_kl: 0.43
  gamma: 0.5
  latent_size: 16
  input_channels: 3
  encoder:
    _target_: torch.nn.Sequential
    _args_:
      - _target_: monet_pytorch.template.sequential_cnn.make_sequential_cnn_from_config
        channels: [44, 44, 32, 14]
        kernels: 3
        strides: 2
        paddings: 0
        input_channels: 4
        batchnorms: true
        bn_affines: false
        activations: relu
      - _target_: torch.nn.Flatten
        start_dim: 1
      - _target_: torch.nn.Linear
        in_features: 256
        out_features: 256
      - _target_: torch.nn.ReLU
      - _target_: torch.nn.Linear
        in_features: 256
        out_features: ${prod:${model.latent_size},2}
  decoder:
    _target_: monet_pytorch.template.encoder_decoder.BroadcastDecoderNet
    w_broadcast: ${sum:${dataset.width},8}
    h_broadcast: ${sum:${dataset.height},8}
    net:
      _target_: monet_pytorch.template.sequential_cnn.make_sequential_cnn_from_config
      input_channels: ${sum:${model.latent_size},2} # latent size + 2
      channels: [32, 32, 64, 64, 4]  # last is 4 channels because rgb (3) + mask (1)
      kernels: [3, 3, 3, 3, 1]
      paddings: 0
      activations: [relu, relu, relu, relu, null]  #null means no activation function no activation
      batchnorms: [true, true, true, true, false]
      bn_affines: [false, false, false, false, false]
  unet:
    _target_: monet_pytorch.unet.UNet
    input_channels: ${model.input_channels}
    num_blocks: 5
    filter_start: 16
    mlp_size: 128""")

custom_monet: Monet = Monet.from_custom_config(custom_monet_config)

Model performances

This implementation reproduce very closely ARI MONet's values

Special thanks

I would like to thank @apra and @addtt for the help to fix code bugs and to improve model performances

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

monet-pytorch-1.1.0.tar.gz (15.1 kB view hashes)

Uploaded Source

Built Distribution

monet_pytorch-1.1.0-py3-none-any.whl (19.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page