Geometric Deep Learning Extension Library for TensorFlow and PyTorch
Project description
TensorFLow or PyTorch? Both!
GraphGallery
GraphGallery is a gallery of state-of-the-arts graph neural networks for TensorFlow 2.x and PyTorch. NOTE: Version 0.3.0 is still in testing.
Installation
- Build from source (latest version)
git clone https://github.com/EdisonLeeeee/GraphGallery.git
cd GraphGallery
python setup.py install
- Or using pip (stable version)
pip install -U graphgallery
Implementations
In detail, the following methods are currently implemented:
Semi-supervised models
General models
- ChebyNet from Michaël Defferrard et al, 📝Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering, NIPS'16. [🌋 TF]
- GCN from Thomas N. Kipf et al, 📝Semi-Supervised Classification with Graph Convolutional Networks, ICLR'17. [🌋 TF], [🔥 Torch]
- GraphSAGE from William L. Hamilton et al, 📝Inductive Representation Learning on Large Graphs, NIPS'17. [🌋 TF]
- FastGCN from Jie Chen et al, FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling, ICLR'18. [🌋 TF]
- LGCN from Hongyang Gao et al, 📝Large-Scale Learnable Graph Convolutional Networks, KDD'18. [🌋 TF]
- GAT from Petar Veličković et al, 📝Graph Attention Networks, ICLR'18. ), [🌋 TF]
- SGC from Felix Wu et al, 📝Simplifying Graph Convolutional Networks, ICML'19. [🌋 TF]
- GWNN from Bingbing Xu et al, 📝Graph Wavelet Neural Network, ICLR'19. [🌋 TF]
- GMNN from Meng Qu et al, 📝Graph Markov Neural Networks, ICML'19. [🌋 TF]
- ClusterGCN from Wei-Lin Chiang et al, 📝Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks, KDD'19. [🌋 TF]
- DAGNN from Meng Liu et al, 📝Towards Deeper Graph Neural Networks, KDD'20. [🌋 TF]
Defense models
- RobustGCN from Dingyuan Zhu et al, 📝Robust Graph Convolutional Networks Against Adversarial Attacks, KDD'19. [🌋 TF]
- SBVAT/OBVAT from Zhijie Deng et al, 📝Batch Virtual Adversarial Training for Graph Convolutional Networks, ICML'19. [🌋 TF], [🌋 TF]
Unsupervised models
- Deepwalk from Bryan Perozzi et al, 📝DeepWalk: Online Learning of Social Representations, KDD'14. [🌋 TF]
- Node2vec from Aditya Grover et al, 📝node2vec: Scalable Feature Learning for Networks, KDD'16. [🌋 TF]
Quick Start
Datasets
from graphgallery.data import Planetoid
# set `verbose=False` to avoid these printed tables
data = Planetoid('cora', verbose=False)
graph = data.graph
idx_train, idx_val, idx_test = data.split()
# idx_train: training indices: 1D Numpy array
# idx_val: validation indices: 1D Numpy array
# idx_test: testing indices: 1D Numpy array
>>> graph
Graph(adj_matrix(2708, 2708), attr_matrix(2708, 2708), labels(2708,))
currently the supported datasets are:
>>> data.supported_datasets
('citeseer', 'cora', 'pubmed')
Example of GCN model
from graphgallery.nn.models import GCN
model = GCN(graph, attr_transformer="normalize_attr", device="CPU", seed=123)
# build your GCN model with default hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
On Cora
dataset:
<Loss = 1.0161 Acc = 0.9500 Val_Loss = 1.4101 Val_Acc = 0.7740 >: 100%|██████████| 100/100 [00:01<00:00, 118.02it/s]
Test loss 1.4123, Test accuracy 81.20%
Customization
- Build your model you can use the following statement to build your model
# one hidden layer with hidden units 32 and activation function RELU
>>> model.build(hiddens=32, activations='relu')
# two hidden layer with hidden units 32, 64 and all activation functions are RELU
>>> model.build(hiddens=[32, 64], activations='relu')
# two hidden layer with hidden units 32, 64 and activation functions RELU and ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])
# other parameters like `dropouts` and `l2_norms` (if have) are the SAME.
- Train your model
# train with validation
>>> his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# train without validation
>>> his = model.train(idx_train, verbose=1, epochs=100)
here his
is tensorflow Histoory
(like) instance.
- Test you model
>>> loss, accuracy = model.test(idx_test)
>>> print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
Test loss 1.4124, Test accuracy 81.20%
Visualization
NOTE: you must install SciencePlots package for a better preview.
- Accuracy
import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
plt.plot(his.history['acc'])
plt.plot(his.history['val_acc'])
plt.legend(['Train Accuracy', 'Val Accuracy'])
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.autoscale(tight=True)
plt.show()
- Loss
import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'])
plt.legend(['Train Loss', 'Val Loss'])
plt.ylabel('Loss')
plt.xlabel('Epochs')
plt.autoscale(tight=True)
plt.show()
Using TensorFlow/PyTorch Backend
>>> import graphgallery
>>> graphgallery.backend()
TensorFlow 2.1.0 Backend
>>> graphgallery.set_backend("pytorch")
PyTorch 1.6.0+cu101 Backend
GCN using PyTorch backend
# The following codes are the same with TensorFlow Backend
>>> from graphgallery.nn.models import GCN
>>> model = GCN(graph, attr_transformer="normalize_attr", device="GPU", seed=123);
>>> model.build()
>>> his = model.train(idx_train, idx_val, verbose=1, epochs=100)
loss 0.57, acc 96.43%, val_loss 1.04, val_acc 78.20%: 100%|██████████| 100/100 [00:00<00:00, 210.90it/s]
>>> loss, accuracy = model.test(idx_test)
>>> print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
Test loss 1.0271, Test accuracy 81.10%
How to add your custom datasets
TODO
How to define your custom models
TODO
More Examples
Please refer to the examples directory.
TODO Lists
- Add Docstrings and Documentation (Building)
- Add PyTorch models support
- Support for
graph Classification
andlink prediction
tasks - Support for Heterogeneous graphs
Acknowledgement
This project is motivated by Pytorch Geometric, Tensorflow Geometric and Stellargraph, and the original implementations of the authors, thanks for their excellent works!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for graphgallery-0.3.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 269fa752c4f9583a0d045f4fd6f44df142ae448c6be7e61715bf9470304c2c5c |
|
MD5 | d527669fef998ccbb95d1564c68d3f8a |
|
BLAKE2b-256 | c778dbf8b11a32e593f1db3c7aa9b679553b9106748328eccae0f30040146da9 |