Skip to main content

Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf

Project description

GCViT: Global Context Vision Transformer

python tensorflow

Open In Colab Open In Kaggle

Tensorflow 2.0 Implementation of GCViT

This library implements GCViT using Tensorflow 2.0 specifically in tf.keras.Model manner to get PyTorch flavor.

Update

Model

  • Architecture:
  • Local Vs Global Attention:

Result

Official codebase had some issue which has been fixed recently (12 August 2022). Here's the result of ported weights on ImageNetV2-Test data,

Model Acc@1 Acc@5 #Params
GCViT-XXTiny 0.663 0.873 12M
GCViT-XTiny 0.685 0.885 20M
GCViT-Tiny 0.708 0.899 28M
GCViT-Small 0.720 0.901 51M
GCViT-Base 0.731 0.907 90M
GCViT-Large 0.734 0.913 202M

Installation

pip install -U gcvit
# or
# pip install -U git+https://github.com/awsaf49/gcvit-tf

Usage

Load model using following codes,

from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)

Simple code to check model's prediction,

from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])

Prediction:

[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623), 
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297), 
('n02883205', 'bow_tie', 0.00042479983)]

For feature extraction:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
model.reset_classifier(num_classes=0, head_act=None)
feature = model(img)
print(feature.shape)

Feature:

(None, 512)

For feature map:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
feature = model.forward_features(img)
print(feature.shape)

Feature map:

(None, 7, 7, 512)

Kaggle Models

These pre-trained models can also be loaded using Kaggle Models. Setting from_kaggle=True will enforce model to load weights from Kaggle Models with downloading, thus can be used without internet in Kaggle.

from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True, from_kaggle=True)

Live-Demo

  • For live demo on Image Classification & Grad-CAM, with ImageNet weights, click powered by 🤗 Space and Gradio. here's an example,

Example

For working training example checkout these notebooks on Google Colab Open In Colab & Kaggle Open In Kaggle.

Here is grad-cam result after training on Flower Classification Dataset,

To Do

  • Remove tensorflow_addons
  • Segmentation Pipeline
  • Support for Kaggle Models
  • New updated weights have been added.
  • Working training example in Colab & Kaggle.
  • GradCAM showcase.
  • Gradio Demo.
  • Build model with tf.keras.Model.
  • Port weights from official repo.
  • Support for TPU.

Acknowledgement

Citation

@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gcvit-1.1.2.tar.gz (15.9 kB view hashes)

Uploaded Source

Built Distribution

gcvit-1.1.2-py3-none-any.whl (17.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page