Code for "AutoMTL: A Programming Framework for Automated Multi-Task Learning"

Related tags

Deep LearningAutoMTL
Overview

AutoMTL: A Programming Framework for Automated Multi-Task Learning

This is the website for our paper "AutoMTL: A Programming Framework for Automated Multi-Task Learning", submitted to MLSys 2022. The arXiv version will be public at Tue, 26 Oct 2021.

Abstract

Multi-task learning (MTL) jointly learns a set of tasks. It is a promising approach to reduce the training and inference time and storage costs while improving prediction accuracy and generalization performance for many computer vision tasks. However, a major barrier preventing the widespread adoption of MTL is the lack of systematic support for developing compact multi-task models given a set of tasks. In this paper, we aim to remove the barrier by developing the first programming framework AutoMTL that automates MTL model development. AutoMTL takes as inputs an arbitrary backbone convolutional neural network and a set of tasks to learn, then automatically produce a multi-task model that achieves high accuracy and has small memory footprint simultaneously. As a programming framework, AutoMTL could facilitate the development of MTL-enabled computer vision applications and even further improve task performance.

overview

Cite

Welcome to cite our work if you find it is helpful to your research. [TODO: cite info]

Description

Environment

conda install pytorch==1.6.0 torchvision==0.7.0 -c pytorch # Or higher
conda install protobuf
pip install opencv-python
pip install scikit-learn

Datasets

We conducted experiments on three popular datasets in multi-task learning (MTL), CityScapes [1], NYUv2 [2], and Tiny-Taskonomy [3]. You can download the them here. For Tiny-Taskonomy, you will need to contact the authors directly. See their official website.

File Structure

├── data
│   ├── dataloader
│   │   ├── *_dataloader.py
│   ├── heads
│   │   ├── pixel2pixel.py
│   ├── metrics
│   │   ├── pixel2pixel_loss/metrics.py
├── framework
│   ├── layer_containers.py
│   ├── base_node.py
│   ├── layer_node.py
│   ├── mtl_model.py
│   ├── trainer.py
├── models
│   ├── *.prototxt
├── utils
└── └── pytorch_to_caffe.py

Code Description

Our code can be divided into three parts: code for data, code of AutoMTL, and others

  • For Data

    • Dataloaders *_dataloader.py: For each dataset, we offer a corresponding PyTorch dataloader with a specific task variable.
    • Heads pixel2pixel.py: The ASPP head [4] is implemented for the pixel-to-pixel vision tasks.
    • Metrics pixel2pixel_loss/metrics.py: For each task, it has its own criterion and metric.
  • AutoMTL

    • Multi-Task Model Generator mtl_model.py: Transfer the given backbone model in the format of prototxt, and the task-specific model head dictionary to a multi-task supermodel.
    • Trainer Tools trainer.py: Meterialize a three-stage training pipeline to search out a good multi-task model for the given tasks. pipeline
  • Others

    • Input Backbone *.prototxt: Typical vision backbone models including Deeplab-ResNet34 [4], MobileNetV2, and MNasNet.
    • Transfer to Prototxt pytorch_to_caffe.py: If you define your own customized backbone model in PyTorch API, we also provide a tool to convert it to a prototxt file.

How to Use

Set up Data

Each task will have its own dataloader for both training and validation, task-specific criterion (loss), evaluation metric, and model head. Here we take CityScapes as an example.

tasks = ['segment_semantic', 'depth_zbuffer']
task_cls_num = {'segment_semantic': 19, 'depth_zbuffer': 1} # the number of classes in each task

You can also define your own dataloader, criterion, and evaluation metrics. Please refer to files in data/ to make sure your customized classes have the same output format as ours to fit for our framework.

dataloader dictionary

trainDataloaderDict = {}
valDataloaderDict = {}
for task in tasks:
    dataset = CityScapes(dataroot, 'train', task, crop_h=224, crop_w=224)
    trainDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)

    dataset = CityScapes(dataroot, 'test', task)
    valDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)

criterion dictionary

criterionDict = {}
for task in tasks:
    criterionDict[task] = CityScapesCriterions(task)

evaluation metric dictionary

metricDict = {}
for task in tasks:
    metricDict[task] = CityScapesMetrics(task)

task-specific heads dictionary

headsDict = nn.ModuleDict() # must be nn.ModuleDict() instead of python dictionary
for task in tasks:
    headsDict[task] = ASPPHeadNode(<feature_dim>, task_cls_num[task])

Construct Multi-Task Supermodel

prototxt = 'models/deeplab_resnet34_adashare.prototxt' # can be any CNN model
mtlmodel = MTLModel(prototxt, headsDict)

3-stage Training

define the trainer

trainer = Trainer(mtlmodel, trainDataloaderDict, valDataloaderDict, criterionDict, metricDict)

pre-train phase

trainer.pre_train(iters=<total_iter>, lr=<init_lr>, savePath=<save_path>)

policy-train phase

loss_lambda = {'segment_semantic': 1, 'depth_zbuffer': 1, 'policy':0.0005} # the weights for each task and the policy regularization term from the paper
trainer.alter_train_with_reg(iters=<total_iter>, policy_network_iters=<alter_iters>, policy_lr=<policy_lr>, network_lr=<network_lr>, 
                             loss_lambda=loss_lambda, savePath=<save_path>)

Notice that when training the policy and the model weights together, we alternatively train them for specified iters in policy_network_iters.

post-train phase

trainer.post_train(ters=<total_iter>, lr=<init_lr>, 
                   loss_lambda=loss_lambda, savePath=<save_path>, reload=<policy_train_model_name>)

Note: Please refer to Example.ipynb for more details.

References

[1] Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt. The cityscapes dataset for semantic urban scene understanding. CVPR, 3213-3223, 2016.

[2] Silberman, Nathan and Hoiem, Derek and Kohli, Pushmeet and Fergus, Rob. Indoor segmentation and support inference from rgbd images. ECCV, 746-760, 2012.

[3] Zamir, Amir R and Sax, Alexander and Shen, William and Guibas, Leonidas J and Malik, Jitendra and Savarese, Silvio. Taskonomy: Disentangling task transfer learning. CVPR, 3712-3722, 2018.

[4] Chen, Liang-Chieh and Papandreou, George and Kokkinos, Iasonas and Murphy, Kevin and Yuille, Alan L. Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. PAMI, 834-848, 2017.

Owner
Ivy Zhang
Ivy Zhang
Go from graph data to a secure and interactive visual graph app in 15 minutes. Batteries-included self-hosting of graph data apps with Streamlit, Graphistry, RAPIDS, and more!

✔️ Linux ✔️ OS X ❌ Windows (#39) Welcome to graph-app-kit Turn your graph data into a secure and interactive visual graph app in 15 minutes! Why This

Graphistry 107 Jan 02, 2023
Official Pytorch Implementation of Adversarial Instance Augmentation for Building Change Detection in Remote Sensing Images.

IAug_CDNet Official Implementation of Adversarial Instance Augmentation for Building Change Detection in Remote Sensing Images. Overview We propose a

53 Dec 02, 2022
DeepMetaHandles: Learning Deformation Meta-Handles of 3D Meshes with Biharmonic Coordinates

DeepMetaHandles (CVPR2021 Oral) [paper] [animations] DeepMetaHandles is a shape deformation technique. It learns a set of meta-handles for each given

Liu Minghua 73 Dec 15, 2022
NeurIPS 2021, "Fine Samples for Learning with Noisy Labels"

[Official] FINE Samples for Learning with Noisy Labels This repository is the official implementation of "FINE Samples for Learning with Noisy Labels"

mythbuster 27 Dec 23, 2022
A PyTorch implementation of Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks

SVHNClassifier-PyTorch A PyTorch implementation of Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks If

Potter Hsu 182 Jan 03, 2023
Facestar dataset. High quality audio-visual recordings of human conversational speech.

Facestar Dataset Description Existing audio-visual datasets for human speech are either captured in a clean, controlled environment but contain only a

Meta Research 87 Dec 21, 2022
Contains source code for the winning solution of the xView3 challenge

Winning Solution for xView3 Challenge This repository contains source code and pretrained models for my (Eugene Khvedchenya) solution to xView 3 Chall

Eugene Khvedchenya 51 Dec 30, 2022
Official repo for AutoInt: Automatic Integration for Fast Neural Volume Rendering in CVPR 2021

AutoInt: Automatic Integration for Fast Neural Volume Rendering CVPR 2021 Project Page | Video | Paper PyTorch implementation of automatic integration

Stanford Computational Imaging Lab 149 Dec 22, 2022
Code for "SRHEN: Stepwise-Refining Homography Estimation Network via Parsing Geometric Correspondences in Deep Latent Space"

SRHEN This is a better and simpler implementation for "SRHEN: Stepwise-Refining Homography Estimation Network via Parsing Geometric Correspondences in

1 Oct 28, 2022
Cache Requests in Deta Bases and Echo them with Deta Micros

Deta Echo Cache Leverage the awesome Deta Micros and Deta Base to cache requests and echo them as needed. Stop worrying about slow public APIs or agre

Gingerbreadfork 8 Dec 07, 2021
Spatial Attentive Single-Image Deraining with a High Quality Real Rain Dataset (CVPR'19)

Spatial Attentive Single-Image Deraining with a High Quality Real Rain Dataset (CVPR'19) Tianyu Wang*, Xin Yang*, Ke Xu, Shaozhe Chen, Qiang Zhang, Ry

Steve Wong 177 Dec 01, 2022
Ankou: Guiding Grey-box Fuzzing towards Combinatorial Difference

Ankou Ankou is a source-based grey-box fuzzer. It intends to use a more rich fitness function by going beyond simple branch coverage and considering t

SoftSec Lab 54 Dec 24, 2022
Dynamics-aware Adversarial Attack of 3D Sparse Convolution Network

Leaded Gradient Method (LGM) This repository contains the PyTorch implementation for paper Dynamics-aware Adversarial Attack of 3D Sparse Convolution

An Tao 2 Oct 18, 2022
A simple image/video to Desmos graph converter run locally

Desmos Bezier Renderer A simple image/video to Desmos graph converter run locally Sample Result Setup Install dependencies apt update apt install git

Kevin JY Cui 339 Dec 23, 2022
Data reduction pipeline for KOALA on the AAT.

KOALA KOALA, the Kilofibre Optical AAT Lenslet Array, is a wide-field, high efficiency, integral field unit used by the AAOmega spectrograph on the 3.

4 Sep 26, 2022
Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

🦩 Flamingo - Pytorch Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the p

Phil Wang 630 Dec 28, 2022
Voxel-based Network for Shape Completion by Leveraging Edge Generation (ICCV 2021, oral)

Voxel-based Network for Shape Completion by Leveraging Edge Generation This is the PyTorch implementation for the paper "Voxel-based Network for Shape

10 Dec 04, 2022
A small library for doing fluid simulation with neural networks.

Neural Fluid Fields This is a small library for doing fluid simulation with neural fields. Check out our review paper, Neural Fields in Visual Computi

Towaki 23 Jun 23, 2022
TensorFlow implementation of ENet, trained on the Cityscapes dataset.

segmentation TensorFlow implementation of ENet (https://arxiv.org/pdf/1606.02147.pdf) based on the official Torch implementation (https://github.com/e

Fredrik Gustafsson 248 Dec 16, 2022
Official PyTorch implementation of the paper Image-Based CLIP-Guided Essence Transfer.

TargetCLIP- official pytorch implementation of the paper Image-Based CLIP-Guided Essence Transfer This repository finds a global direction in StyleGAN

Hila Chefer 221 Dec 13, 2022