Cross-Task Consistency Learning Framework for Multi-Task Learning

Related tags

Deep Learningxtask_mt
Overview

Cross-Task Consistency Learning Framework for Multi-Task Learning

Tested on

  • numpy(v1.19.1)
  • opencv-python(v4.4.0.42)
  • torch(v1.7.0)
  • torchvision(v0.8.0)
  • tqdm(v4.48.2)
  • matplotlib(v3.3.1)
  • seaborn(v0.11.0)
  • pandas(v.1.1.2)

Data

Cityscapes (CS)

Download Cityscapes dataset and put it in a subdirectory named ./data/cityscapes. The folder should have the following subfolders:

  • RGB image in folder leftImg8bit
  • Segmentation in folder gtFine
  • Disparity maps in folder disparity

NYU

We use the preprocessed NYUv2 dataset provided by this repo. Download the dataset and put it in the dataset folder in ./data/nyu.

Model

The model consists of one encoder (ResNet) and two decoders, one for each task. The decoders outputs the predictions for each task ("direct predictions"), which are fed to the TaskTransferNet.
The objective of the TaskTranferNet is to predict the other task given a prediction image as an input (Segmentation prediction -> Depth prediction, vice versa), which I refer to as "transferred predictions"

Loss function

When computing the losses, the direct predictions are compared with the target while the transferred predictions are compared with the direct predictions so that they "align themselves".
The total loss consists of 4 different losses:

  • direct segmentation loss: CrossEntropyLoss()
  • direct depth loss: L1() or MSE() or logL1() or SmoothL1()
  • transferred segmentation loss:
    CrossEntropyLoss() or KLDivergence()
  • transferred depth loss: L1() or SSIM()

* Label smoothing: To "smooth" the one-hot probability by taking some of the probability from the correct class and distributing it among other classes.
* SSIM: Structural Similarity Loss

Flags

The flags are the same for both datasets. The flags and its usage are as written below,

Flag Name Usage Comments
input_path Path to dataset default is data/cityscapes (CS) or data/nyu (NYU)
height height of prediction default: 128 (CS) or 288 (NYU)
width width of prediction default: 256 (CS) or 384 (NYU)
epochs # of epochs default: 250 (CS) or 100 (NYU)
enc_layers which encoder to use default: 34, can choose from 18, 34, 50, 101, 152
use_pretrain toggle on to use pretrained encoder weights available for both datasets
batch_size batch size default: 8 (CS) or 6 (NYU)
scheduler_step_size step size for scheduler default: 80 (CS) or 60 (NYU), note that we use StepLR
scheduler_gamma decay rate of scheduler default: 0.5
alpha weight of adding transferred depth loss default: 0.01 (CS) or 0.0001 (NYU)
gamma weight of adding transferred segmentation loss default: 0.01 (CS) or 0.0001 (NYU)
label_smoothing amount of label smoothing default: 0.0
lp loss fn for direct depth loss default: L1, can choose from L1, MSE, logL1, smoothL1
tdep_loss loss fn for transferred depth loss default: L1, can choose from L1 or SSIM
tseg_loss loss fn for transferred segmentation loss default: cross, can choose from cross or kl
batch_norm toggle to enable batch normalization layer in TaskTransferNet slightly improves segmentation task
wider_ttnet toggle to double the # of channels in TaskTransferNet
uncertainty_weights toggle to use uncertainty weights (Kendall, et al. 2018) we used this for best results
gradnorm toggle to use GradNorm (Chen, et al. 2018)

Training

Cityscapes

For the Cityscapes dataset, there are two versions of segmentation task, which are 7-classes task and 19-classes task (Use flag 'num_classes' to switch tasks, default is 7).
So far, the results show near-SOTA for 7-class segmentation task + depth estimation.

ResNet34 was used as the encoder, L1() for direct depth loss and CrossEntropyLoss() for transferred segmentation loss.
The hyperparameter weights for both transferred predictions were 0.01.
I used Adam as my optimizer with an initial learning rate of 0.0001 and trained for 250 epochs with batch size 8. The learning rate was halved every 80 epochs.

To reproduce the code, use the following:

python main_cross_cs.py --uncertainty_weights

NYU

Our results show SOTA for NYU dataset.

ResNet34 was used as the encoder, L1() for direct depth loss and CrossEntropyLoss() for transferred segmentation loss.
The hyperparameter weights for both transferred predictions were 0.0001.
I used Adam as my optimizer with an initial learning rate of 0.0001 and trained for 100 epochs with batch size 6. The learning rate was halved every 60 epochs.

To reproduce the code, use the following:

python main_cross_nyu.py --uncertainty_weights

Comparisons

Evaluation metrics are the following:

Segmentation

  • Pixel accuracy (Pix Acc): percentage of pixels with the correct label
  • mIoU: mean Intersection over Union

Depth

  • Absolute Error (Abs)
  • Absolute Relative Error (Abs Rel): Absolute error divided by ground truth depth

The results are the following:

Cityscapes

Models mIoU Pix Acc Abs Abs Rel
MTAN 53.04 91.11 0.0144 33.63
KD4MTL 52.71 91.54 0.0139 27.33
PCGrad 53.59 91.45 0.0171 31.34
AdaMT-Net 62.53 94.16 0.0125 22.23
Ours 66.51 93.56 0.0122 19.40

NYU

Models mIoU Pix Acc Abs Abs Rel
MTAN* 21.07 55.70 0.6035 0.2472
MTAN† 20.10 53.73 0.6417 0.2758
KD4MTL* 20.75 57.90 0.5816 0.2445
KD4MTL† 22.44 57.32 0.6003 0.2601
PCGrad* 20.17 56.65 0.5904 0.2467
PCGrad† 21.29 54.07 0.6705 0.3000
AdaMT-Net* 21.86 60.35 0.5933 0.2456
AdaMT-Net† 20.61 58.91 0.6136 0.2547
Ours† 30.31 63.02 0.5954 0.2235

*: Trained on 3 tasks (segmentation, depth, and surface normal)
†: Trained on 2 tasks (segmentation and depth)
Italic: Reproduced by ourselves

Scores with models trained on 3 tasks for NYU dataset are shown only as reference.

Papers referred

MTAN: [paper][github]
KD4MTL: [paper][github]
PCGrad: [paper][github (tensorflow)][github (pytorch)]
AdaMT-Net: [paper]

Owner
Aki Nakano
Student at the University of Tokyo pursuing master's degree. Joined UC Berkeley Summer Session 2019. Researching deep learning. Python/R
Aki Nakano
This code provides a PyTorch implementation for OTTER (Optimal Transport distillation for Efficient zero-shot Recognition), as described in the paper.

Data Efficient Language-Supervised Zero-Shot Recognition with Optimal Transport Distillation This repository contains PyTorch evaluation code, trainin

Meta Research 45 Dec 20, 2022
RoboDesk A Multi-Task Reinforcement Learning Benchmark

RoboDesk A Multi-Task Reinforcement Learning Benchmark If you find this open source release useful, please reference in your paper: @misc{kannan2021ro

Google Research 66 Oct 07, 2022
Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022

PyCRE Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022 Dependencies This project is developed

<a href=[email protected]"> 7 May 06, 2022
Quick program made to generate alpha and delta tables for Hidden Markov Models

HMM_Calc Functions for generating Alpha and Delta tables from a Hidden Markov Model. Parameters: a: Matrix of transition probabilities. a[i][j] = a_{i

Adem Odza 1 Dec 04, 2021
A JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short.

BraVe This is a JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short. The model provided in this package wa

DeepMind 44 Nov 20, 2022
Deep Learning for Computer Vision final project

Deep Learning for Computer Vision final project

grassking100 1 Nov 30, 2021
PIGLeT: Language Grounding Through Neuro-Symbolic Interaction in a 3D World [ACL 2021]

piglet PIGLeT: Language Grounding Through Neuro-Symbolic Interaction in a 3D World [ACL 2021] This repo contains code and data for PIGLeT. If you like

Rowan Zellers 51 Oct 08, 2022
Fedlearn支持前沿算法研发的Python工具库 | Fedlearn algorithm toolkit for researchers

FedLearn-algo Installation Development Environment Checklist python3 (3.6 or 3.7) is required. To configure and check the development environment is c

89 Nov 14, 2022
Pre-Trained Image Processing Transformer (IPT)

Pre-Trained Image Processing Transformer (IPT) By Hanting Chen, Yunhe Wang, Tianyu Guo, Chang Xu, Yiping Deng, Zhenhua Liu, Siwei Ma, Chunjing Xu, Cha

HUAWEI Noah's Ark Lab 332 Dec 18, 2022
Deep Residual Networks with 1K Layers

Deep Residual Networks with 1K Layers By Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Microsoft Research Asia (MSRA). Table of Contents Introduc

Kaiming He 856 Jan 06, 2023
Code for Contrastive-Geometry Networks for Generalized 3D Pose Transfer

CGTransformer Code for our AAAI 2022 paper "Contrastive-Geometry Transformer network for Generalized 3D Pose Transfer" Contrastive-Geometry Transforme

18 Jun 28, 2022
Simple renderer for use with MuJoCo (>=2.1.2) Python Bindings.

Viewer for MuJoCo in Python Interactive renderer to use with the official Python bindings for MuJoCo. Starting with version 2.1.2, MuJoCo comes with n

Rohan P. Singh 62 Dec 30, 2022
Official implementation of "StyleCariGAN: Caricature Generation via StyleGAN Feature Map Modulation" (SIGGRAPH 2021)

StyleCariGAN in PyTorch Official implementation of StyleCariGAN:Caricature Generation via StyleGAN Feature Map Modulation in PyTorch Requirements PyTo

PeterZhouSZ 49 Oct 31, 2022
Ranger deep learning optimizer rewrite to use newest components

Ranger21 - integrating the latest deep learning components into a single optimizer Ranger deep learning optimizer rewrite to use newest components Ran

Less Wright 266 Dec 28, 2022
[ACM MM 2021] TSA-Net: Tube Self-Attention Network for Action Quality Assessment

Tube Self-Attention Network (TSA-Net) This repository contains the PyTorch implementation for paper TSA-Net: Tube Self-Attention Network for Action Qu

ShunliWang 18 Dec 23, 2022
Y. Zhang, Q. Yao, W. Dai, L. Chen. AutoSF: Searching Scoring Functions for Knowledge Graph Embedding. IEEE International Conference on Data Engineering (ICDE). 2020

AutoSF The code for our paper "AutoSF: Searching Scoring Functions for Knowledge Graph Embedding" and this paper has been accepted by ICDE2020. News:

AutoML Research 64 Dec 17, 2022
StarGAN v2 - Official PyTorch Implementation (CVPR 2020)

StarGAN v2 - Official PyTorch Implementation StarGAN v2: Diverse Image Synthesis for Multiple Domains Yunjey Choi*, Youngjung Uh*, Jaejun Yoo*, Jung-W

Clova AI Research 3.1k Jan 09, 2023
Code for: https://berkeleyautomation.github.io/bags/

DeformableRavens Code for the paper Learning to Rearrange Deformable Cables, Fabrics, and Bags with Goal-Conditioned Transporter Networks. Here is the

Daniel Seita 121 Dec 30, 2022
Repository features UNet inspired architecture used for segmenting lungs on chest X-Ray images

Lung Segmentation (2D) Repository features UNet inspired architecture used for segmenting lungs on chest X-Ray images. Demo See the application of the

163 Sep 21, 2022