A PyTorch Implementation of "Watch Your Step: Learning Node Embeddings via Graph Attention" (NeurIPS 2018).

Overview

Attention Walk

Arxiv codebeat badge repo sizebenedekrozemberczki

A PyTorch Implementation of Watch Your Step: Learning Node Embeddings via Graph Attention (NIPS 2018).

Abstract

Graph embedding methods represent nodes in a continuous vector space, preserving different types of relational information from the graph. There are many hyper-parameters to these methods (e.g. the length of a random walk) which have to be manually tuned for every graph. In this paper, we replace previously fixed hyper-parameters with trainable ones that we automatically learn via backpropagation. In particular, we propose a novel attention model on the power series of the transition matrix, which guides the random walk to optimize an upstream objective. Unlike previous approaches to attention models, the method that we propose utilizes attention parameters exclusively on the data itself (e.g. on the random walk), and are not used by the model for inference. We experiment on link prediction tasks, as we aim to produce embeddings that best-preserve the graph structure, generalizing to unseen information. We improve state-of-the-art results on a comprehensive suite of real-world graph datasets including social, collaboration, and biological networks, where we observe that our graph attention model can reduce the error by up to 20%-40%. We show that our automatically-learned attention parameters can vary significantly per graph, and correspond to the optimal choice of hyper-parameter if we manually tune existing methods.

This repository provides an implementation of Attention Walk as described in the paper:

Watch Your Step: Learning Node Embeddings via Graph Attention. Sami Abu-El-Haija, Bryan Perozzi, Rami Al-Rfou, Alexander A. Alemi. NIPS, 2018. [Paper]

The original Tensorflow implementation is available [here].

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torchvision       0.3.0

Datasets

The code takes an input graph in a csv file. Every row indicates an edge between two nodes separated by a comma. The first row is a header. Nodes should be indexed starting with 0. Sample graphs for the `Twitch Brasilians` and `Wikipedia Chameleons` are included in the `input/` directory.

### Options

Learning of the embedding is handled by the src/main.py script which provides the following command line arguments.

Input and output options

  --edge-path         STR   Input graph path.     Default is `input/chameleon_edges.csv`.
  --embedding-path    STR   Embedding path.       Default is `output/chameleon_AW_embedding.csv`.
  --attention-path    STR   Attention path.       Default is `output/chameleon_AW_attention.csv`.

Model options

  --dimensions           INT       Number of embeding dimensions.        Default is 128.
  --epochs               INT       Number of training epochs.            Default is 200.
  --window-size          INT       Skip-gram window size.                Default is 5.
  --learning-rate        FLOAT     Learning rate value.                  Default is 0.01.
  --beta                 FLOAT     Attention regularization parameter.   Default is 0.5.
  --gamma                FLOAT     Embedding regularization parameter.   Default is 0.5.
  --num-of-walks         INT       Number of walks per source node.      Default is 80.

Examples

The following commands learn a graph embedding and write the embedding to disk. The node representations are ordered by the ID.

Creating an Attention Walk embedding of the default dataset with the standard hyperparameter settings. Saving this embedding at the default path.

``` python src/main.py ```

Creating an Attention Walk embedding of the default dataset with 256 dimensions.

python src/main.py --dimensions 256

Creating an Attention Walk embedding of the default dataset with a higher window size.

python src/main.py --window-size 20

Creating an embedding of another dataset the Twitch Brasilians. Saving the outputs under custom file names.

python src/main.py --edge-path input/ptbr_edges.csv --embedding-path output/ptbr_AW_embedding.csv --attention-path output/ptbr_AW_attention.csv

License


Comments
  • Nan parameters

    Nan parameters

    Thanks for your pytorch code. I found that my parameters become Nan during training. Nan parameters include model.left_factors, model.right_factors, model.attention. All the entries of them become Nan during training. And also the loss. I'm trying to find the reason. I would appreciate it if you could give me some help or hints.

    opened by kkkkk001 9
  • Memory Error

    Memory Error

    I'm getting OOM errors even with small files. The attached file link_network.txt throws the following error:

    Adjacency matrix powers: 100%|███████████████████████████████████████████████████████| 4/4 [00:00<00:00, 108.39it/s]
    Traceback (most recent call last):
      File "src\main.py", line 79, in <module>
        main()
      File "src\main.py", line 74, in main
        model = AttentionWalkTrainer(args)
      File "E:\AttentionWalk\src\attentionwalk.py", line 70, in __init__
        self.initialize_model_and_features()
      File "E:\AttentionWalk\src\attentionwalk.py", line 76, in initialize_model_and_features
        self.target_tensor = feature_calculator(self.args, self.graph)
      File "E:\AttentionWalk\src\utils.py", line 53, in feature_calculator
        target_matrices = np.array(target_matrices)
    MemoryError
    

    I guess this is due to the large indices of the nodes. Any workarounds for this?

    opened by davidlenz 2
  • modified normalized_adjacency_matrix calculation

    modified normalized_adjacency_matrix calculation

    As mentioned in this issue: https://github.com/benedekrozemberczki/AttentionWalk/issues/9

    Added normalization into calculation, able to prevent unbalanced loss and prevent loss_on_mat to be extreme big while node count of data is big.

    opened by neilctwu 1
  • miscalculations of normalized adjacency matrix

    miscalculations of normalized adjacency matrix

    Thanks for sharing this awesome repo.

    The issue is I found that loss_on_target will become extreme big while training from the original code, and I think is due to the miscalculation of normalized_adjacency_matrix.

    From your original code, normalized_adjacency_matrix is been calculated by:

    normalized_adjacency_matrix = degs.dot(adjacency_matrix)
    

    However while the matrix hasn't been normalize but simply multiple by degree of nodes. I think the part of normalized_adjacency_matrix should be modified like its original definition:

      normalized_adjacency_matrix = degs.power(-1/2)\
                                        .dot(adjacency_matrix)\
                                        .dot(degs.power(-1/2))
    

    It'll turn out to be more reasonable loss shown below: image

    Am I understand it correctly?

    opened by neilctwu 1
  • problem with being killed

    problem with being killed

    Hi, I tried to train the model with new dataset which have about 60000 nodes, but I have a problem of getting Killed suddenly. Do you have any idea why? Thanks :) image

    opened by amy-hyunji 1
  • Directed weighted graphs

    Directed weighted graphs

    Is it possible to use the code with directed and weighted graphs? The paper states the attention walk framework for unweighted graphs only, but i'd like to use it for such types of networks. Thank you for your attention.

    opened by federicoairoldi 1
Releases(v_00001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Training Structured Neural Networks Through Manifold Identification and Variance Reduction

Training Structured Neural Networks Through Manifold Identification and Variance Reduction This repository is a pytorch implementation of the Regulari

0 Dec 23, 2021
An implementation of the proximal policy optimization algorithm

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

Martin Huber 59 Dec 09, 2022
[PyTorch] Official implementation of CVPR2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency". https://arxiv.org/abs/2103.05465

PointDSC repository PyTorch implementation of PointDSC for CVPR'2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency",

153 Dec 14, 2022
This repository contains the code needed to train Mega-NeRF models and generate the sparse voxel octrees

Mega-NeRF This repository contains the code needed to train Mega-NeRF models and generate the sparse voxel octrees used by the Mega-NeRF-Dynamic viewe

cmusatyalab 260 Dec 28, 2022
House-GAN++: Generative Adversarial Layout Refinement Network towards Intelligent Computational Agent for Professional Architects

House-GAN++ Code and instructions for our paper: House-GAN++: Generative Adversarial Layout Refinement Network towards Intelligent Computational Agent

122 Dec 28, 2022
PyAF is an Open Source Python library for Automatic Time Series Forecasting built on top of popular pydata modules.

PyAF (Python Automatic Forecasting) PyAF is an Open Source Python library for Automatic Forecasting built on top of popular data science python module

CARME Antoine 405 Jan 02, 2023
Code for Phase diagram of Stochastic Gradient Descent in high-dimensional two-layer neural networks

Phase diagram of Stochastic Gradient Descent in high-dimensional two-layer neural networks Under construction. Description Code for Phase diagram of S

Rodrigo Veiga 3 Nov 24, 2022
The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection .

GCoNet The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection . Trained model Download final_gconet.pth

Qi Fan 46 Nov 17, 2022
A library for hidden semi-Markov models with explicit durations

hsmmlearn hsmmlearn is a library for unsupervised learning of hidden semi-Markov models with explicit durations. It is a port of the hsmm package for

Joris Vankerschaver 69 Dec 20, 2022
Semantic-aware Grad-GAN for Virtual-to-Real Urban Scene Adaption

SG-GAN TensorFlow implementation of SG-GAN. Prerequisites TensorFlow (implemented in v1.3) numpy scipy pillow Getting Started Train Prepare dataset. W

lplcor 61 Jun 07, 2022
This computer program provides a reference implementation of Lagrangian Monte Carlo in metric induced by the Monge patch

This computer program provides a reference implementation of Lagrangian Monte Carlo in metric induced by the Monge patch. The code was prepared to the final version of the accepted manuscript in AIST

Marcelo Hartmann 2 May 06, 2022
TorchX: A PyTorch Extension Library for More Efficient Deep Learning

TorchX TorchX: A PyTorch Extension Library for More Efficient Deep Learning. @misc{torchx, author = {Ansheng You and Changxu Wang}, title = {T

Donny You 8 May 28, 2022
Deep learning image registration library for PyTorch

TorchIR: Pytorch Image Registration TorchIR is a image registration library for deep learning image registration (DLIR). I have integrated several ide

Bob de Vos 40 Dec 16, 2022
Evolutionary Scale Modeling (esm): Pretrained language models for proteins

Evolutionary Scale Modeling This repository contains code and pre-trained weights for Transformer protein language models from Facebook AI Research, i

Meta Research 1.6k Jan 09, 2023
[ICCV 2021] FaPN: Feature-aligned Pyramid Network for Dense Image Prediction

FaPN: Feature-aligned Pyramid Network for Dense Image Prediction [arXiv] [Project Page] @inproceedings{ huang2021fapn, title={{FaPN}: Feature-alig

EMI-Group 175 Dec 30, 2022
Pixel-wise segmentation on VOC2012 dataset using pytorch.

PiWiSe Pixel-wise segmentation on the VOC2012 dataset using pytorch. FCN SegNet PSPNet UNet RefineNet For a more complete implementation of segmentati

Bodo Kaiser 378 Dec 30, 2022
MoveNet Single Pose on DepthAI

MoveNet Single Pose tracking on DepthAI Running Google MoveNet Single Pose models on DepthAI hardware (OAK-1, OAK-D,...). A convolutional neural netwo

64 Dec 29, 2022
Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering

Graph ConvNets in PyTorch October 15, 2017 Xavier Bresson http://www.ntu.edu.sg/home/xbresson https://github.com/xbresson https://twitter.com/xbresson

Xavier Bresson 287 Jan 04, 2023
Using contrastive learning and OpenAI's CLIP to find good embeddings for images with lossy transformations

The official code for the paper "Inverse Problems Leveraging Pre-trained Contrastive Representations" (to appear in NeurIPS 2021).

Sriram Ravula 26 Dec 10, 2022
Compare outputs between layers written in Tensorflow and layers written in Pytorch

Compare outputs of Wasserstein GANs between TensorFlow vs Pytorch This is our testing module for the implementation of improved WGAN in Pytorch Prereq

Hung Nguyen 72 Dec 20, 2022