PyTorch implementation of: Michieli U. and Zanuttigh P., "Continual Semantic Segmentation via Repulsion-Attraction of Sparse and Disentangled Latent Representations", CVPR 2021.

Related tags

Deep LearningSDR
Overview

Continual Semantic Segmentation via Repulsion-Attraction of Sparse and Disentangled Latent Representations

This is the official PyTorch implementation of our work: "Continual Semantic Segmentation via Repulsion-Attraction of Sparse and Disentangled Latent Representations" published at CVPR 2021.

In this paper, we present some novel approaches constraining the feature space for continual learning semantic segmentation models. The evaluation on Pascal VOC2012 and on ADE20K validated our method.

Paper
5-min video
slides
poster
teaser

Requirements

This repository uses the following libraries:

  • Python (3.7.6)
  • Pytorch (1.4.0) [tested up to 1.7.1]
  • torchvision (0.5.0)
  • tensorboardX (2.0)
  • matplotlib (3.1.1)
  • numpy (1.18.1)
  • apex (0.1) [optional]
  • inplace-abn (1.0.7) [optional]

We also assume to have installed pytorch.distributed package.

All the dependencies are listed in the requirements.txt file which can be used in conda as:
conda create --name <env> --file requirements.txt

How to download data

In this project we use two dataset, ADE20K and Pascal-VOC 2012. We provide the scripts to download them in 'data/download_<dataset_name>.sh'. The script takes no inputs and you should use it in the target directory (where you want to download data).

How to perform training

The most important file is run.py, that is in charge to start the training or test procedure. To run it, simpy use the following command:

python -m torch.distributed.launch --nproc_per_node=<num_GPUs> run.py --data_root <data_folder> --name <exp_name> .. other args ..

The default is to use a pretraining for the backbone used, which is the one officially released by PyTorch models and it will be downloaded automatically. If you don't want to use pretrained, please use --no-pretrained.

There are many options and you can see them all by using --help option. Some of them are discussed in the following:

  • please specify the data folder using: --data_root <data_root>
  • dataset: --dataset voc (Pascal-VOC 2012) | ade (ADE20K)
  • task: --task <task>, where tasks are
    • 15-5, 15-5s, 19-1 (VOC), 100-50, 100-10, 50, 100-50b, 100-10b, 50b (ADE, b indicates the order)
  • step (each step is run separately): --step <N>, where N is the step number, starting from 0
  • (only for Pascal-VOC) disjoint is default setup, to enable overlapped: --overlapped
  • learning rate: --lr 0.01 (for step 0) | 0.001 (for step > 0)
  • batch size: --batch_size 8 (Pascal-VOC 2012) | 4 (ADE20K)
  • epochs: --epochs 30 (Pascal-VOC 2012) | 60 (ADE20K)
  • method: --method <method name>, where names are
    • FT, LWF, LWF-MC, ILT, EWC, RW, PI, MIB, CIL, SDR
      Note that method overwrites other parameters, but can be used as a kickstart to use default parameters for each method (see more on this in the hyperparameters section below)

For all the details please follow the information provided using the help option.

Example training commands

We provide some example scripts in the *.slurm and *.bat files.
For instance, to run the step 0 of 19-1 VOC2012 you can run:

python -u -m torch.distributed.launch 1> 'outputs/19-1/output_19-1_step0.txt' 2>&1 \
--nproc_per_node=1 run.py \
--batch_size 8 \
--logdir logs/19-1/ \
--dataset voc \
--name FT \
--task 19-1 \
--step 0 \
--lr 0.001 \
--epochs 30 \
--debug \
--sample_num 10 \
--unce \
--loss_de_prototypes 1 \
--where_to_sim GPU_windows

Note: loss_de_prototypes is set to 1 only for having the prototypes computed in the 0-th step (no distillation is actually computed of course).

Then, the step 1 of the same scenario can be computed simply as:

python -u -m torch.distributed.launch 1> 'outputs/19-1/output_19-1_step1.txt'  2>&1 \
--nproc_per_node=1 run.py \
--batch_size 8 \
--logdir logs/19-1/ \
--dataset voc \
--task 19-1 \
--step 1 \
--lr 0.0001 \
--epochs 30 \
--debug \
--sample_num 10 \
--where_to_sim GPU_windows \
--method SDR \
--step_ckpt 'logs/19-1/19-1-voc_FT/19-1-voc_FT_0.pth'

The results obtained are reported inside the outputs/ and logs/ folder, which can be downloaded here, and are 0.4% of mIoU higher than those reported in the main paper due to a slightly changed hyperparameter.

To run other approaches it is sufficient to change the --method parameter into one of the following: FT, LWF, LWF-MC, ILT, EWC, RW, PI, MIB, CIL, SDR.

Note: for the best results, the hyperparameters may change. Please see further details on the hyperparameters section below.

Once you trained the model, you can see the result on tensorboard (we perform the test after the whole training) or on the output files. or you can test it by using the same script and parameters but using the option

--test

that will skip all the training procedure and test the model on test data.

Do you want to try our constraints on your codebase or task?

If you want to try our novel constraints on your codebase or on a different problem you can check the utils/loss.py file. Here, you can take the definitions of the different losses and embed them into your codebase
The names of the variables could be interpreted as:

  • targets-- ground truth map,
  • outputs-- segmentation map output from the current network
  • outputs_old-- segmentation map output from the previous network
  • features-- features taken from the end of the currently-trained encoder,
  • features_old-- features taken from the end of the previous encoder [used for distillation on the encoder on ILT, but not used on SDR],
  • prototypes-- prototypical feature representations
  • incremental_step -- index of the current incremental step (0 if first non-incremental training is performed)
  • classes_old-- index of previous classes

Range for the Hyper-parameters

For what concerns the hyperparameters of our approach:

  • The parameter for the distillation loss is in the same range of that of MiB,
  • Prototypes matching: lambda was searched in range 1e-1 to 1e-3,
  • Contrastive learning (or clustering): lambda was searched in the range of 1e-2 to 1e-3,
  • Features sparsification: lambda was searched in the range of 1e-3 to 1e-5 A kick-start could be to use KD 10, PM 1e-2, CL 1e-3 and FS 1e-4.
    The best parameters may vary across datasets and incremental setup. However, we typically did a grid search and kept it fixed across learning steps.

So, writing explicitly all the parameters, the command would look something like the following:

python -u -m torch.distributed.launch 1> 'outputs/19-1/output_19-1_step1_custom.txt'  2>&1 \
--nproc_per_node=1 run.py \
--batch_size 8 \
--logdir logs/19-1/ \
--dataset voc \
--task 19-1 \
--step 1 \
--lr 0.0001 \
--epochs 30 \
--debug \
--sample_num 10 \
--where_to_sim GPU_windows \
--unce \
--loss_featspars $loss_featspars \
--lfs_normalization $lfs_normalization \
--lfs_shrinkingfn $lfs_shrinkingfn \
--lfs_loss_fn_touse $lfs_loss_fn_touse \
--loss_de_prototypes $loss_de_prototypes \
--loss_de_prototypes_sumafter \
--lfc_sep_clust $lfc_sep_clust \
--loss_fc $loss_fc \
--loss_kd $loss_kd \
--step_ckpt 'logs/19-1/19-1-voc_FT/19-1-voc_FT_0.pth'

Cite us

If you use this repository, please consider to cite

   @inProceedings{michieli2021continual,
   author = {Michieli, Umberto and Zanuttigh, Pietro},
   title  = {Continual Semantic Segmentation via Repulsion-Attraction of Sparse and Disentangled Latent Representations},
   booktitle = {Computer Vision and Pattern Recognition (CVPR)},
   year      = {2021},
   month     = {June}
   }

And our previous works ILT and its journal extension.

Acknowledgements

We gratefully acknowledge the authors of MiB paper for the insightful discussion and for providing the open source codebase, which has been the starting point for our work.
We also acknowledge the authors of CIL for providing their code even before the official release.

Owner
Multimedia Technology and Telecommunication Lab
Department of Information Engineering, University of Padova
Multimedia Technology and Telecommunication Lab
Code for "OctField: Hierarchical Implicit Functions for 3D Modeling (NeurIPS 2021)"

OctField(Jittor): Hierarchical Implicit Functions for 3D Modeling Introduction This repository is code release for OctField: Hierarchical Implicit Fun

55 Dec 08, 2022
Code accompanying the paper "ProxyFL: Decentralized Federated Learning through Proxy Model Sharing"

ProxyFL Code accompanying the paper "ProxyFL: Decentralized Federated Learning through Proxy Model Sharing" Authors: Shivam Kalra*, Junfeng Wen*, Jess

Layer6 Labs 14 Dec 06, 2022
The Multi-Mission Maximum Likelihood framework (3ML)

PyPi Conda The Multi-Mission Maximum Likelihood framework (3ML) A framework for multi-wavelength/multi-messenger analysis for astronomy/astrophysics.

The Multi-Mission Maximum Likelihood (3ML) 62 Dec 30, 2022
Memory-Augmented Model Predictive Control

Memory-Augmented Model Predictive Control This repository hosts the source code for the journal article "Composing MPC with LQR and Neural Networks fo

Fangyu Wu 1 Jun 19, 2022
NeWT: Natural World Tasks

NeWT: Natural World Tasks This repository contains resources for working with the NeWT dataset. ❗ At this time the binary tasks are not publicly avail

Visipedia 26 Oct 18, 2022
Collections for the lasted paper about multi-view clustering methods (papers, codes)

Multi-View Clustering Papers Collections for the lasted paper about multi-view clustering methods (papers, codes). There also exists some repositories

Andrew Guan 10 Sep 20, 2022
Codes to pre-train T5 (Text-to-Text Transfer Transformer) models pre-trained on Japanese web texts

t5-japanese Codes to pre-train T5 (Text-to-Text Transfer Transformer) models pre-trained on Japanese web texts. The following is a list of models that

Kimio Kuramitsu 1 Dec 13, 2021
A PyTorch implementation of EfficientNet and EfficientNetV2 (coming soon!)

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 06, 2023
Colab notebook and additional materials for Python-driven analysis of redlining data in Philadelphia

RedliningExploration The Google Colaboratory file contained in this repository contains work inspired by a project on educational inequality in the Ph

Benjamin Warren 1 Jan 20, 2022
Official implementation of our paper "Learning to Bootstrap for Combating Label Noise"

Learning to Bootstrap for Combating Label Noise This repo is the official implementation of our paper "Learning to Bootstrap for Combating Label Noise

21 Apr 09, 2022
Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning

Proxy Anchor Loss for Deep Metric Learning Unofficial pytorch, tensorflow and mxnet implementations of Proxy Anchor Loss for Deep Metric Learning. Not

Geonmo Gu 3 Jun 09, 2021
MultiTaskLearning - Multi Task Learning for 3D segmentation

Multi Task Learning for 3D segmentation Perception stack of an Autonomous Drivin

2 Sep 22, 2022
Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics.

Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics. By Andres Milioto @ University of Bonn. (for the new P

Photogrammetry & Robotics Bonn 314 Dec 30, 2022
A repo for Causal Imitation Learning under Temporally Correlated Noise

CausIL A repo for Causal Imitation Learning under Temporally Correlated Noise. Running Experiments To re-train an expert, run: python experts/train_ex

Gokul Swamy 5 Nov 01, 2022
a short visualisation script for pyvideo data

PyVideo Speakers A CLI that visualises repeat speakers from events listed in https://github.com/pyvideo/data Not terribly efficient, but you know. Ins

Katie McLaughlin 3 Nov 24, 2021
A quick recipe to learn all about Transformers

Transformers have accelerated the development of new techniques and models for natural language processing (NLP) tasks.

DAIR.AI 772 Dec 31, 2022
✔️ Visual, reactive testing library for Julia. Time machine included.

PlutoTest.jl (alpha release) Visual, reactive testing library for Julia A macro @test that you can use to verify your code's correctness. But instead

Pluto 68 Dec 20, 2022
State-to-Distribution (STD) Model

State-to-Distribution (STD) Model In this repository we provide exemplary code on how to construct and evaluate a state-to-distribution (STD) model fo

<a href=[email protected]"> 2 Apr 07, 2022
This is the offical website for paper ''Category-consistent deep network learning for accurate vehicle logo recognition''

The Pytorch Implementation of Category-consistent deep network learning for accurate vehicle logo recognition This is the offical website for paper ''

Wanglong Lu 28 Oct 29, 2022
A compendium of useful, interesting, inspirational usage of pandas functions, each example will be an ipynb file

Pandas_by_examples A compendium of useful/interesting/inspirational usage of pandas functions, each example will be an ipynb file What is this reposit

Guangyuan(Frank) Li 32 Nov 20, 2022