Skip to main content

MEGAN: Multi Explanation Graph Attention Network

Project description

made-with-python made-with-kgcnn python-version os-linux

Architecture Overview

MEGAN: Multi Explanation Graph Attention Student

Explainable artificial intelligence (XAI) methods are expected to improve trust during human-AI interactions, provide tools for model analysis and extend human understanding of complex problems. Attention-based models are an important subclass of XAI methods, partly due to their full differentiability and the potential to improve explanations by means of explanation-supervised training. We propose the novel multi-explanation graph attention network (MEGAN). Our graph regression and classification model features multiple explanation channels, which can be chosen independently of the task specifications. We first validate our model on a synthetic graph regression dataset, where our model produces single-channel explanations with quality similar to GNNExplainer. Furthermore, we demonstrate the advantages of multi-channel explanations on one synthetic and two real-world datasets: The prediction of water solubility of molecular graphs and sentiment classification of movie reviews. We find that our model produces explanations consistent with human intuition, opening the way to learning from our model in less well-understood tasks.

Installation

Main Installation

Clone the repository from github:

git clone https://github.com/awa59kst120df/graph_attention_student.git

Then in the main folder run a pip install:

cd graph_attention_student
pip3 install -e .

Afterwards, you can check the install by invoking the CLI:

python3 -m graph_attention_student.cli --version
python3 -m graph_attention_student.cli --help

Usage

Computational Experiments

It is possible to list, show and execute all the computational experiments using a command line interface CLI.

NOTE Most of the experiments have a long runtime, ranging from ~2hrs to ~2days. Furthermore, all of the experiments which do model training are currently configured to run on a GPU and might crash if the GPU can either not be detected or does not have enough VRAM. This setting can be changed in the corresponding experiment scripts

All the available experiments can be listed like this:

python3 -m graph_attention_student.cli list

The details for a specific experiment can be viewed like this:

python3 -m graph_attention_student.cli info [experiment_name]

A new run of an experiment can be started like this. However, be aware that this might take some time.

python3 -m graph_attention_student.cli run [experiment_name]

Each experiment will create a new archive folder, which will contain all the artifacts (such as visual examples and the raw data) created during the runtime. The location of this archive folder can be found from the output generated by the experiment execution.

Archived Experiments

To view the detailed data which was used in the making of the paper, go to graph_attention_student/experiments. The subfolders in that folder contain the archived experiments. These contain extensive examples for each repetition of the various experiments as well as all of the raw data collected during the execution of the experiments.

MEGAN in code

The MEGAN model is implemented as the MultiAttentionStudent class, which implements keras.Model. The implementation is based on the kgcnn library for graph convolutional networks for keras. For further information on loading graph structured data with kgcnn visit: https://github.com/aimat-lab/gcnn_keras

This is a simple example of how to use the model in the regression case:

import tensorflow as tf
import tensorflow.keras as ks
from graph_attention_student.training import NoLoss
from graph_attention_student.models import Megan

model = Megan(
    # These lists define the number of layers and the number of hidden units in each layer for the
    # various parts of the architecture
    units=[9, 9, 9],  # The main convolutional layers
    importance_units=[],  # The MLP that creates the node importances
    final_units=[5, 1],  # The final MLP for graph embeddings
    # Example for a regression problem. We need the prior knowledge about what range the values of the
    # dataset will be expected to fall into...
    regression_limits=(-3, +3),
    # ... as well as a reference value.
    regression_reference=0,
    # This controls the weight of the explanation-only train step (gamma)
    importance_factor=1.0,
    importance_multiplier=5,
    # This is the weight of the sparsity regularization
    sparsity_factor=0.1,
)

# The model output is actually a three tuple: (prediction, node_importances, edge_importances).
# This allows the importances to be trained in a supervised fashion. If we don't want that,
# we can simply supply the NoLoss function instead.
model.compile(
    loss=[ks.losses.MeanSquaredError(), NoLoss(), NoLoss()],
    loss_weights=[1, 1, 1],
    optimizer=ks.optimizers.Adam(0.001)
)

# model.fit() ...

Examples

The following examples show some of the cherry picked examples that show the explanatory capabilities of the model.

RB-Motifs Dataset

This is a synthetic dataset, which basically consists of randomly generated graphs with nodes of different colors. Some of the graphs contain special sub-graph motifs, which are either blue-heavy or red-heavy structures. The blue-heavy sub-graphs contribute a certain negative value to the overall value of the graph, while red-heavy structures contain a certain positive value.

This way, every graph has a certain value associated with it, which is between -3 and 3. The network was trained to predict this value for each graph.

Rb-Motifs Example

The examples shows from left to right: (1) The ground truth explanations, (2) a baseline MEGAN model trained only on the prediction task, (3) explanation-supervised MEGAN model and (4) GNNExplainer explanations for a basic GCN network. While the baseline MEGAN and GNNExplainer focus only on one of the ground truth motifs, the explanation-supervised MEGAN model correctly finds both.

Water Solubility Dataset

This is the AqSolDB dataset, which consists of ~10000 molecules and measured values for the solubility in water (logS value).

The network was trained to predict the solubility value for each molecule.

Solubility Example.png

Movie Reviews

Originally the MovieReviews dataset is a natural language processing dataset from the ERASER benchmark. The task is to classify the sentiment of ~2000 movie reviews collected from the IMDB database into the classes “positive” and “negative”. This dataset was converted into a graph dataset by considering all words as nodes of a graph and then connecting adjacent words by undirected edges with a sliding window of size 2. Words were converted into numeric feature vectors by using a pre-trained GLOVE model.

Example for a positive review:

Positive Movie Review

Example for a negative review:

Negative Movie Review

Examples show the explanation channel for the “negative” class left and the “positive” class right. Sentences with negative / positive adjectives are appropriately attributed to the corresponding channels.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graph_attention_student-0.10.0.tar.gz (2.2 MB view hashes)

Uploaded Source

Built Distribution

graph_attention_student-0.10.0-py3-none-any.whl (2.3 MB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page