A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019).

Overview

ClusterGCN

Arxiv codebeat badge repo sizebenedekrozemberczki

A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019).

Abstract

Graph convolutional network (GCN) has been successfully applied to many graph-based applications; however, training a large-scale GCN remains challenging. Current SGD-based algorithms suffer from either a high computational cost that exponentially grows with number of GCN layers, or a large space requirement for keeping the entire graph and the embedding of each node in memory. In this paper, we propose Cluster-GCN, a novel GCN algorithm that is suitable for SGD-based training by exploiting the graph clustering structure. Cluster-GCN works as the following: at each step, it samples a block of nodes that associate with a dense subgraph identified by a graph clustering algorithm, and restricts the neighborhood search within this subgraph. This simple but effective strategy leads to significantly improved memory and computational efficiency while being able to achieve comparable test accuracy with previous algorithms. To test the scalability of our algorithm, we create a new Amazon2M data with 2 million nodes and 61 million edges which is more than 5 times larger than the previous largest publicly available dataset (Reddit). For training a 3-layer GCN on this data, Cluster-GCN is faster than the previous state-of-the-art VR-GCN (1523 seconds vs 1961 seconds) and using much less memory (2.2GB vs 11.2GB). Furthermore, for training 4 layer GCN on this data, our algorithm can finish in around 36 minutes while all the existing GCN training algorithms fail to train due to the out-of-memory issue. Furthermore, Cluster-GCN allows us to train much deeper GCN without much time and memory overhead, which leads to improved prediction accuracy -- using a 5-layer Cluster-GCN, we achieve state-of-the-art test F1 score 99.36 on the PPI dataset, while the previous best result was 98.71.

This repository provides a PyTorch implementation of ClusterGCN as described in the paper:

Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks Wei-Lin Chiang, Xuanqing Liu, Si Si, Yang Li, Samy Bengio, Cho-Jui Hsieh. KDD, 2019. [Paper]

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx           1.11
tqdm               4.28.1
numpy              1.15.4
pandas             0.23.4
texttable          1.5.0
scipy              1.1.0
argparse           1.1.0
torch              0.4.1
torch-geometric    0.3.1
metis              0.2a.4
scikit-learn       0.20
torch_spline_conv  1.0.4
torch_sparse       0.2.2
torch_scatter      1.0.4
torch_cluster      1.1.5

Installing metis on Ubuntu:

sudo apt-get install libmetis-dev

Datasets

The code takes the **edge list** of the graph in a csv file. Every row indicates an edge between two nodes separated by a comma. The first row is a header. Nodes should be indexed starting with 0. A sample graph for `Pubmed` is included in the `input/` directory. In addition to the edgelist there is a csv file with the sparse features and another one with the target variable.

The **feature matrix** is a sparse one and it is stored as a csv. Features are indexed from 0 consecutively. The feature matrix csv is structured as:

NODE ID FEATURE ID Value
0 3 0.2
0 7 0.5
1 17 0.8
1 4 5.4
1 38 1.3
... ... ...
n 3 0.9

The **target vector** is a csv with two columns and headers, the first contains the node identifiers the second the targets. This csv is sorted by node identifiers and the target column contains the class meberships indexed from zero.

NODE ID Target
0 3
1 1
2 0
3 1
... ...
n 3

Options

The training of a ClusterGCN model is handled by the `src/main.py` script which provides the following command line arguments.

Input and output options

  --edge-path       STR    Edge list csv.         Default is `input/edges.csv`.
  --features-path   STR    Features csv.         Default is `input/features.csv`.
  --target-path     STR    Target classes csv.    Default is `input/target.csv`.

Model options

  --clustering-method   STR     Clustering method.             Default is `metis`.
  --cluster-number      INT     Number of clusters.            Default is 10. 
  --seed                INT     Random seed.                   Default is 42.
  --epochs              INT     Number of training epochs.     Default is 200.
  --test-ratio          FLOAT   Training set ratio.            Default is 0.9.
  --learning-rate       FLOAT   Adam learning rate.            Default is 0.01.
  --dropout             FLOAT   Dropout rate value.            Default is 0.5.
  --layers              LST     Layer sizes.                   Default is [16, 16, 16]. 

Examples

The following commands learn a neural network and score on the test set. Training a model on the default dataset.

$ python src/main.py

Training a ClusterGCN model for a 100 epochs.

$ python src/main.py --epochs 100

Increasing the learning rate and the dropout.

$ python src/main.py --learning-rate 0.1 --dropout 0.9

Training a model with a different layer structure:

$ python src/main.py --layers 64 64

Training a random clustered model:

$ python src/main.py --clustering-method random

License

Comments
  • Segmentation fault While running main.py on Ubuntu

    Segmentation fault While running main.py on Ubuntu

    while i am running main.py i am getting the segmentation fault error on Ubuntu.

    python3 main.py --epochs 100

    +-------------------+----------------------------------------------------------+ | Parameter | Value | +===================+==========================================================+ | Cluster number | 10 | +-------------------+----------------------------------------------------------+ | Clustering method | metis | +-------------------+----------------------------------------------------------+ | Dropout | 0.500 | +-------------------+----------------------------------------------------------+ | Edge path | /home/User/Desktop/ClusterGCN-master/input/edges.csv | +-------------------+----------------------------------------------------------+ | Epochs | 100 | +-------------------+----------------------------------------------------------+ | Features path | /home/User/Desktop/ClusterGCN- | | | master/input/features.csv | +-------------------+----------------------------------------------------------+ | Layers | [16, 16, 16] | +-------------------+----------------------------------------------------------+ | Learning rate | 0.010 | +-------------------+----------------------------------------------------------+ | Seed | 42 | +-------------------+----------------------------------------------------------+ | Target path | /home/User/Desktop/ClusterGCN- | | | master/input//target.csv | +-------------------+----------------------------------------------------------+ | Test ratio | 0.900 | +-------------------+----------------------------------------------------------+

    Metis graph clustering started.

    Segmentation fault

    opened by alamsaqib 4
  • ImportError: No module named 'torch_spline_conv'

    ImportError: No module named 'torch_spline_conv'

    I followed the instructions of installation properly, however, error above occurred.

    After checking the site packages folder, i do not find the file torch_spline_conv. I will google around for finding out why that is happening, but thought you might have some insights

    Any help is appreciated.

    The complete trace is as follows

    File "src/main.py", line 4, in <module>
        from clustergcn import ClusterGCNTrainer
      File "/media/anuj/Softwares & Study Material/Study Material/MS Stuff/RA/ClusterGCN/src/clustergcn.py", line 5, in <module>
        from layers import StackedGCN
      File "/media/anuj/Softwares & Study Material/Study Material/MS Stuff/RA/ClusterGCN/src/layers.py", line 2, in <module>
        from torch_geometric.nn import GCNConv
      File "/home/anuj/virtualenv-forest/gcn/lib/python3.5/site-packages/torch_geometric/nn/__init__.py", line 1, in <module>
        from .conv import *  # noqa
      File "/home/anuj/virtualenv-forest/gcn/lib/python3.5/site-packages/torch_geometric/nn/conv/__init__.py", line 1, in <module>
        from .spline_conv import SplineConv
      File "/home/anuj/virtualenv-forest/gcn/lib/python3.5/site-packages/torch_geometric/nn/conv/spline_conv.py", line 3, in <module>
        from torch_spline_conv import SplineConv as Conv
    ImportError: No module named 'torch_spline_conv'
    
    
    opened by 1byxero 2
  • For ppi

    For ppi

    Hello. Thanks for your work and code. It's great that Cluster-GCN achieves great performance in PPI datasets. But it seems that you have not opened source the code for PPI node classification.

    Do you find the best model on validation dataset at first then test on the unseen test dataset? I notice that GraphStar now is the SOTA. However, they don't use the validation dataset and directly find the best model on test dataset.

    Can you share code of PPI with us and mention how to split dataset in the readme file? It's important for others to follow your great job.

    opened by guochengqian 2
  • Metis hits a Segmentation fault when running _METIS_PartGraphKway

    Metis hits a Segmentation fault when running _METIS_PartGraphKway

    • I'm using the default test input files.

    • I've attached pdb screenshot during the run.

    • Environment: Ubuntu 18.04 Anaconda (Python 3.7.3),
      torch-geometric==1.3.0 torch-scatter==1.3.0 torch-sparse==0.4.0 torch-spline-conv==1.1.0 metis==0.2a.4

    PDB Error Screenshot from 2019-07-04 13-56-16

    Requirements.txt Screenshot from 2019-07-04 14-02-14

    opened by poppingtonic 2
  • The error of metis, Segmentation fault (core dumped)

    The error of metis, Segmentation fault (core dumped)

    I found that I can use the random model to divide the graph, but when using Metis, the code will terminate abnormally. I want to ask what causes this. I change "IDXTYPEWIDTH = os.getenv('METIS_IDXTYPEWIDTH', '32')" in metis.py (line 31) to "IDXTYPEWIDTH = os.getenv('METIS_IDXTYPEWIDTH', '64')", but it doesn't work!!!

    python src/main.py +-------------------+----------------------+ | Parameter | Value | +===================+======================+ | Cluster number | 10 | +-------------------+----------------------+ | Clustering method | metis | +-------------------+----------------------+ | Dropout | 0.500 | +-------------------+----------------------+ | Edge path | ./input/edges.csv | +-------------------+----------------------+ | Epochs | 200 | +-------------------+----------------------+ | Features path | ./input/features.csv | +-------------------+----------------------+ | Layers | [16, 16, 16] | +-------------------+----------------------+ | Learning rate | 0.010 | +-------------------+----------------------+ | Seed | 42 | +-------------------+----------------------+ | Target path | ./input/target.csv | +-------------------+----------------------+ | Test ratio | 0.900 | +-------------------+----------------------+

    Metis graph clustering started.

    Segmentation fault (core dumped)

    opened by yiyang-wang 1
  • TypeError: object of type 'int' has no len()

    TypeError: object of type 'int' has no len()

    hello, when I run main.py, I found the error message: File "D:\anaconda3.4\lib\site-packages\pymetis_init_.py", line 44, in _prepare_graph for i in range(len(adjacency)): TypeError: object of type 'int' has no len()

    I have installed pymetis package to solve the metis.dll, this error occurs in the pymetis_init_.py. do you know how to solve it?

    opened by tanjia123456 1
  • RuntimeError: Could not locate METIS dll.

    RuntimeError: Could not locate METIS dll.

    hello,when I run main.py, the error massage appears:

    raise RuntimeError('Could not locate METIS dll. Please set the METIS_DLL environment variable to its full path.') RuntimeError: Could not locate METIS dll. Please set the METIS_DLL environment variable to its full path.

    do you know how to solve it?

    opened by tanjia123456 1
  • Runtime error about metis

    Runtime error about metis

    At the train begining that part the full graph, the function "metis.part_graph(self.graph, self.args.cluster_number)" throws an error: Traceback (most recent call last): File "C:/Users/xieRu/Desktop/ML/ClusterGCN/src/main.py", line 30, in <module> main() File "C:/Users/xieRu/Desktop/ML/ClusterGCN/src/main.py", line 19, in main clustering_machine.decompose() File "C:\Users\xieRu\Desktop\ML\ClusterGCN\src\clustering.py", line 38, in decompose self.metis_clustering() File "C:\Users\xieRu\Desktop\ML\ClusterGCN\src\clustering.py", line 56, in metis_clustering (st, parts) = metis.part_graph(self.graph, self.args.cluster_number) File "D:\Program\Anaconda\lib\site-packages\metis.py", line 800, in part_graph _METIS_PartGraphKway(*args) File "D:\Program\Anaconda\lib\site-packages\metis.py", line 677, in _METIS_PartGraphKway adjwgt, nparts, tpwgts, ubvec, options, objval, part) OSError: exception: access violation writing 0x000001B0B9C0E000

    But I tried test package metis as follow, It works: ` import metis from networkx import karate_club_graph

    zkc = karate_club_graph() graph_clustering=metis.part_graph(zkc) ` So, what happend?

    opened by ByskyXie 1
  • some question about code

    some question about code

    It seems like your code didn't consider the connection between clusters,and normalization that are mentioned in paper ,will you add these two options?

    opened by thunderbird0902 1
  • About installation

    About installation

    Hi there: Thank you for your great work, I've finally got the code running. To make the installation in README.md more precise & complete. You may want to add the following dependancies:

    • torch_spline_conv == 1.0.4
    • torch_sparse == 0.2.2
    • torch_scatter == 1.0.4
    • torch_cluster == 1.1.5 (strict)
    opened by dkdk-ddk 1
  • Cannot run main.py

    Cannot run main.py

    src/main.py --epochs 100 +-------------------+----------------------+ | Parameter | Value | +===================+======================+ | Cluster number | 10 | +-------------------+----------------------+ | Clustering method | metis | +-------------------+----------------------+ | Dropout | 0.500 | +-------------------+----------------------+ | Edge path | ./input/edges.csv | +-------------------+----------------------+ | Epochs | 100 | +-------------------+----------------------+ | Features path | ./input/features.csv | +-------------------+----------------------+ | Layers | [16, 16, 16] | +-------------------+----------------------+ | Learning rate | 0.010 | +-------------------+----------------------+ | Seed | 42 | +-------------------+----------------------+ | Target path | ./input/target.csv | +-------------------+----------------------+ | Test ratio | 0.900 | +-------------------+----------------------+

    Metis graph clustering started.

    Traceback (most recent call last): File "src/main.py", line 24, in main() File "src/main.py", line 18, in main clustering_machine.decompose() File "/Users/linmiao/gits/ClusterGCN/src/clustering.py", line 38, in decompose self.metis_clustering() File "/Users/linmiao/gits/ClusterGCN/src/clustering.py", line 56, in metis_clustering (st, parts) = metis.part_graph(self.graph, self.args.cluster_number) File "/usr/local/lib/python3.7/site-packages/metis.py", line 765, in part_graph graph = networkx_to_metis(graph) File "/usr/local/lib/python3.7/site-packages/metis.py", line 574, in networkx_to_metis for i in H.node: AttributeError: 'Graph' object has no attribute 'node'

    opened by linkerlin 1
  • issues about the metis algorithm

    issues about the metis algorithm

    (st, parts) = metis.part_graph(self.graph, self.args.cluster_number) Thanks for your awesome code, could you please tell me how metis conduct the graph partition? Cause the self.graph here doesn't include the information about edge weights and feature attributes.

    opened by immortal13 2
Releases(v_00001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Unofficial implementation of Perceiver IO: A General Architecture for Structured Inputs & Outputs

Perceiver IO Unofficial implementation of Perceiver IO: A General Architecture for Structured Inputs & Outputs Usage import torch from src.perceiver.

Timur Ganiev 111 Nov 15, 2022
[ICCV'21] UNISURF: Unifying Neural Implicit Surfaces and Radiance Fields for Multi-View Reconstruction

UNISURF: Unifying Neural Implicit Surfaces and Radiance Fields for Multi-View Reconstruction Project Page | Paper | Supplementary | Video This reposit

331 Dec 28, 2022
Video2x - A lossless video/GIF/image upscaler achieved with waifu2x, Anime4K, SRMD and RealSR.

Official Discussion Group (Telegram): https://t.me/video2x A Discord server is also available. Please note that most developers are only on Telegram.

K4YT3X 5.9k Dec 31, 2022
ICLR21 Tent: Fully Test-Time Adaptation by Entropy Minimization

⛺️ Tent: Fully Test-Time Adaptation by Entropy Minimization This is the official project repository for Tent: Fully-Test Time Adaptation by Entropy Mi

Dequan Wang 204 Dec 25, 2022
DeFMO: Deblurring and Shape Recovery of Fast Moving Objects (CVPR 2021)

Evaluation, Training, Demo, and Inference of DeFMO DeFMO: Deblurring and Shape Recovery of Fast Moving Objects (CVPR 2021) Denys Rozumnyi, Martin R. O

Denys Rozumnyi 139 Dec 26, 2022
Facial recognition project

Facial recognition project documentation Project introduction This project is developed by linuxu. It is a face model recognition project developed ba

Jefferson 2 Dec 04, 2022
[ICLR 2021] Is Attention Better Than Matrix Decomposition?

Enjoy-Hamburger 🍔 Official implementation of Hamburger, Is Attention Better Than Matrix Decomposition? (ICLR 2021) Under construction. Introduction T

Gsunshine 271 Dec 29, 2022
A set of tools to pre-calibrate and calibrate (multi-focus) plenoptic cameras (e.g., a Raytrix R12) based on the libpleno.

COMPOTE: Calibration Of Multi-focus PlenOpTic camEra. COMPOTE is a set of tools to pre-calibrate and calibrate (multifocus) plenoptic cameras (e.g., a

ComSEE - Computers that SEE 4 May 10, 2022
[NeurIPS 2021] ORL: Unsupervised Object-Level Representation Learning from Scene Images

Unsupervised Object-Level Representation Learning from Scene Images This repository contains the official PyTorch implementation of the ORL algorithm

Jiahao Xie 55 Dec 03, 2022
League of Legends Reinforcement Learning Environment (LoLRLE) multiple training scenarios using PPO.

League of Legends Reinforcement Learning Environment (LoLRLE) About This repo contains code to train an agent to play league of legends in a distribut

2 Aug 19, 2022
SafePicking: Learning Safe Object Extraction via Object-Level Mapping, ICRA 2022

SafePicking Learning Safe Object Extraction via Object-Level Mapping Kentaro Wad

Kentaro Wada 49 Oct 24, 2022
Open source annotation tool for machine learning practitioners.

doccano doccano is an open source text annotation tool for humans. It provides annotation features for text classification, sequence labeling and sequ

7.1k Jan 01, 2023
RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation

RIFE RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation Ported from https://github.com/hzwer/arXiv2020-RIFE Dependencies NumPy

49 Jan 07, 2023
Split Variational AutoEncoder

Split-VAE Split Variational AutoEncoder Introduction This repository contains and implemementation of a Split Variational AutoEncoder (SVAE). In a SVA

Andrea Asperti 2 Sep 02, 2022
Python scripts performing class agnostic object localization using the Object Localization Network model in ONNX.

ONNX Object Localization Network Python scripts performing class agnostic object localization using the Object Localization Network model in ONNX. Ori

Ibai Gorordo 15 Oct 14, 2022
A python module for configuration of block devices

Blivet is a python module for system storage configuration. CI status Licence See COPYING Installation From Fedora repositories Blivet is available in

78 Dec 14, 2022
An example showing how to use jax to train resnet50 on multi-node multi-GPU

jax-multi-gpu-resnet50-example This repo shows how to use jax for multi-node multi-GPU training. The example is adapted from the resnet50 example in d

Yangzihao Wang 20 Jul 04, 2022
Study of human inductive biases in CNNs and Transformers.

Are Convolutional Neural Networks or Transformers more like human vision? This repository contains the code and fine-tuned models of popular Convoluti

Shikhar Tuli 39 Dec 08, 2022
TensorFlow-LiveLessons - "Deep Learning with TensorFlow" LiveLessons

TensorFlow-LiveLessons Note that the second edition of this video series is now available here. The second edition contains all of the content from th

Deep Learning Study Group 830 Jan 03, 2023
Heart Arrhythmia Classification

This program takes and input of an ECG in European Data Format (EDF) and outputs the classification for heartbeats into normal vs different types of arrhythmia . It uses a deep learning model for cla

4 Nov 02, 2022