Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (CVAMD)

Overview

Is it Time to Replace CNNs with Transformers for Medical Images?

Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (CVAMD)

Convolutional Neural Networks (CNNs) have reigned for a decade as the de facto approach to automated medical image diagnosis. Recently, vision transformers (ViTs) have appeared as a competitive alternative to CNNs, yielding similar levels of performance while possessing several interesting properties that could prove beneficial for medical imaging tasks. In this work, we explore whether it is time to move to transformer-based models or if we should keep working with CNNs - can we trivially switch to transformers? If so, what are the advantages and drawbacks of switching to ViTs for medical image diagnosis? We consider these questions in a series of experiments on three mainstream medical image datasets. Our findings show that, while CNNs perform better when trained from scratch, off-the-shelf vision transformers using default hyperparameters are on par with CNNs when pretrained on ImageNet, and outperform their CNN counterparts when pretrained using self-supervision.

Enviroment setup

To build using the docker file use the following command
docker build -f Dockerfile -t med_trans \
--build-arg UID=$(id -u) \
--build-arg GID=$(id -g) \
--build-arg USER=$(whoami) \
--build-arg GROUP=$(id -g -n) .

Usage:

  • Training: python classification.py
  • Training with DINO: python classification.py --dino
  • Testing (using json file): python classification.py --test
  • Testing (using saved checkpoint): python classification.py --checkpoint CheckpointName --test
  • Fine tune the learning rate: python classification.py --lr_finder

Configuration (json file)

  • dataset_params
    • dataset: Name of the dataset (ISIC2019, APTOS2019, DDSM)
    • data_location: Location that the datasets are located
    • train_transforms: Defines the augmentations for the training set
    • val_transforms: Defines the augmentations for the validation set
    • test_transforms: Defines the augmentations for the test set
  • dataloader_params: Defines the dataloader parameters (batch size, num_workers etc)
  • model_params
    • backbone_type: type of the backbone model (e.g. resnet50, deit_small)
    • transformers_params: Additional hyperparameters for the transformers
      • img_size: The size of the input images
      • patch_size: The patch size to use for patching the input
      • pretrained_type: If supervised it loads ImageNet weights that come from supervised learning. If dino it loads ImageNet weights that come from sefl-supervised learning with DINO.
    • pretrained: If True, it uses ImageNet pretrained weights
    • freeze_backbone: If True, it freezes the backbone network
    • DINO: It controls the hyperparameters for when training with DINO
  • optimization_params: Defines learning rate, weight decay, learning rate schedule etc.
    • optimizer: The default optimizer's parameters
      • type: The optimizer's type
      • autoscale_rl: If True it scales the learning rate based on the bach size
      • params: Defines the learning rate and the weght decay value
    • LARS_params: If use=True and bach size >= batch_act_thresh it uses LARS as optimizer
    • scheduler: Defines the learning rate schedule
      • type: A list of schedulers to use
      • params: Sets the hyperparameters of the optimizers
  • training_params: Defines the training parameters
    • model_name: The model's name
    • val_every: Sets the frequency of the valiidation step (epochs - float)
    • log_every: Sets the frequency of the logging (iterations - int)
    • save_best_model: If True it will save the bast model based on the validation metrics
    • log_embeddings: If True it creates U-maps on each validation step
    • knn_eval: If True, during validation it will also calculate the scores based on knn evalutation
    • grad_clipping: If > 0, it clips the gradients
    • use_tensorboard: If True, it will use tensorboard for logging instead of wandb
    • use_mixed_precision: If True, it will use mixed precision
    • save_dir: The dir to save the model's checkpoints etc.
  • system_params: Defines if GPUs are used, which GPUs etc.
  • log_params: Project and run name for the logger (we are using Weights & Biases by default)
  • lr_finder: Define the learning rate parameters
    • grid_search_params
      • min_pow, min_pow: The min and max power of 10 for the search
      • resolution: How many different learning rates to try
      • n_epochs: maximum epochs of the training session
      • random_lr: If True, it uses random learning rates withing the accepted range
      • keep_schedule: If True, it keeps the learning rate schedule
      • report_intermediate_steps: If True, it logs if validates throughout the training sessions
  • transfer_learning_params: Turns on or off transfer learning from pretrained models
    • use_pretrained: If True, it will use a pretrained model as a backbone
    • pretrained_model_name: The pretrained model's name
    • pretrained_path: If the prerained model's dir
Owner
Christos Matsoukas
PhD student in Deep Learning @ KTH Royal Institute of Technology
Christos Matsoukas
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Thalles Silva 1.7k Dec 28, 2022
Styled Handwritten Text Generation with Transformers (ICCV 21)

⚡ Handwriting Transformers [PDF] Ankan Kumar Bhunia, Salman Khan, Hisham Cholakkal, Rao Muhammad Anwer, Fahad Shahbaz Khan & Mubarak Shah Abstract: We

Ankan Kumar Bhunia 85 Dec 22, 2022
🦙 LaMa Image Inpainting, Resolution-robust Large Mask Inpainting with Fourier Convolutions, WACV 2022

🦙 LaMa Image Inpainting, Resolution-robust Large Mask Inpainting with Fourier Convolutions, WACV 2022

Advanced Image Manipulation Lab @ Samsung AI Center Moscow 4.7k Dec 31, 2022
ANEA: Distant Supervision for Low-Resource Named Entity Recognition

ANEA: Distant Supervision for Low-Resource Named Entity Recognition ANEA is a tool to automatically annotate named entities in unlabeled text based on

Saarland University Spoken Language Systems Group 15 Mar 30, 2022
Classification of ecg datas for disease detection

ecg_classification Classification of ecg datas for disease detection

Atacan ÖZKAN 5 Sep 09, 2022
Temporal Segment Networks (TSN) in PyTorch

TSN-Pytorch We have released MMAction, a full-fledged action understanding toolbox based on PyTorch. It includes implementation for TSN as well as oth

1k Jan 03, 2023
Social Network Ads Prediction

Social network advertising, also social media targeting, is a group of terms that are used to describe forms of online advertising that focus on social networking services.

Khazar 2 Jan 28, 2022
ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator

ONNX Runtime is a cross-platform inference and training machine-learning accelerator. ONNX Runtime inference can enable faster customer experiences an

Microsoft 8k Jan 04, 2023
[CVPR 2022] Semi-Supervised Semantic Segmentation Using Unreliable Pseudo-Labels

Using Unreliable Pseudo Labels Official PyTorch implementation of Semi-Supervised Semantic Segmentation Using Unreliable Pseudo Labels, CVPR 2022. Ple

Haochen Wang 268 Dec 24, 2022
[AAAI 2021] EMLight: Lighting Estimation via Spherical Distribution Approximation and [ICCV 2021] Sparse Needlets for Lighting Estimation with Spherical Transport Loss

EMLight: Lighting Estimation via Spherical Distribution Approximation (AAAI 2021) Update 12/2021: We release our Virtual Object Relighting (VOR) Datas

Fangneng Zhan 144 Jan 06, 2023
DeRF: Decomposed Radiance Fields

DeRF: Decomposed Radiance Fields Daniel Rebain, Wei Jiang, Soroosh Yazdani, Ke Li, Kwang Moo Yi, Andrea Tagliasacchi Links Paper Project Page Abstract

UBC Computer Vision Group 24 Dec 02, 2022
Deep learning (neural network) based remote photoplethysmography: how to extract pulse signal from video using deep learning tools

Deep-rPPG: Camera-based pulse estimation using deep learning tools Deep learning (neural network) based remote photoplethysmography: how to extract pu

Terbe Dániel 138 Dec 17, 2022
A clean and robust Pytorch implementation of PPO on continuous action space.

PPO-Continuous-Pytorch I found the current implementation of PPO on continuous action space is whether somewhat complicated or not stable. And this is

XinJingHao 56 Dec 16, 2022
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma This repo provi

Jingtao Zhan 99 Dec 27, 2022
Code & Data for the Paper "Time Masking for Temporal Language Models", WSDM 2022

Time Masking for Temporal Language Models This repository provides a reference implementation of the paper: Time Masking for Temporal Language Models

Guy Rosin 12 Jan 06, 2023
A graphical Semi-automatic annotation tool based on labelImg and Yolov5

💕YOLOV5 semi-automatic annotation tool (Based on labelImg)

EricFang 247 Jan 05, 2023
MetaBalance: High-Performance Neural Networks for Class-Imbalanced Data

This repository is the official PyTorch implementation of Meta-Balance. Find the paper on arxiv MetaBalance: High-Performance Neural Networks for Clas

Arpit Bansal 20 Oct 18, 2021
REBEL: Relation Extraction By End-to-end Language generation

REBEL: Relation Extraction By End-to-end Language generation This is the repository for the Findings of EMNLP 2021 paper REBEL: Relation Extraction By

Babelscape 222 Jan 06, 2023
SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model

SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model Edresson Casanova, Christopher Shulby, Eren Gölge, Nicolas Michael Müller, Frede

Edresson Casanova 92 Dec 09, 2022