unofficial pytorch implementation of RefineGAN

Overview

RefineGAN

unofficial pytorch implementation of RefineGAN (https://arxiv.org/abs/1709.00753) for CSMRI reconstruction, the official code using tensorpack can be found at https://github.com/tmquan/RefineGAN

To Do

  • run the original tensorpack code (sorry, can't run tensorpack on my GPU)
  • pytorch implementation and experiments on brain images with radial mask
  • bug fixed. the mean psnr of zero-filled image is not exactly the same as the value in original paper, although the model improvement is similar
  • experiments on different masks

Install

python>=3.7.11 is required with all requirements.txt installed including pytorch>=1.10.0

git clone https://github.com/hellopipu/RefineGAN.git
cd RefineGAN
pip install -r requirements.txt

How to use

for training:

cd run_sh
sh train.sh

the model will be saved in folder weight, tensorboard information will be saved in folder log. You can change the arguments in script such as --mask_type and --sampling_rate for different experiment settings.

for tensorboard:

check the training curves while training

tensorboard --logdir log

the training info of my experiments is already in log folder

for testing:

test after training, or you can download my trained model weights from google drive.

cd run_sh
sh test.sh

for visualization:

cd run_sh
sh visualize.sh

training curves

sampling rates : 10%(light orange), 20%(dark blue), 30%(dark orange), 40%(light blue). You can check more loss curves of my experiments using tensorboard.

loss_G_loss_total loss_recon_img_Aa

PSNR on training set over 500 epochs, compared with results shown in original paper.

my_train_psnr paper_train_psnr

Test results

mean PSNR on validation dataset with radial mask of different sampling rates, batch_size is set as 4;

model 10% 20% 30% 40%
zero-filled 22.296 25.806 28.997 31.699
RefineGAN 32.705 36.734 39.961 42.903

Test cases visualization

rate from left to right: mask, zero-filled, prediction and ground truth error (zero-filled) and error (prediction)
10%
20%
30%
40%

Notes on RefineGAN

  • data processing before training : complex value represents in 2-channel , each channel rescale to [-1,1]; accordingly the last layer of generator is tanh()
  • Generator uses residual learning for reconstruction task
  • Generator is a cascade of two U-net, the U-net doesn't do concatenation but addition when combining the enc and dec features.
  • each U-net is followed by a Data-consistency (DC) module, although the paper doesn't mention it.
  • the last layer of generator is tanh layer on two-channel output, so when we revert output to original pixel scale and calculate abs, the pixel value may exceed 255; we need to do clipping while calculating psnr
  • while training, we get two random image samples A, B for each iteration, RefineGAN calculates a large amount of losses (it may be redundant) including reconstruction loss on different phases of generator output in both image domain and frequency domain, total variantion loss and WGAN loss
  • one special loss is D_loss_AB, D is trained to only distinguish from real samples and fake samples, so D should not only work for (real A, fake A) or (real B, fake B), but also work for (real A, fake B) input
  • WGAN-gp may be used to improve the performance
  • small batch size MAY BE better. In my experiment, batch_size=4 is better than batch_size=16

I will appreciate if you can find any implementation mistakes in codes.

Owner
xinby17
research interest: Medical Image Analysis, Computer Vision
xinby17
Official implementation for "Symbolic Learning to Optimize: Towards Interpretability and Scalability"

Symbolic Learning to Optimize This is the official implementation for ICLR-2022 paper "Symbolic Learning to Optimize: Towards Interpretability and Sca

VITA 8 Dec 19, 2022
Code for "Typilus: Neural Type Hints" PLDI 2020

Typilus A deep learning algorithm for predicting types in Python. Please find a preprint here. This repository contains its implementation (src/) and

47 Nov 08, 2022
Deep learning models for classification of 15 common weeds in the southern U.S. cotton production systems.

CottonWeeds Deep learning models for classification of 15 common weeds in the southern U.S. cotton production systems. requirements pytorch torchsumma

Dong Chen 8 Jun 07, 2022
This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive Selective Coding)

HCSC: Hierarchical Contrastive Selective Coding This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive

YUANFAN GUO 111 Dec 20, 2022
Non-Homogeneous Poisson Process Intensity Modeling and Estimation using Measure Transport

Non-Homogeneous Poisson Process Intensity Modeling and Estimation using Measure Transport This GitHub page provides code for reproducing the results i

Andrew Zammit Mangion 1 Nov 08, 2021
Video-Captioning - A machine Learning project to generate captions for video frames indicating the relationship between the objects in the video

Video-Captioning - A machine Learning project to generate captions for video frames indicating the relationship between the objects in the video

1 Jan 23, 2022
AdamW optimizer and cosine learning rate annealing with restarts

AdamW optimizer and cosine learning rate annealing with restarts This repository contains an implementation of AdamW optimization algorithm and cosine

Maksym Pyrozhok 133 Dec 20, 2022
Machine Learning Toolkit for Kubernetes

Kubeflow the cloud-native platform for machine learning operations - pipelines, training and deployment. Documentation Please refer to the official do

Kubeflow 12.1k Jan 03, 2023
Benchmarking Pipeline for Prediction of Protein-Protein Interactions

B4PPI Benchmarking Pipeline for the Prediction of Protein-Protein Interactions How this benchmarking pipeline has been built, and how to use it, is de

Loïc Lannelongue 4 Jun 27, 2022
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition

Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition

107 Dec 02, 2022
Tensorflow2 Keras-based Semantic Segmentation Models Implementation

Tensorflow2 Keras-based Semantic Segmentation Models Implementation

Hah Min Lew 1 Feb 08, 2022
A library to inspect itermediate layers of PyTorch models.

A library to inspect itermediate layers of PyTorch models. Why? It's often the case that we want to inspect intermediate layers of a model without mod

archinet.ai 380 Dec 28, 2022
IAST: Instance Adaptive Self-training for Unsupervised Domain Adaptation (ECCV 2020)

This repo is the official implementation of our paper "Instance Adaptive Self-training for Unsupervised Domain Adaptation". The purpose of this repo is to better communicate with you and respond to y

CVSM Group - email: <a href=[email protected]"> 84 Dec 12, 2022
Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition

Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition Official implementation of the Efficient Conforme

Maxime Burchi 145 Dec 30, 2022
Co-GAIL: Learning Diverse Strategies for Human-Robot Collaboration

CoGAIL Table of Content Overview Installation Dataset Training Evaluation Trained Checkpoints Acknowledgement Citations License Overview This reposito

Jeremy Wang 29 Dec 24, 2022
Code for 1st place solution in Sleep AI Challenge SNU Hospital

Sleep AI Challenge SNU Hospital 2021 Code for 1st place solution for Sleep AI Challenge (Note that the code is not fully organized) Refer to the notio

Saewon Yang 13 Jan 03, 2022
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma This repo provi

Jingtao Zhan 99 Dec 27, 2022
MinkLoc3D-SI: 3D LiDAR place recognition with sparse convolutions,spherical coordinates, and intensity

MinkLoc3D-SI: 3D LiDAR place recognition with sparse convolutions,spherical coordinates, and intensity Introduction The 3D LiDAR place recognition aim

16 Dec 08, 2022
JstDoS - HTTP Protocol Stack Remote Code Execution Vulnerability

jstDoS If you are going to skid that, please give credits ! ^^ ¿How works? This

apolo 4 Feb 11, 2022
Train Scene Graph Generation for Visual Genome and GQA in PyTorch >= 1.2 with improved zero and few-shot generalization.

Scene Graph Generation Object Detections Ground truth Scene Graph Generated Scene Graph In this visualization, woman sitting on rock is a zero-shot tr

Boris Knyazev 93 Dec 28, 2022