Skip to main content

Explanation techniques for Transformer-based architectures

Project description

.. -- mode: rst --

|pypi_version|_ |pypi_downloads|_

.. |pypi_version| image:: https://img.shields.io/pypi/v/explainable-transformers.svg .. _pypi_version: https://pypi.python.org/pypi/explainable-transformers/

.. |pypi_downloads| image:: https://pepy.tech/badge/explainable-transformers/month .. _pypi_downloads: https://pepy.tech/project/explainable-transformers

.. image:: artwork/cover.png :alt: Vision Transformers explanation

===== explainable-transformers

Explanation and interpretation techniques for Transformer-based architectures.


Installation

Requirements:

  • opencv-python
  • numpy
  • torch
  • tqdm

.. code:: bash

pip install explainable-transformers

Usage examples

Please, see notebook/ for complete examples on how to create representations for the explanations.

For Vision Transformers, use the VisionTransformerWrapper passing a Pytorch model.

.. code:: python

from transformers import ViTModel

# import explanator module
from explainable_transformers.image_explainer import VisionTransformerWrapper

# define the last layer for classification
class PreTrainedViT(nn.Module):
    def __init__(self, vit_model, d_model, classes):
        ...

    def forward(self, x):
        ...

        
# load the pre-trained model
pretrained_vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k', 
                                                add_pooling_layer=False, output_attentions=True)

model = PreTrainedViT(pretrained_vit_model, hidden_size=768, output_dim=10)

# create the ViT wrapper and register the layers
vit_wrapper = VisionTransformerWrapper(model, device, num_attn_layers=12)
vit_wrapper.register_hook()

# explain a prediction using .generate_visualization(img)
image = Image.open('images/dogbird.png')
processed_image = transform(image)
cat_exp, _ = vit_wrapper.generate_visualization(processed_image)

For Text Transformers, right now we need to know how the attention component is organized.

.. code:: python

# first the imports

from transformers import BertTokenizer, BertForSequenceClassification

from explainable_transformers.utils import *
from explainable_transformers import NLPTransformerWrapper


# for text, we provide the NLP wrapper

"""
We access the attention component like following:

- BERT or RoBERTa: '.encoder.layer.#.attention.self.dropout'
- XLNet: '.layer.#.rel_attn.dropout'

"""
nlp_wrapper = NLPTransformerWrapper(model, device, 12, 'bert', 'classifier', '.encoder.layer.#.attention.self.dropout')
nlp_wrapper.register_hook()

explanation = nlp_wrapper.generate_explanation(input_ids, attention_mask, class_index=true_class, start_layer=NUM_LAYERS-1)
explanation = explanation.detach().cpu().numpy()

Citation

Please, use the respective authors if you use any of the techniques.

Currently, we have the Pytorch implementation of the following approaches:

Transformer Interpretability Beyond Attention Visualization (paper <https://arxiv.org/abs/2012.09838>_):

  1. Transformers: BERT, RoBERTa, and XLNet

  2. Vision Transformers

.. code:: bibtex

@InProceedings{Chefer_2021_CVPR,
    author    = {Chefer, Hila and Gur, Shir and Wolf, Lior},
    title     = {Transformer Interpretability Beyond Attention Visualization},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2021},
    pages     = {782-791}
}

License

explainable-transformers follows the 3-clause BSD license and it is based on other open-source implementations: Chefer's <https://github.com/hila-chefer/Transformer-Explainability>_.

We also use nlp_understanding <https://github.com/ENSAE-CKW/nlp_understanding>_ for generating the heatmap.

E-mail me (wilson_jr at outlook dot com) if you like to contribute.

......

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

explainable_transformers-0.0.1-py3-none-any.whl (13.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page