Code for Parameter Prediction for Unseen Deep Architectures (NeurIPS 2021)

Related tags

Deep Learningppuda
Overview

Parameter Prediction for Unseen Deep Architectures (NeurIPS 2021)

authors: Boris Knyazev, Michal Drozdzal, Graham Taylor, Adriana Romero-Soriano

Overview Results on ResNet-50

This repository contains the code to train and evaluate Graph HyperNetworks (GHNs). This repository also contains the DeepNets-1M dataset of neural architectures proposed in our paper to train and evaluate GHNs. Our improved GHNs trained on our DeepNets-1M allow to predict parameters for diverse networks, even if they are very different from those used to train GHNs (e.g. ResNet-50). Parameter prediction by GHNs is performed in a single forward pass and on average takes < 1 second either on GPU or CPU!

Table of Contents

Requirements and installation

The main requirements are:

  • Python 3.6+
  • PyTorch 1.9+
  • NetworkX

For graph visualizations we use pygraphviz. To make it work, graphviz may need to be installed as sudo apt-get install graphviz graphviz-dev.

To install the ppuda package

pip install .  # use pip install -e . to install an editable version

To obtain all packages required for every file

pip install -r requirements.txt

It's also possible to use conda to install this by running:

conda create --name ppuda --file requirements.txt --channel default --channel anaconda --channel conda-forge --channel pytorch

Available GHNs

We release five GHNs. Below are top-1 accuracies on CIFAR-10 and top-5 accuracies on ImageNet using the parameters predicted by one of the GHNs.

Model ResNet-50* Best Architecture (index)
MLP-CIFAR-10 17.7 60.2 (167)
GHN-1-CIFAR-10 19.2 59.9 (179)
GHN-2-CIFAR-10 58.6 77.1 (210)
GHN-1-ImageNet 6.9 32.1 (295)
GHN-2-ImageNet 5.3 48.3 (85)

* ResNet-50 is an unseen architecture (i.e. trained GHNs have not observed such or even similar architectures during training). Our GHNs can still predict good parameters for such unseen architectures. On ImageNet, even though the performance is low, the predicted parameters are very useful for fine-tuning.

Denotes the architecture index in the test split of DeepNets-1M.

Each GHN checkpoint takes just a few megabytes and is stored in the checkpoints folder of this repository.

Minimal example: predict parameters for ResNet-50

ImageNet:

from ppuda.ghn.nn import GHN2
import torchvision.models as models

ghn = GHN2('imagenet')      # load our GHN-2 trained on ImageNet
model = models.resnet50()   # ResNet-50 or any other torchvision model
model = ghn(model)          # predict parameters in < 1 second on GPU/CPU

# That's it! The model can be now evaluated on ImageNet to obtain top5=5.2%.

CIFAR-10:

from ppuda.ghn.nn import GHN2
import torchvision.models as models

# On CIFAR-10, we have an additional step of adjusting 
# the first layer(s) of the network for a 32x32 image size,
# since torchvision models expect a 224x224 input, 
# while GHNs on CIFAR-10 were trained on 32x32 inputs.

from ppuda.utils import adjust_net

ghn = GHN2('cifar10')                    # load our GHN-2 trained on CIFAR-10
model = models.resnet50(num_classes=10)  # ResNet-50 
model = adjust_net(model)                # adjust to a 32x32 input
model = ghn(model)                       # predict parameters in < 1 second on GPU/CPU

# That's it! The model can be now evaluated on CIFAR-10 to obtain top1=58.6%.

Full example for ResNet-50 and other torchvision models can be found in examples/torch_models.py and examples/all_torch_models.py. See other examples in examples.

Note 1: For the networks with batch norm, the running statistics of batch norm layers are not predicted (since these statistics are not trainable parameters). So to evaluate such networks, our code computes batch statistics on the evaluation set with batch size = 64. The networks without batch norm (e.g. in our BN-Free split) have the same accuracies regardless of the batch size.

Note 2: To evaluate/train on ImageNet, follow the data instructions below on how to prepare the ImageNet dataset.

Data

DeepNets-1M

To train or evaluate on DeepNets-1M, first download the dataset file by running ./data/download.sh.

To generate a new DeepNets-1M dataset, the following command can be used:

python experiments/net_generator.py train 1000000 ./data

The dataset generated using this command should be close to our training dataset. Other splits can be regenerated by specifying the split as the first argument.

CIFAR-10

CIFAR-10 is downloaded automatically and is saved in the --data_dir folder (default is ./data).

ImageNet

We implemented a simple wrapper of the torchvision.datasets.ImageNet implementation. The ImageNet root folder imagenet is expected to be in the ./data folder by default with the following structure:

./data
│   imagenet
│   │   train
|   |   |    n01440764
|   |   |    n01443537
|   |   |    ...
│   │   val
|   |   |    n01440764
|   |   |    n01443537
|   |   |    ...
│   │   ILSVRC2012_devkit_t12.tar.gz
│   deepnets1m_train.hdf5       
|   deepnets1m_train_meta.json
|   ...

Both imagenet/train and imagenet/val must contain separate folders for each class. Follow the official instructions on how to obtain ImageNet (ILSVRC 2012) data.

Reproducing main results

The arguments of our scripts are described in config.py. The default hyperparameters are based on our paper. Below, the examples to run the scripts and override the default hyperparameters are shown.

DeepNets-1M results

Training GHN

  • GHN-1 on CIFAR-10: python experiments/train_ghn.py --name ghn1

  • GHN-2 on CIFAR-10: python experiments/train_ghn.py -m 8 -n -v 50 --ln --name ghn2

  • MLP on CIFAR-10: python experiments/train_ghn.py -m 8 -n -v 50 --ln -H mlp --name mlp

where -m 8 denotes meta batch size = 8, -n denotes to normalize predicted parameters, -v 50 denotes adding virtual edges to graphs with 50 as the maximum shortest path length, --ln denotes adding layer normalization before decoding the parameters, --name ghn2 denotes the directory name where to save trained checkpoints (which is combined with --save_dir to obtain the full path), -H mlp denotes using MLP instead of GatedGNN.

To train on Imagenet, use -d imagenet. To train GHNs on multiple GPUs (e.g. with a large meta batch size), add --multigpu to use all CUDA devices available (make sure to set CUDA_VISIBLE_DEVICES appropriately).

For example, to train GHN-2 on Imagenet and 4 GPUs: export CUDA_VISIBLE_DEVICES=0,1,2,3; python experiments/train_ghn.py -m 8 -n -v 50 --ln --name ghn2_imagenet -d imagenet --multigpu

Evaluating GHNs

  • Evaluate GHN-2 on CIFAR-10 on all architectures of $split from DeepNets-1M: python experiments/eval_ghn.py --ckpt ./checkpoints/ghn2_cifar10.pt -d cifar10 --split $split

  • Evaluate GHN-2 on CIFAR-10 on a single architecture from DeepNets-1M: python experiments/eval_ghn.py --ckpt ./checkpoints/ghn2_cifar10.pt -d cifar10 --split $split --arch $ind

where $split is one from val, test, wide, deep, dense, bnfree, predefined, $ind is an integer index of the architecture in a split.

Training and evaluating SGD

  • Train architecture=0 from the test split of DeepNets-1M for 50 epochs on CIFAR-10: python experiments/sgd/train_net.py --split test --arch 0 --epochs 50

  • Train the best architecture from the DARTS paper for 50 epochs on CIFAR-10: python experiments/sgd/train_net.py --arch DARTS --epochs 50

  • Train architecture=0 from the wide split of DeepNets-1M for 1 epoch on ImageNet: python experiments/sgd/train_net.py --split wide --arch 0 --epochs 1 -d imagenet

Fine-tuning predicted parameters on other tasks

The parameters predicted by GHN-2 trained on ImageNet can be fine-tuned on any vision dataset, such as CIFAR-10.

100-shot CIFAR-10

  • Fine-tune ResNet-50 initialized with the parameters predicted by GHN-1-ImageNet: python experiments/sgd/train_net.py --split predefined --arch 0 --epochs 50 -d cifar10 --n_shots 100 --wd 1e-3 --ckpt ./checkpoints/ghn1_imagenet.pt

  • Fine-tune ResNet-50 initialized with the parameters predicted by GHN-2-ImageNet: python experiments/sgd/train_net.py --split predefined --arch 0 --epochs 50 -d cifar10 --n_shots 100 --wd 1e-3 --ckpt ./checkpoints/ghn2_imagenet.pt

  • Fine-tune ResNet-50 initialized randomly with Kaiming He's method: python experiments/sgd/train_net.py --split predefined --arch 0 --epochs 50 -d cifar10 --n_shots 100 --wd 1e-3

  • Fine-tune ResNet-50 pretrained on Imagenet: python experiments/sgd/train_net.py --split predefined --arch 0 --epochs 50 -d cifar10 --n_shots 100 --wd 1e-3 --pretrained

  • Fine-tune ViT initialized with the parameters predicted by GHN-2-ImageNet: python experiments/sgd/train_net.py --split predefined --arch 1 --epochs 50 -d cifar10 --n_shots 100 --wd 1e-3 --ckpt ./checkpoints/ghn2_imagenet.pt

  • Fine-tune DARTS initialized with the parameters predicted by GHN-2-ImageNet: python experiments/sgd/train_net.py --arch DARTS --epochs 50 -d cifar10 --n_shots 100 --wd 1e-3 --init_channels 48 --layers 14 --ckpt ./checkpoints/ghn2_imagenet.pt

--wd 1e-3 was generally the best in these experiments. To report the results in the paper, we also tuned the initial learning rate on the 200 validation images of the 100-shot CIFAR-10 training set, so the results obtained with the scripts above might be a bit different from the reported ones.

Object detection

In the paper, we fine-tune on Penn-Fudan object detection. Our experiments are based on PyTorch Object Detection Finetuning Tutorial.

The dataset can be downloaded from here and should be put inside the ./data folder like ./data/PennFudanPed.

The commands to fine-tune/train networks for object detection are similar to those for 100-shot CIFAR-10 above, but are based on the experiments/sgd/detector/train_detector.py script and the hyperparameters from the tutorial. For example, to fine-tune DARTS initialized with the parameters predicted by GHN-2-ImageNet.

python experiments/sgd/detector/train_detector.py -d PennFudanPed --arch DARTS --ckpt ./checkpoints/ghn2_imagenet.pt --init_channels 48 --layers 14

Property prediction

To train and evaluate regression models on top of graph embeddings extracted using GHN-2-CIFAR-10:

python experiments/property_prediction.py cifar10 ./checkpoints/ghn2_cifar10.pt

The script will evaluate the four properties of architectures discussed in the paper: accuracy on the clean test set, accuracy on a corrupted test set, inference speed, and speed of convergence.

The extracted embeddings in the .npy format for each GHN are available in the checkpoints folder, but will be recomputed if they are missing.

NAS

Training the best (in terms of accuracy in this example) architecture on CIFAR-10 with SGD for 600 epochs according to the DARTS protocol:

python experiments/sgd/train_net.py --split search --arch 35133 --epochs 600 --cutout --drop_path_prob 0.2 --auxiliary

Architecture 35133 was found to be the best in the search split on CIFAR-10 using our GHN-2.

Visualization

Example of visualizing the computational graph of ResNet-50.

import torchvision
from ppuda.deepnets1m.graph import Graph

Graph(torchvision.models.resnet50()).visualize(node_size=100)

Example of visualizing the computational graph of the best DARTS network.

from ppuda.deepnets1m.graph import Graph
from ppuda.deepnets1m.net import Network
from ppuda.deepnets1m.genotypes import DARTS

model = Network(C=48, num_classes=1000, genotype=DARTS, n_cells=14)
Graph(model).visualize(node_size=50)
ResNet-50 ViT DARTS

See more examples for different architectures in examples/graph_visualization.ipynb.

License

The majority of PPUDA is licensed under MIT license, however portions of the project are available under separate license terms: DARTS is licensed under the Apache 2.0 license and NetworkX is licensed under the 3-Clause BSD license.

Contributions

Please submit a pull request or open a github issue (see the details). Make sure to comply with our code of conduct.

Acknowledgements

We thank the Vector AI Engineering team (Gerald Shen, Maria Koshkina and Deval Pandya) for code review.

Citation

@inproceedings{knyazev2021parameter,
  title={Parameter Prediction for Unseen Deep Architectures},
  author={Knyazev, Boris and Drozdzal, Michal and Taylor, Graham W and Romero-Soriano, Adriana},
  booktitle={Advances in Neural Information Processing Systems},
  year={2021}  
}
Owner
Facebook Research
Facebook Research
PyTorch implementation of hand mesh reconstruction described in CMR and MobRecon.

Hand Mesh Reconstruction Introduction This repo is the PyTorch implementation of hand mesh reconstruction described in CMR and MobRecon. Update 2021-1

Xingyu Chen 236 Dec 29, 2022
[CVPR2021] Look before you leap: learning landmark features for one-stage visual grounding.

LBYL-Net This repo implements paper Look Before You Leap: Learning Landmark Features For One-Stage Visual Grounding CVPR 2021. Getting Started Prerequ

SVIP Lab 45 Dec 12, 2022
HandFoldingNet ✌️ : A 3D Hand Pose Estimation Network Using Multiscale-Feature Guided Folding of a 2D Hand Skeleton

HandFoldingNet ✌️ : A 3D Hand Pose Estimation Network Using Multiscale-Feature Guided Folding of a 2D Hand Skeleton Wencan Cheng, Jae Hyun Park, Jong

cwc1260 23 Oct 21, 2022
This repo provides the official code for TransBTS: Multimodal Brain Tumor Segmentation Using Transformer (https://arxiv.org/pdf/2103.04430.pdf).

TransBTS: Multimodal Brain Tumor Segmentation Using Transformer This repo is the official implementation for TransBTS: Multimodal Brain Tumor Segmenta

Raymond 247 Dec 28, 2022
A Free and Open Source Python Library for Multiobjective Optimization

Platypus What is Platypus? Platypus is a framework for evolutionary computing in Python with a focus on multiobjective evolutionary algorithms (MOEAs)

Project Platypus 424 Dec 18, 2022
A tool to estimate time varying instantaneous reproduction number during epidemics

EpiEstim A tool to estimate time varying instantaneous reproduction number during epidemics. It is described in the following paper: @article{Cori2013

MRC Centre for Global Infectious Disease Analysis 78 Dec 19, 2022
Scalable and Elastic Deep Reinforcement Learning Using PyTorch. Please star. 🔥

ElegantRL “小雅”: Scalable and Elastic Deep Reinforcement Learning ElegantRL is developed for researchers and practitioners with the following advantage

AI4Finance Foundation 2.5k Jan 05, 2023
Unsupervised clustering of high content screen samples

Microscopium Unsupervised clustering and dataset exploration for high content screens. See microscopium in action Public dataset BBBC021 from the Broa

60 Dec 05, 2022
Official pytorch implementation of paper Dual-Level Collaborative Transformer for Image Captioning (AAAI 2021).

Dual-Level Collaborative Transformer for Image Captioning This repository contains the reference code for the paper Dual-Level Collaborative Transform

lyricpoem 160 Dec 11, 2022
RODD: A Self-Supervised Approach for Robust Out-of-Distribution Detection

RODD Official Implementation of 2022 CVPRW Paper RODD: A Self-Supervised Approach for Robust Out-of-Distribution Detection Introduction: Recent studie

Umar Khalid 17 Oct 11, 2022
Automated detection of anomalous exoplanet transits in light curve data.

Automatically detecting anomalous exoplanet transits This repository contains the source code for the paper "Automatically detecting anomalous exoplan

1 Feb 01, 2022
GyroSPD: Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices

GyroSPD Code for the paper "Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices" accepted at NeurIPS 2021. Re

Federico Lopez 12 Dec 12, 2022
A Python Reconnection Tool for alt:V

altv-reconnect What? It invokes a reconnect in the altV Client Dev Console. You get to determine when your local client should reconnect when developi

8 Jun 30, 2022
🚀 PyTorch Implementation of "Progressive Distillation for Fast Sampling of Diffusion Models(v-diffusion)"

PyTorch Implementation of "Progressive Distillation for Fast Sampling of Diffusion Models(v-diffusion)" Unofficial PyTorch Implementation of Progressi

Vitaliy Hramchenko 58 Dec 19, 2022
Differentiable Abundance Matching With Python

shamnet Differentiable Stellar Population Synthesis Installation You can install shamnet with pip. Installation dependencies are numpy, jax, corrfunc,

5 Dec 17, 2021
The easiest tool for extracting radiomics features and training ML models on them.

Simple pipeline for experimenting with radiomics features Installation git clone https://github.com/piotrekwoznicki/ClassyRadiomics.git cd classrad pi

Piotr Woźnicki 17 Aug 04, 2022
Auto Seg-Loss: Searching Metric Surrogates for Semantic Segmentation

Auto-Seg-Loss By Hao Li, Chenxin Tao, Xizhou Zhu, Xiaogang Wang, Gao Huang, Jifeng Dai This is the official implementation of the ICLR 2021 paper Auto

61 Dec 21, 2022
ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

Double-zh 45 Dec 19, 2022
Python library for computer vision labeling tasks. The core functionality is to translate bounding box annotations between different formats-for example, from coco to yolo.

PyLabel pip install pylabel PyLabel is a Python package to help you prepare image datasets for computer vision models including PyTorch and YOLOv5. I

PyLabel Project 176 Jan 01, 2023
Utilities to bridge Canvas-generated course rosters with GitLab's API.

gitlab-canvas-utils A collection of scripts originally written for CSE 13S. Oversees everything from GitLab course group creation, student repository

Eugene Chou 5 Jun 08, 2022