This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Overview

Predicting Patient Outcomes with Graph Representation Learning

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning. You can watch a video of the spotlight talk at W3PHIAI (AAAI workshop) here:

Watch the video

Citation

If you use this code or the models in your research, please cite the following:

@misc{rocheteautong2021,
      title={Predicting Patient Outcomes with Graph Representation Learning}, 
      author={Emma Rocheteau and Catherine Tong and Petar Veličković and Nicholas Lane and Pietro Liò},
      year={2021},
      eprint={2101.03940},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Motivation

Recent work on predicting patient outcomes in the Intensive Care Unit (ICU) has focused heavily on the physiological time series data, largely ignoring sparse data such as diagnoses and medications. When they are included, they are usually concatenated in the late stages of a model, which may struggle to learn from rarer disease patterns. Instead, we propose a strategy to exploit diagnoses as relational information by connecting similar patients in a graph. To this end, we propose LSTM-GNN for patient outcome prediction tasks: a hybrid model combining Long Short-Term Memory networks (LSTMs) for extracting temporal features and Graph Neural Networks (GNNs) for extracting the patient neighbourhood information. We demonstrate that LSTM-GNNs outperform the LSTM-only baseline on length of stay prediction tasks on the eICU database. More generally, our results indicate that exploiting information from neighbouring patient cases using graph neural networks is a promising research direction, yielding tangible returns in supervised learning performance on Electronic Health Records.

Pre-Processing Instructions

eICU Pre-Processing

  1. To run the sql files you must have the eICU database set up: https://physionet.org/content/eicu-crd/2.0/.

  2. Follow the instructions: https://eicu-crd.mit.edu/tutorials/install_eicu_locally/ to ensure the correct connection configuration.

  3. Replace the eICU_path in paths.json to a convenient location in your computer, and do the same for eICU_preprocessing/create_all_tables.sql using find and replace for '/Users/emmarocheteau/PycharmProjects/eICU-GNN-LSTM/eICU_data/'. Leave the extra '/' at the end.

  4. In your terminal, navigate to the project directory, then type the following commands:

    psql 'dbname=eicu user=eicu options=--search_path=eicu'
    

    Inside the psql console:

    \i eICU_preprocessing/create_all_tables.sql
    

    This step might take a couple of hours.

    To quit the psql console:

    \q
    
  5. Then run the pre-processing scripts in your terminal. This will need to run overnight:

    python3 -m eICU_preprocessing.run_all_preprocessing
    

Graph Construction

To make the graphs, you can use the following scripts:

This is to make most of the graphs that we use. You can alter the arguments given to this script.

python3 -m graph_construction.create_graph --freq_adjust --penalise_non_shared --k 3 --mode k_closest

Write the diagnosis strings into eICU_data folder:

python3 -m graph_construction.get_diagnosis_strings

Get the bert embeddings:

python3 -m graph_construction.bert

Create the graph from the bert embeddings:

python3 -m graph_construction.create_bert_graph --k 3 --mode k_closest

Alternatively, you can request to download our graphs using this link: https://drive.google.com/drive/folders/1yWNLhGOTPhu6mxJRjKCgKRJCJjuToBS4?usp=sharing

Training the ML Models

Before proceeding to training the ML models, do the following.

  1. Define data_dir, graph_dir, log_path and ray_dir in paths.json to convenient locations.

  2. Run the following to unpack the processed eICU data into mmap files for easy loading during training. The mmap files will be saved in data_dir.

    python3 -m src.dataloader.convert
    

The following commands train and evaluate the models introduced in our paper.

N.B.

  • The models are structured using pytorch-lightning. Graph neural networks and neighbourhood sampling are implemented using pytorch-geometric.

  • Our models assume a default graph which is made with k=3 under a k-closest scheme. If you wish to use other graphs, refer to read_graph_edge_list in src/dataloader/pyg_reader.py to add a reference handle to version2filename for your graph.

  • The default task is In-House-Mortality Prediction (ihm), add --task los to the command to perform the Length-of-Stay Prediction (los) task instead.

  • These commands use the best set of hyperparameters; To use other hyperparameters, remove --read_best from the command and refer to src/args.py.

a. LSTM-GNN

The following runs the training and evaluation for LSTM-GNN models. --gnn_name can be set as gat, sage, or mpnn. When mpnn is used, add --ns_sizes 10 to the command.

python3 -m train_ns_lstmgnn --bilstm --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.lstmgnn_search --bilstm --ts_mask --add_flat --class_weights  --gnn_name gat --add_diag

b. Dynamic LSTM-GNN

The following runs the training & evaluation for dynamic LSTM-GNN models. --gnn_name can be set as gcn, gat, or mpnn.

python3 -m train_dynamic --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.dynamic_lstmgnn_search --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn

c. GNN

The following runs the GNN models (with neighbourhood sampling). --gnn_name can be set as gat, sage, or mpnn. When mpnn is used, add --ns_sizes 10 to the command.

python3 -m train_ns_gnn --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.ns_gnn_search --ts_mask --add_flat --class_weights --gnn_name gat --add_diag

d. LSTM (Baselines)

The following runs the baseline bi-LSTMs. To remove diagnoses from the input vector, remove --add_diag from the command.

python3 -m train_ns_lstm --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.lstm_search --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag
Owner
Emma Rocheteau
Computer Science PhD Student at Cambridge
Emma Rocheteau
Bringing Computer Vision and Flutter together , to build an awesome app !!

Bringing Computer Vision and Flutter together , to build an awesome app !! Explore the Directories Flutter · Machine Learning Table of Contents About

Padmanabha Banerjee 14 Apr 07, 2022
magiCARP: Contrastive Authoring+Reviewing Pretraining

magiCARP: Contrastive Authoring+Reviewing Pretraining Welcome to the magiCARP API, the test bed used by EleutherAI for performing text/text bi-encoder

EleutherAI 43 Dec 29, 2022
I-SECRET: Importance-guided fundus image enhancement via semi-supervised contrastive constraining

I-SECRET This is the implementation of the MICCAI 2021 Paper "I-SECRET: Importance-guided fundus image enhancement via semi-supervised contrastive con

13 Dec 02, 2022
Libtorch yolov3 deepsort

Overview It is for my undergrad thesis in Tsinghua University. There are four modules in the project: Detection: YOLOv3 Tracking: SORT and DeepSORT Pr

Xu Wei 226 Dec 13, 2022
MoCoPnet - Deformable 3D Convolution for Video Super-Resolution

MoCoPnet: Exploring Local Motion and Contrast Priors for Infrared Small Target Super-Resolution Pytorch implementation of local motion and contrast pr

Xinyi Ying 28 Dec 15, 2022
Spectralformer: Rethinking hyperspectral image classification with transformers

The code in this toolbox implements the "Spectralformer: Rethinking hyperspectral image classification with transformers". More specifically, it is detailed as follow.

Danfeng Hong 104 Jan 04, 2023
NeuralDiff: Segmenting 3D objects that move in egocentric videos

NeuralDiff: Segmenting 3D objects that move in egocentric videos Project Page | Paper + Supplementary | Video About This repository contains the offic

Vadim Tschernezki 14 Dec 05, 2022
A CNN implementation using only numpy. Supports multidimensional images, stride, etc.

A CNN implementation using only numpy. Supports multidimensional images, stride, etc. Speed up due to heavy use of slicing and mathematical simplification..

2 Nov 30, 2021
Few-Shot Graph Learning for Molecular Property Prediction

Few-shot Graph Learning for Molecular Property Prediction Introduction This is the source code and dataset for the following paper: Few-shot Graph Lea

Zhichun Guo 94 Dec 12, 2022
A web application that provides real time temperature and humidity readings of a house.

About A web application which provides real time temperature and humidity readings of a house. If you're interested in the data collected so far click

Ben Thompson 3 Jan 28, 2022
Hypersim: A Photorealistic Synthetic Dataset for Holistic Indoor Scene Understanding

The Hypersim Dataset For many fundamental scene understanding tasks, it is difficult or impossible to obtain per-pixel ground truth labels from real i

Apple 1.3k Jan 04, 2023
Exploiting a Zoo of Checkpoints for Unseen Tasks

Exploiting a Zoo of Checkpoints for Unseen Tasks This repo includes code to reproduce all results in the above Neurips paper, authored by Jiaji Huang,

Baidu Research 8 Sep 06, 2022
Evolutionary Scale Modeling (esm): Pretrained language models for proteins

Evolutionary Scale Modeling This repository contains code and pre-trained weights for Transformer protein language models from Facebook AI Research, i

Meta Research 1.6k Jan 09, 2023
HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep.

HODEmu HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep. and emulates satellite abundance as a function of co

Antonio Ragagnin 1 Oct 13, 2021
Pytorch implementation of AngularGrad: A New Optimization Technique for Angular Convergence of Convolutional Neural Networks

AngularGrad Optimizer This repository contains the oficial implementation for AngularGrad: A New Optimization Technique for Angular Convergence of Con

mario 124 Sep 16, 2022
DCA - Official Python implementation of Delaunay Component Analysis algorithm

Delaunay Component Analysis (DCA) Official Python implementation of the Delaunay

Petra Poklukar 9 Sep 06, 2022
Implementation of Sequence Generative Adversarial Nets with Policy Gradient

SeqGAN Requirements: Tensorflow r1.0.1 Python 2.7 CUDA 7.5+ (For GPU) Introduction Apply Generative Adversarial Nets to generating sequences of discre

Lantao Yu 2k Dec 29, 2022
Learning Temporal Consistency for Low Light Video Enhancement from Single Images (CVPR2021)

StableLLVE This is a Pytorch implementation of "Learning Temporal Consistency for Low Light Video Enhancement from Single Images" in CVPR 2021, by Fan

99 Dec 19, 2022
Vrcwatch - Supply the local time to VRChat as Avatar Parameters through OSC

English: README-EN.md VRCWatch VRCWatch は、VRChat 内のアバター向けに現在時刻を送信するためのプログラムです。 使

Kosaki Mezumona 17 Nov 30, 2022
A denoising diffusion probabilistic model synthesises galaxies that are qualitatively and physically indistinguishable from the real thing.

Realistic galaxy simulation via score-based generative models Official code for 'Realistic galaxy simulation via score-based generative models'. We us

Michael Smith 32 Dec 20, 2022