Code for NeurIPS 2020 article "Contrastive learning of global and local features for medical image segmentation with limited annotations"

Overview

Contrastive learning of global and local features for medical image segmentation with limited annotations

The code is for the article "Contrastive learning of global and local features for medical image segmentation with limited annotations" which got accepted as an Oral presentation at NeurIPS 2020 (33rd international conference on Neural Information Processing Systems). With the proposed pre-training method using Contrastive learning, we get competitive segmentation performance with just 2 labeled training volumes compared to a benchmark that is trained with many labeled volumes.
https://arxiv.org/abs/2006.10511

Observations / Conclusions:

  1. For medical image segmentation, the proposed contrastive pre-training strategy incorporating domain knowledge present naturally across medical volumes yields better performance than baseline, other pre-training methods, semi-supervised, and data augmentation methods.
  2. Proposed local contrastive loss, an extension of global loss, provides an additional boost in performance by learning distinctive local-level representation to distinguish between neighbouring regions.
  3. The proposed pre-training strategy is complementary to semi-supervised and data augmentation methods. Combining them yields a further boost in accuracy.

Authors:
Krishna Chaitanya (email),
Ertunc Erdil,
Neerav Karani,
Ender Konukoglu.

Requirements:
Python 3.6.1,
Tensorflow 1.12.0,
rest of the requirements are mentioned in the "requirements.txt" file.

I) To clone the git repository.
git clone https://github.com/krishnabits001/domain_specific_dl.git

II) Install python, required packages and tensorflow.
Then, install python packages required using below command or the packages mentioned in the file.
pip install -r requirements.txt

To install tensorflow
pip install tensorflow-gpu=1.12.0

III) Dataset download.
To download the ACDC Cardiac dataset, check the website :
https://www.creatis.insa-lyon.fr/Challenge/acdc.

To download the Medical Decathlon Prostate dataset, check the website :
http://medicaldecathlon.com/

To download the MMWHS Cardiac dataset, check the website :
http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/mmwhs/

All the images were bias corrected using N4 algorithm with a threshold value of 0.001. For more details, refer to the "N4_bias_correction.py" file in scripts.
Image and label pairs are re-sampled (to chosen target resolution) and cropped/zero-padded to a fixed size using "create_cropped_imgs.py" file.

IV) Train the models.
Below commands are an example for ACDC dataset.
The models need to be trained sequentially as follows (check "train_model/pretrain_and_fine_tune_script.sh" script for commands)
Steps :

  1. Step 1: To pre-train the encoder with global loss by incorporating proposed domain knowledge when defining positive and negative pairs.
    cd train_model/
    python pretr_encoder_global_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --global_loss_exp_no=2 --n_parts=4 --temp_fac=0.1 --bt_size=12

  2. Step 2: After step 1, we pre-train the decoder with proposed local loss to aid segmentation task by learning distinctive local-level representations.
    python pretr_decoder_local_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --bt_size=12

  3. Step 3: We use the pre-trained encoder and decoder weights as initialization and fine-tune to segmentation task using limited annotations.
    python ft_pretr_encoder_decoder_net_local_loss.py --dataset=acdc --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --ver=0

To train the baseline with affine and random deformations & intensity transformations for comparison, use the below code file.
cd train_model/
python tr_baseline.py --dataset=acdc --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --ver=0

V) Config files contents.
One can modify the contents of the below 2 config files to run the required experiments.
experiment_init directory contains 2 files.
Example for ACDC dataset:

  1. init_acdc.py
    --> contains the config details like target resolution, image dimensions, data path where the dataset is stored and path to save the trained models.
  2. data_cfg_acdc.py
    --> contains an example of data config details where one can set the patient ids which they want to use as train, validation and test images.

Bibtex citation:

@article{chaitanya2020contrastive,
  title={Contrastive learning of global and local features for medical image segmentation with limited annotations},
  author={Chaitanya, Krishna and Erdil, Ertunc and Karani, Neerav and Konukoglu, Ender},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}
Owner
Krishna Chaitanya
Doctoral Student, ETH Zurich
Krishna Chaitanya
Code for the paper "Learning-Augmented Algorithms for Online Steiner Tree"

Learning-Augmented Algorithms for Online Steiner Tree This is the code for the paper "Learning-Augmented Algorithms for Online Steiner Tree". Requirem

0 Dec 09, 2021
Neural-net-from-scratch - A simple Neural Network from scratch in Python using the Pymathrix library

A Simple Neural Network from scratch A Simple Neural Network from scratch in Pyt

Youssef Chafiqui 2 Jan 07, 2022
Pytorch implementation of Each Part Matters: Local Patterns Facilitate Cross-view Geo-localization https://arxiv.org/abs/2008.11646

[TCSVT] Each Part Matters: Local Patterns Facilitate Cross-view Geo-localization LPN [Paper] NEWs Prerequisites Python 3.6 GPU Memory = 8G Numpy 1.

46 Dec 14, 2022
Multi-Template Mouse Brain MRI Atlas (MBMA): both in-vivo and ex-vivo

Multi-template MRI mouse brain atlas (both in vivo and ex vivo) Mouse Brain MRI atlas (both in-vivo and ex-vivo) (repository relocated from the origin

8 Nov 18, 2022
A Transformer-Based Feature Segmentation and Region Alignment Method For UAV-View Geo-Localization

University1652-Baseline [Paper] [Slide] [Explore Drone-view Data] [Explore Satellite-view Data] [Explore Street-view Data] [Video Sample] [中文介绍] This

Zhedong Zheng 335 Jan 06, 2023
Graduation Project

Gesture-Detection-and-Depth-Estimation This is my graduation project. (1) In this project, I use the YOLOv3 object detection model to detect gesture i

ChaosAT 1 Nov 23, 2021
A very simple tool for situations where optimization with onnx-simplifier would exceed the Protocol Buffers upper file size limit of 2GB, or simply to separate onnx files to any size you want.

sne4onnx A very simple tool for situations where optimization with onnx-simplifier would exceed the Protocol Buffers upper file size limit of 2GB, or

Katsuya Hyodo 10 Aug 30, 2022
Ejemplo Algoritmo Viterbi - Example of a Viterbi algorithm applied to a hidden Markov model on DNA sequence

Ejemplo Algoritmo Viterbi Ejemplo de un algoritmo Viterbi aplicado a modelo ocul

Mateo Velásquez Molina 1 Jan 10, 2022
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Pawel Dziemiach 1 Dec 18, 2021
Meta-Learning Sparse Implicit Neural Representations (NeurIPS 2021)

Meta-SparseINR Official PyTorch implementation of "Meta-learning Sparse Implicit Neural Representations" (NeurIPS 2021) by Jaeho Lee*, Jihoon Tack*, N

Jaeho Lee 41 Nov 10, 2022
NeRD: Neural Reflectance Decomposition from Image Collections

NeRD: Neural Reflectance Decomposition from Image Collections Project Page | Video | Paper | Dataset Implementation for NeRD. A novel method which dec

Computergraphics (University of Tübingen) 195 Dec 29, 2022
TAug :: Time Series Data Augmentation using Deep Generative Models

TAug :: Time Series Data Augmentation using Deep Generative Models Note!!! The package is under development so be careful for using in production! Fea

35 Dec 06, 2022
A real world application of a Recurrent Neural Network on a binary classification of time series data

What is this This is a real world application of a Recurrent Neural Network on a binary classification of time series data. This project includes data

Josep Maria Salvia Hornos 2 Jan 30, 2022
Real-world Anomaly Detection in Surveillance Videos- pytorch Re-implementation

Real world Anomaly Detection in Surveillance Videos : Pytorch RE-Implementation This repository is a re-implementation of "Real-world Anomaly Detectio

seominseok 62 Dec 08, 2022
SCAN: Learning to Classify Images without Labels, incl. SimCLR. [ECCV 2020]

Learning to Classify Images without Labels This repo contains the Pytorch implementation of our paper: SCAN: Learning to Classify Images without Label

Wouter Van Gansbeke 1.1k Dec 30, 2022
PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

neural-combinatorial-rl-pytorch PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning. I have implemented the basic

Patrick E. 454 Jan 06, 2023
Instance Segmentation in 3D Scenes using Semantic Superpoint Tree Networks

SSTNet Instance Segmentation in 3D Scenes using Semantic Superpoint Tree Networks(ICCV2021) by Zhihao Liang, Zhihao Li, Songcen Xu, Mingkui Tan, Kui J

83 Nov 29, 2022
a reimplementation of Optical Flow Estimation using a Spatial Pyramid Network in PyTorch

pytorch-spynet This is a personal reimplementation of SPyNet [1] using PyTorch. Should you be making use of this work, please cite the paper according

Simon Niklaus 269 Jan 02, 2023
Pytorch implementation for the EMNLP 2020 (Findings) paper: Connecting the Dots: A Knowledgeable Path Generator for Commonsense Question Answering

Path-Generator-QA This is a Pytorch implementation for the EMNLP 2020 (Findings) paper: Connecting the Dots: A Knowledgeable Path Generator for Common

Peifeng Wang 33 Dec 05, 2022
Repository for the Bias Benchmark for QA dataset.

BBQ Repository for the Bias Benchmark for QA dataset. Authors: Alicia Parrish, Angelica Chen, Nikita Nangia, Vishakh Padmakumar, Jason Phang, Jana Tho

ML² AT CILVR 18 Nov 18, 2022