Skip to main content

Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf

Project description

GCViT: Global Context Vision Transformer

python tensorflow

Open In Colab Open In Kaggle

Tensorflow 2.0 Implementation of GCViT

This library implements GCViT using Tensorflow 2.0 specifally in tf.keras.Model manner to get PyTorch flavor.

Model

  • Architecture:
  • Local Vs Global Attention:

Result

Official codebase had some issue which has been fixed recently (27 July 2022). Here's the result of ported weights on ImageNetV2-Test data,

Model Acc@1 Acc@5 #Params
GCViT-XXTiny 63 85 12M
GCViT-XTiny 66 87 20M
GCViT-Tiny 69 89 28M
GCViT-Small 69 89 51M
GCViT-Base 71 90 90M

Installation

pip install -U gcvit
# or
# pip install -U git+https://github.com/awsaf49/gcvit-tf

Usage

Load model using following codes,

from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)

Simple code to check model's prediction,

from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])

Prediction:

[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623), 
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297), 
('n02883205', 'bow_tie', 0.00042479983)]

For feature extraction:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
model.reset_classifier(num_classes=0, head_act=None)
feature = model(img)
print(feature.shape)

Feature:

(None, 512)

For feature map:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
feature = model.forward_features(img)
print(feature.shape)

Feature map:

(None, 7, 7, 512)

Live-Demo

  • For live demo on Image Classification & Grad-CAM, with ImageNet weights, click powered by 🤗 Space and Gradio. here's an example,

Example

For working training example checkout these notebooks on Google Colab Open In Colab & Kaggle Open In Kaggle.

Here is grad-cam result after training on Flower Classification Dataset,

To Do

  • New updated weights have been added.
  • Working training example in Colab & Kaggle.
  • GradCAM showcase.
  • Gradio Demo.
  • Build model with tf.keras.Model.
  • Port weights from official repo.
  • Support for TPU.

Acknowledgement

Citation

@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gcvit-1.0.9.tar.gz (12.2 kB view hashes)

Uploaded Source

Built Distribution

gcvit-1.0.9-py3-none-any.whl (16.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page