Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Related tags

Deep LearningUTNet
Overview

UTNet (Accepted at MICCAI 2021)

Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Introduction

Transformer architecture has emerged to be successful in a number of natural language processing tasks. However, its applications to medical vision remain largely unexplored. In this study, we present UTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing medical image segmentation. UTNet applies self-attention modules in both encoder and decoder for capturing long-range dependency at dif- ferent scales with minimal overhead. To this end, we propose an efficient self-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from O(n2) to approximate O(n). A new self-attention decoder is also proposed to recover fine-grained details from the skipped connections in the encoder. Our approach addresses the dilemma that Transformer requires huge amounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks without a need of pre-training. We have evaluated UTNet on the multi- label, multi-vendor cardiac magnetic resonance imaging cohort. UTNet demonstrates superior segmentation performance and robustness against the state-of-the-art approaches, holding the promise to generalize well on other medical image segmentations.

image image

Supportting models

UTNet

TransUNet

ResNet50-UTNet

ResNet50-UNet

SwinUNet

To be continue ...

Getting Started

Currently, we only support M&Ms dataset.

Prerequisites

Python >= 3.6
pytorch = 1.8.1
SimpleITK = 2.0.2
numpy = 1.19.5
einops = 0.3.2

Preprocess

Resample all data to spacing of 1.2x1.2 mm in x-y plane. We don't change the spacing of z-axis, as UTNet is a 2D network. Then put all data into 'dataset/'

Training

The M&M dataset provides data from 4 venders, where vendor AB are provided for training while ABCD for testing. The '--domain' is used to control using which vendor for training. '--domain A' for using vender A only. '--domain B' for using vender B only. '--domain AB' for using both vender A and B. For testing, all 4 venders will be used.

UTNet

For default UTNet setting, training with:

python train_deep.py -m UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 1234 --num_blocks 1,1,1,1 --domain AB --gpu 0 --aux_loss

Or you can use '-m UTNet_encoder' to use transformer blocks in the encoder only. This setting is more stable than the default setting in some cases.

To optimize UTNet in your own task, there are several hyperparameters to tune:

'--block_list': indicates apply transformer blocks in which resolution. The number means the number of downsamplings, e.g. 3,4 means apply transformer blocks in features after 3 and 4 times downsampling. Apply transformer blocks in higher resolution feature maps will introduce much more computation.

'--num_blocks': indicates the number of transformer blocks applied in each level. e.g. block_list='3,4', num_blocks=2,4 means apply 2 transformer blocks in 3-times downsampling level and apply 4 transformer blocks in 4-time downsampling level.

'--reduce_size': indicates the size of downsampling for efficient attention. In our experiments, reduce_size 8 and 16 don't have much difference, but 16 will introduce more computation, so we choost 8 as our default setting. 16 might have better performance in other applications.

'--aux_loss': applies deep supervision in training, will introduce some computation overhead but has slightly better performance.

Here are some recomended parameter setting:

--block_list 1234 --num_blocks 1,1,1,1

Our default setting, most efficient setting. Suitable for tasks with limited training data, and most errors occur in the boundary of ROI where high resolution information is important.

--block_list 1234 --num_blocks 1,1,4,8

Similar to the previous one. The model capacity is larger as more transformer blocks are including, but needs larger dataset for training.

--block_list 234 --num_blocks 2,4,8

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

Feel free to try other combinations of the hyperparameter like base_chan, reduce_size and num_blocks in each level etc. to trade off between capacity and efficiency to fit your own tasks and datasets.

TransUNet

We borrow code from the original TransUNet repo and fit it into our training framework. If you want to use pre-trained weight, please download from the original repo. The configuration is not parsed by command line, so if you want change the configuration of TransUNet, you need change it inside the train_deep.py.

python train_deep.py -m TransUNet -u EXP_NAME --data_path YOUR_OWN_PATH --gpu 0

ResNet50-UTNet

For fair comparison with TransUNet, we implement the efficient attention proposed in UTNet into ResNet50 backbone, which is basically append transformer blocks into specified level after ResNet blocks. ResNet50-UTNet is slightly better in performance than the default UTNet in M&M dataset.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 123 --num_blocks 1,1,1 --gpu 0

Similar to UTNet, this is the most efficient setting, suitable for tasks with limited training data.

--block_list 23 --num_blocks 2,4

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

ResNet50-UNet

If you don't use Transformer blocks in ResNet50-UTNet, it is actually ResNet50-UNet. So you can use this as the baseline to compare the performance improvement from Transformer for fair comparision with TransUNet and our UTNet.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --block_list ''  --gpu 0

SwinUNet

Download pre-trained model from the origin repo. As Swin-Transformer's input size is related to window size and is hard to change after pretraining, so we adapt our input size to 224. Without pre-training, SwinUNet's performance is very low.

python train_deep.py -m SwinUNet -u EXP_NAME --data_path YOUR_OWN_PATH --crop_size 224

Citation

If you find this repo helps, please kindly cite our paper, thanks!

@inproceedings{gao2021utnet,
  title={UTNet: a hybrid transformer architecture for medical image segmentation},
  author={Gao, Yunhe and Zhou, Mu and Metaxas, Dimitris N},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={61--71},
  year={2021},
  organization={Springer}
}
Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics

Dataset Cartography Code for the paper Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics at EMNLP 2020. This repository cont

AI2 125 Dec 22, 2022
TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.

TensorFlow-Image-Models Introduction Usage Models Profiling License Introduction TensorfFlow-Image-Models (tfimm) is a collection of image models with

Martins Bruveris 227 Dec 20, 2022
StyleGAN - Official TensorFlow Implementation

StyleGAN — Official TensorFlow Implementation Picture: These people are not real – they were produced by our generator that allows control over differ

NVIDIA Research Projects 13.1k Jan 09, 2023
Implementation of paper: "Image Super-Resolution Using Dense Skip Connections" in PyTorch

SRDenseNet-pytorch Implementation of paper: "Image Super-Resolution Using Dense Skip Connections" in PyTorch (http://openaccess.thecvf.com/content_ICC

wxy 114 Nov 26, 2022
一个多语言支持、易使用的 OCR 项目。An easy-to-use OCR project with multilingual support.

AgentOCR 简介 AgentOCR 是一个基于 PaddleOCR 和 ONNXRuntime 项目开发的一个使用简单、调用方便的 OCR 项目 本项目目前包含 Python Package 【AgentOCR】 和 OCR 标注软件 【AgentOCRLabeling】 使用指南 Pytho

AgentMaker 98 Nov 10, 2022
This is a collection of simple PyTorch implementations of neural networks and related algorithms. These implementations are documented with explanations,

labml.ai Deep Learning Paper Implementations This is a collection of simple PyTorch implementations of neural networks and related algorithms. These i

labml.ai 16.4k Jan 09, 2023
Flax is a neural network ecosystem for JAX that is designed for flexibility.

Flax: A neural network library and ecosystem for JAX designed for flexibility Overview | Quick install | What does Flax look like? | Documentation See

Google 3.9k Jan 02, 2023
Code that accompanies the paper Semi-supervised Deep Kernel Learning: Regression with Unlabeled Data by Minimizing Predictive Variance

Semi-supervised Deep Kernel Learning This is the code that accompanies the paper Semi-supervised Deep Kernel Learning: Regression with Unlabeled Data

58 Oct 26, 2022
1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

Lihe Yang 209 Jan 01, 2023
This repository is the code of the paper "Sparse Spatial Transformers for Few-Shot Learning".

🌟 Sparse Spatial Transformers for Few-Shot Learning This code implements the Sparse Spatial Transformers for Few-Shot Learning(SSFormers). Our code i

chx_nju 38 Dec 13, 2022
novel deep learning research works with PaddlePaddle

Research 发布基于飞桨的前沿研究工作,包括CV、NLP、KG、STDM等领域的顶会论文和比赛冠军模型。 目录 计算机视觉(Computer Vision) 自然语言处理(Natrual Language Processing) 知识图谱(Knowledge Graph) 时空数据挖掘(Spa

1.5k Dec 29, 2022
Code for "Neural Parts: Learning Expressive 3D Shape Abstractions with Invertible Neural Networks", CVPR 2021

Neural Parts: Learning Expressive 3D Shape Abstractions with Invertible Neural Networks This repository contains the code that accompanies our CVPR 20

Despoina Paschalidou 161 Dec 20, 2022
Code for CVPR2021 paper "Robust Reflection Removal with Reflection-free Flash-only Cues"

Robust Reflection Removal with Reflection-free Flash-only Cues (RFC) Paper | To be released: Project Page | Video | Data Tensorflow implementation for

Chenyang LEI 162 Jan 05, 2023
QuanTaichi evaluation suite

QuanTaichi: A Compiler for Quantized Simulations (SIGGRAPH 2021) Yuanming Hu, Jiafeng Liu, Xuanda Yang, Mingkuan Xu, Ye Kuang, Weiwei Xu, Qiang Dai, W

Taichi Developers 120 Jan 04, 2023
[CVPR 2021] Involution: Inverting the Inherence of Convolution for Visual Recognition, a brand new neural operator

involution Official implementation of a neural operator as described in Involution: Inverting the Inherence of Convolution for Visual Recognition (CVP

Duo Li 1.3k Dec 28, 2022
Pytorch implementation of Zero-DCE++

Zero-DCE++ You can find more details here: https://li-chongyi.github.io/Proj_Zero-DCE++.html. You can find the details of our CVPR version: https://li

Chongyi Li 157 Dec 23, 2022
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 55 Jan 01, 2023
NuPIC Studio is an all­-in-­one tool that allows users create a HTM neural network from scratch

NuPIC Studio is an all­-in-­one tool that allows users create a HTM neural network from scratch, train it, collect statistics, and share it among the members of the community. It is not just a visual

HTM Community 93 Sep 30, 2022
Code for the paper "Zero-shot Natural Language Video Localization" (ICCV2021, Oral).

Zero-shot Natural Language Video Localization (ZSNLVL) by Pseudo-Supervised Video Localization (PSVL) This repository is for Zero-shot Natural Languag

Computer Vision Lab. @ GIST 37 Dec 27, 2022
Computational inteligence project on faces in the wild dataset

Table of Contents The general idea How these scripts work? Loading data Needed modules and global variables Parsing the arrays in dataset Extracting a

tooraj taraz 4 Oct 21, 2022