Public repo for the ICCV2021-CVAMD paper "Is it Time to Replace CNNs with Transformers for Medical Images?"

Overview

Is it Time to Replace CNNs with Transformers for Medical Images?

Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (CVAMD)

Convolutional Neural Networks (CNNs) have reigned for a decade as the de facto approach to automated medical image diagnosis. Recently, vision transformers (ViTs) have appeared as a competitive alternative to CNNs, yielding similar levels of performance while possessing several interesting properties that could prove beneficial for medical imaging tasks. In this work, we explore whether it is time to move to transformer-based models or if we should keep working with CNNs - can we trivially switch to transformers? If so, what are the advantages and drawbacks of switching to ViTs for medical image diagnosis? We consider these questions in a series of experiments on three mainstream medical image datasets. Our findings show that, while CNNs perform better when trained from scratch, off-the-shelf vision transformers using default hyperparameters are on par with CNNs when pretrained on ImageNet, and outperform their CNN counterparts when pretrained using self-supervision.

Enviroment setup

To build using the docker file use the following command
docker build -f Dockerfile -t med_trans \
--build-arg UID=$(id -u) \
--build-arg GID=$(id -g) \
--build-arg USER=$(whoami) \
--build-arg GROUP=$(id -g -n) .

Usage:

  • Training: python classification.py
  • Training with DINO: python classification.py --dino
  • Testing (using json file): python classification.py --test
  • Testing (using saved checkpoint): python classification.py --checkpoint CheckpointName --test
  • Fine tune the learning rate: python classification.py --lr_finder

Configuration (json file)

  • dataset_params
    • dataset: Name of the dataset (ISIC2019, APTOS2019, DDSM)
    • data_location: Location that the datasets are located
    • train_transforms: Defines the augmentations for the training set
    • val_transforms: Defines the augmentations for the validation set
    • test_transforms: Defines the augmentations for the test set
  • dataloader_params: Defines the dataloader parameters (batch size, num_workers etc)
  • model_params
    • backbone_type: type of the backbone model (e.g. resnet50, deit_small)
    • transformers_params: Additional hyperparameters for the transformers
      • img_size: The size of the input images
      • patch_size: The patch size to use for patching the input
      • pretrained_type: If supervised it loads ImageNet weights that come from supervised learning. If dino it loads ImageNet weights that come from sefl-supervised learning with DINO.
    • pretrained: If True, it uses ImageNet pretrained weights
    • freeze_backbone: If True, it freezes the backbone network
    • DINO: It controls the hyperparameters for when training with DINO
  • optimization_params: Defines learning rate, weight decay, learning rate schedule etc.
    • optimizer: The default optimizer's parameters
      • type: The optimizer's type
      • autoscale_rl: If True it scales the learning rate based on the bach size
      • params: Defines the learning rate and the weght decay value
    • LARS_params: If use=True and bach size >= batch_act_thresh it uses LARS as optimizer
    • scheduler: Defines the learning rate schedule
      • type: A list of schedulers to use
      • params: Sets the hyperparameters of the optimizers
  • training_params: Defines the training parameters
    • model_name: The model's name
    • val_every: Sets the frequency of the valiidation step (epochs - float)
    • log_every: Sets the frequency of the logging (iterations - int)
    • save_best_model: If True it will save the bast model based on the validation metrics
    • log_embeddings: If True it creates U-maps on each validation step
    • knn_eval: If True, during validation it will also calculate the scores based on knn evalutation
    • grad_clipping: If > 0, it clips the gradients
    • use_tensorboard: If True, it will use tensorboard for logging instead of wandb
    • use_mixed_precision: If True, it will use mixed precision
    • save_dir: The dir to save the model's checkpoints etc.
  • system_params: Defines if GPUs are used, which GPUs etc.
  • log_params: Project and run name for the logger (we are using Weights & Biases by default)
  • lr_finder: Define the learning rate parameters
    • grid_search_params
      • min_pow, min_pow: The min and max power of 10 for the search
      • resolution: How many different learning rates to try
      • n_epochs: maximum epochs of the training session
      • random_lr: If True, it uses random learning rates withing the accepted range
      • keep_schedule: If True, it keeps the learning rate schedule
      • report_intermediate_steps: If True, it logs if validates throughout the training sessions
  • transfer_learning_params: Turns on or off transfer learning from pretrained models
    • use_pretrained: If True, it will use a pretrained model as a backbone
    • pretrained_model_name: The pretrained model's name
    • pretrained_path: If the prerained model's dir
Owner
Christos Matsoukas
PhD student in Deep Learning @ KTH Royal Institute of Technology
Christos Matsoukas
🤖 A Python library for learning and evaluating knowledge graph embeddings

PyKEEN PyKEEN (Python KnowlEdge EmbeddiNgs) is a Python package designed to train and evaluate knowledge graph embedding models (incorporating multi-m

PyKEEN 1.1k Jan 09, 2023
Codes for our paper "SentiLARE: Sentiment-Aware Language Representation Learning with Linguistic Knowledge" (EMNLP 2020)

SentiLARE: Sentiment-Aware Language Representation Learning with Linguistic Knowledge Introduction SentiLARE is a sentiment-aware pre-trained language

74 Dec 30, 2022
Algorithm to texture 3D reconstructions from multi-view stereo images

MVS-Texturing Welcome to our project that textures 3D reconstructions from images. This project focuses on 3D reconstructions generated using structur

Nils Moehrle 766 Jan 04, 2023
A Robust Unsupervised Ensemble of Feature-Based Explanations using Restricted Boltzmann Machines

A Robust Unsupervised Ensemble of Feature-Based Explanations using Restricted Boltzmann Machines Understanding the results of deep neural networks is

Johan van den Heuvel 2 Dec 13, 2021
TensorFlow, PyTorch and Numpy layers for generating Orthogonal Polynomials

OrthNet TensorFlow, PyTorch and Numpy layers for generating multi-dimensional Orthogonal Polynomials 1. Installation 2. Usage 3. Polynomials 4. Base C

Chuan 29 May 25, 2022
Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation

Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation (CVPR2019) This is a pytorch implementatio

Yawei Luo 280 Jan 01, 2023
Pairwise model for commonlit competition

Pairwise model for commonlit competition To run: - install requirements - create input directory with train_folds.csv and other competition data - cd

abhishek thakur 45 Aug 31, 2022
[Pedestron] Generalizable Pedestrian Detection: The Elephant In The Room. @ CVPR2021

Pedestron Pedestron is a MMdetection based repository, that focuses on the advancement of research on pedestrian detection. We provide a list of detec

Irtiza Hasan 594 Jan 05, 2023
Implementation of gaze tracking and demo

Predicting Customer Demand by Using Gaze Detecting and Object Tracking This project is the integration of gaze detecting and object tracking. Predict

2 Oct 20, 2022
Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Bayes-Newton Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Wil

AaltoML 165 Nov 27, 2022
Official PyTorch implementation of the paper "Deep Constrained Least Squares for Blind Image Super-Resolution", CVPR 2022.

Deep Constrained Least Squares for Blind Image Super-Resolution [Paper] This is the official implementation of 'Deep Constrained Least Squares for Bli

MEGVII Research 141 Dec 30, 2022
[CVPRW 2021] Code for Region-Adaptive Deformable Network for Image Quality Assessment

RADN [CVPRW 2021] Code for Region-Adaptive Deformable Network for Image Quality Assessment [Paper on arXiv] Overview Update [2021/5/7] add codes for W

IIGROUP 53 Dec 28, 2022
ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin et al., 2020).

ReConsider ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin

Facebook Research 47 Jul 26, 2022
Context Axial Reverse Attention Network for Small Medical Objects Segmentation

CaraNet: Context Axial Reverse Attention Network for Small Medical Objects Segmentation This repository contains the implementation of a novel attenti

401 Dec 23, 2022
Python inverse kinematics for your robot model based on Pinocchio.

Python inverse kinematics for your robot model based on Pinocchio.

Stéphane Caron 50 Dec 22, 2022
Morphable Detector for Object Detection on Demand

Morphable Detector for Object Detection on Demand (ICCV 2021) PyTorch implementation of the paper Morphable Detector for Object Detection on Demand. I

9 Feb 23, 2022
Semi-Supervised 3D Hand-Object Poses Estimation with Interactions in Time

Semi Hand-Object Semi-Supervised 3D Hand-Object Poses Estimation with Interactions in Time (CVPR 2021).

96 Dec 27, 2022
Seeing Dynamic Scene in the Dark: High-Quality Video Dataset with Mechatronic Alignment (ICCV2021)

Seeing Dynamic Scene in the Dark: High-Quality Video Dataset with Mechatronic Alignment This is a pytorch project for the paper Seeing Dynamic Scene i

DV Lab 21 Nov 28, 2022
CAPITAL: Optimal Subgroup Identification via Constrained Policy Tree Search

CAPITAL: Optimal Subgroup Identification via Constrained Policy Tree Search This repository is the official implementation of CAPITAL: Optimal Subgrou

Hengrui Cai 0 Oct 19, 2021
ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation

ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation This repository provides a PyTorch implementation of ADSPM. Requirements Pyth

24 Jul 24, 2022