Skip to main content

Counterfactual explanations for GNNs based on the visual graph dataset format

Project description

made with python python 3.8 version

banner image

VGD Counterfactuals

Library for the generation and more importantly the easy visualization of Counterfactuals for Graph Neural Networks (GNNs) based on the VisualGraphDatasets dataset format.

What are Counterfactuals?

Counterfactuals are a method of explaining the predictions of complex machine learning models. For a certain prediction of a model, a counterfactual is an input element that is as similar as possible to the original input, but causes the largest possible deviation w.r.t. to the original model output prediction. They are sort of “counter examples” for the behavior of a model and can help to understand the decision boundary of the model.

The subject of this package are graph counterfactuals. They are generated by maximizing a customizable distance function in regards to the prediction output over all immediate neighbors of the original graph w.r.t. to the allowed, domain-specific graph edit operations.

Installation

git clone https://github.com/the16thpythonist/vgd_counterfactuals

Then in the main folder run a pip install:

cd vgd_counterfactuals
python3 -m pip install .

Afterwards, you can check the install by invoking the CLI:

python3 -m vgd_counterfactuals.cli --version
python3 -m vgd_counterfactuals.cli --help

Usage

Quickstart

The generation of counterfactual graphs is implemented via the CounterfactualGenerator class. The instantiation of one such object requires the following 4 main components:

  • processing: A visual_graph_dataset “Processing” object. These implement the necessary functionality to convert a domain-specific graph representation into the full graph structure for the machine learning models. These are shipped with each specific visual graph dataset.

  • model: The model to be explained. This model has to implement the visual_graph_dataset “PredictGraph” interface to ensure that the model can be directly queried with the vgd GraphDict representation of graph elements.

  • neighborhood_func: A function which receives the domain-specific representation of a graph as an input and is supposed to return a list of all the domain-specific representations of the immediate neighbors of that graph. The implementation for this is highly specific to each application domain.

  • distance_func: A function which receives to arguments: The prediction of the original element and the prediction of a neighbor and should return a single numeric value for the distance between the two predictions. The generator will maximize this distance measure.

After the generator object was instantiated, it can be used to create counterfactuals for any number of input elements using the generate method.

The following example shows a quickstart mock example of how all of this can be used. For more information have a look at the example modules provided in the examples folder of the repository.

import tempfile

from visual_graph_datasets.processing.molecules import MoleculeProcessing

from vgd_counterfactuals.base import CounterfactualGenerator
from vgd_counterfactuals.testing import MockModel
from vgd_counterfactuals.generate.molecules import get_neighborhood

processing = MoleculeProcessing()
model = MockModel()

generator = CounterfactualGenerator(
    processing=processing,
    model=model,
    neighborhood_func=get_neighborhood,
    distance_func=lambda orig, mod: abs(orig - mod),
)

with tempfile.TemporaryDirectory() as path:
    # The "generate" function will create all the possible neighbors of the
    # given "original" element, then query the model for to predict the
    # output for each of them, and sort them by their distance to the original.
    # The top k elements will be turned into a temporary visual graph dataset
    # within the given folder "path". That means in that folder two files will
    # be created per element: A metadata JSON file and a visualization PNG file.
    # Returns the dictionary for the loaded visual graph dataset.
    index_data_map = generator.generate(
        original='CCCCCC',
        # Path to the folder into which to save the vgd element files
        path=path,
        # The number of counterfactuals to be returned.
        # Elements will be sorted by their distance.
        k_results=10,
    )

    # The keys of the resulting dict are the integer indices and the values
    # are dicts themselves which describe the corresponding vgd elements.
    # These dicts contain for example the absolute path to the PNG file,
    # the full graph representation and additional metadata.
    print(f'generated {len(index_data_map)} counterfactuals:')
    for index, data in index_data_map.items():
        print(f' * {data["metadata"]["name"]} '
              f' - distance: {data["metadata"]["distance"]:.2f}')

Credits

  • PyComex is a micro framework which simplifies the setup, processing and management of computational experiments. It is also used to auto-generate the command line interface that can be used to interact with these experiments.

  • VisualGraphDatasets is a library which deals with the VGD dataset format. In this format, graph datasets for machine learning are represented by a folder, where each graph is represented by two files: A metadata JSON file that contains the full graph representation and additional metadata and a PNG visualization of the graph. The library aims to provide a framework for explainable graph machine learning which is easier to use and produces more reproducable results.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vgd_counterfactuals-0.3.2.tar.gz (429.0 kB view hashes)

Uploaded Source

Built Distribution

vgd_counterfactuals-0.3.2-py3-none-any.whl (436.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page