DrNAS: Dirichlet Neural Architecture Search

Related tags

Deep LearningDrNAS
Overview

DrNAS

About

Code accompanying the paper
ICLR'2021: DrNAS: Dirichlet Neural Architecture Search paper
Xiangning Chen*, Ruochen Wang*, Minhao Cheng*, Xiaocheng Tang, Cho-Jui Hsieh

This code is based on the implementation of NAS-Bench-201 and PC-DARTS.

This paper proposes a novel differentiable architecture search method by formulating it into a distribution learning problem. We treat the continuously relaxed architecture mixing weight as random variables, modeled by Dirichlet distribution. With recently developed pathwise derivatives, the Dirichlet parameters can be easily optimized with gradient-based optimizer in an end-to-end manner. This formulation improves the generalization ability and induces stochasticity that naturally encourages exploration in the search space. Furthermore, to alleviate the large memory consumption of differentiable NAS, we propose a simple yet effective progressive learning scheme that enables searching directly on large-scale tasks, eliminating the gap between search and evaluation phases. Extensive experiments demonstrate the effectiveness of our method. Specifically, we obtain a test error of 2.46% for CIFAR-10, 23.7% for ImageNet under the mobile setting. On NAS-Bench-201, we also achieve state-of-the-art results on all three datasets and provide insights for the effective design of neural architecture search algorithms.

Results

On NAS-Bench-201

The table below shows the test accuracy on NAS-Bench-201 space. We achieve the state-of-the-art results on all three datasets. On CIFAR-100, DrNAS even achieves the global optimal with no variance!

Method CIFAR-10 (test) CIFAR-100 (test) ImageNet-16-120 (test)
ENAS 54.30 ± 0.00 10.62 ± 0.27 16.32 ± 0.00
DARTS 54.30 ± 0.00 38.97 ± 0.00 18.41 ± 0.00
SNAS 92.77 ± 0.83 69.34 ± 1.98 43.16 ± 2.64
PC-DARTS 93.41 ± 0.30 67.48 ± 0.89 41.31 ± 0.22
DrNAS (ours) 94.36 ± 0.00 73.51 ± 0.00 46.34 ± 0.00
optimal 94.37 73.51 47.31

For every search process, we sample 100 architectures from the current Dirichlet distribution and plot their accuracy range along with the current architecture selected by Dirichlet mean (solid line). The figure below shows that the accuracy range of the sampled architectures starts very wide but narrows gradually during the search phase. It indicates that DrNAS learns to encourage exploration at the early stages and then gradually reduces it towards the end as the algorithm becomes more and more confident of the current choice. Moreover, the performance of our architectures can consistently match the best performance of the sampled architectures, indicating the effectiveness of DrNAS.

On DARTS Space (CIFAR-10)

DrNAS achieves an average test error of 2.46%, ranking top amongst recent NAS results.

Method Test Error (%) Params (M) Search Cost (GPU days)
ENAS 2.89 4.6 0.5
DARTS 2.76 ± 0.09 3.3 1.0
SNAS 2.85 ± 0.02 2.8 1.5
PC-DARTS 2.57 ± 0.07 3.6 0.1
DrNAS (ours) 2.46 ± 0.03 4.1 0.6

On DARTS Space (ImageNet)

DrNAS can perform a direct search on ImageNet and achieves a top-1 test error below 24.0%!

Method Top-1 Error (%) Params (M) Search Cost (GPU days)
DARTS* 26.7 4.7 1.0
SNAS* 27.3 4.3 1.5
PC-DARTS 24.2 5.3 3.8
DSNAS 25.7 - -
DrNAS (ours) 23.7 5.7 4.6

* not a direct search

Usage

Architecture Search

Search on NAS-Bench-201 Space: (3 datasets to choose from)

  • Data preparation: Please first download the 201 benchmark file and prepare the api follow this repository.

  • cd 201-space && python train_search.py

  • With Progressively Pruning: cd 201-space && python train_search_progressive.py

Search on DARTS Space:

  • Data preparation: For a direct search on ImageNet, we follow PC-DARTS to sample 10% and 2.5% images for earch class as train and validation.

  • CIFAR-10: cd DARTS-space && python train_search.py

  • ImageNet: cd DARTS-space && python train_search_imagenet.py

Architecture Evaluation

  • CIFAR-10: cd DARTS-space && python train.py --cutout --auxiliary

  • ImageNet: cd DARTS-space && python train_imagenet.py --auxiliary

Reference

If you find this code useful in your research please cite

@inproceedings{chen2021drnas,
    title={Dr{\{}NAS{\}}: Dirichlet Neural Architecture Search},
    author={Xiangning Chen and Ruochen Wang and Minhao Cheng and Xiaocheng Tang and Cho-Jui Hsieh},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=9FWas6YbmB3}
}

Related Publications

Owner
Xiangning Chen
UCLA CS Ph.D. Student
Xiangning Chen
Code for 'Blockwise Sequential Model Learning for Partially Observable Reinforcement Learning' (AAAI 2022)

Blockwise Sequential Model Learning Code for 'Blockwise Sequential Model Learning for Partially Observable Reinforcement Learning' (AAAI 2022) For ins

2 Jun 17, 2022
PyTorch implementation of the Transformer in Post-LN (Post-LayerNorm) and Pre-LN (Pre-LayerNorm).

Transformer-PyTorch A PyTorch implementation of the Transformer from the paper Attention is All You Need in both Post-LN (Post-LayerNorm) and Pre-LN (

Jared Wang 22 Feb 27, 2022
R-package accompanying the paper "Dynamic Factor Model for Functional Time Series: Identification, Estimation, and Prediction"

dffm The goal of dffm is to provide functionality to apply the methods developed in the paper “Dynamic Factor Model for Functional Time Series: Identi

Sven Otto 3 Dec 09, 2022
《Fst Lerning of Temporl Action Proposl vi Dense Boundry Genertor》(AAAI 2020)

Update 2020.03.13: Release tensorflow-version and pytorch-version DBG complete code. 2019.11.12: Release tensorflow-version DBG inference code. 2019.1

Tencent 338 Dec 16, 2022
PyTorch implementation of our Adam-NSCL algorithm from our CVPR2021 (oral) paper "Training Networks in Null Space for Continual Learning"

Adam-NSCL This is a PyTorch implementation of Adam-NSCL algorithm for continual learning from our CVPR2021 (oral) paper: Title: Training Networks in N

Shipeng Wang 34 Dec 21, 2022
LF-YOLO (Lighter and Faster YOLO) is used to detect defect of X-ray weld image.

This project is based on ultralytics/yolov3. LF-YOLO (Lighter and Faster YOLO) is used to detect defect of X-ray weld image. The related paper is avai

26 Dec 13, 2022
Model Agnostic Interpretability for Multiple Instance Learning

MIL Model Agnostic Interpretability This repo contains the code for "Model Agnostic Interpretability for Multiple Instance Learning". Overview Executa

Joe Early 10 Dec 17, 2022
Deep Learning Models for Causal Inference

Extensive tutorials for learning how to build deep learning models for causal inference using selection on observables in Tensorflow 2.

Bernard J Koch 151 Dec 31, 2022
source code of Adversarial Feedback Loop Paper

Adversarial Feedback Loop [ArXiv] [project page] Official repository of Adversarial Feedback Loop paper Firas Shama, Roey Mechrez, Alon Shoshan, Lihi

17 Jul 20, 2022
Code for LIGA-Stereo Detector, ICCV'21

LIGA-Stereo Introduction This is the official implementation of the paper LIGA-Stereo: Learning LiDAR Geometry Aware Representations for Stereo-based

Xiaoyang Guo 75 Dec 09, 2022
natural image generation using ConvNets

The Eyescream Project Generating Natural Images using Neural Networks. For our research summary on this work, please read the Arxiv paper: http://arxi

Meta Archive 601 Nov 23, 2022
Photo2cartoon - 人像卡通化探索项目 (photo-to-cartoon translation project)

人像卡通化 (Photo to Cartoon) 中文版 | English Version 该项目为小视科技卡通肖像探索项目。您可使用微信扫描下方二维码或搜索“AI卡通秀”小程序体验卡通化效果。

Minivision_AI 3.5k Dec 30, 2022
Pytorch based library to rank predicted bounding boxes using text/image user's prompts.

pytorch_clip_bbox: Implementation of the CLIP guided bbox ranking for Object Detection. Pytorch based library to rank predicted bounding boxes using t

Sergei Belousov 50 Nov 27, 2022
Modified fork of Xuebin Qin's U-2-Net Repository. Used for demonstration purposes.

U^2-Net (U square net) Modified version of U2Net used for demonstation purposes. Paper: U^2-Net: Going Deeper with Nested U-Structure for Salient Obje

Shreyas Bhat Kera 13 Aug 28, 2022
Creating a custom CNN hypertunned architeture for the Fashion MNIST dataset with Python, Keras and Tensorflow.

custom-cnn-fashion-mnist Creating a custom CNN hypertunned architeture for the Fashion MNIST dataset with Python, Keras and Tensorflow. The following

Danielle Almeida 1 Mar 05, 2022
Img-process-manual - Utilize Python Numpy and Matplotlib to realize OpenCV baisc image processing function

Img-process-manual - Opencv Library basic graphic processing algorithm coding reproduction based on Numpy and Matplotlib library

Jack_Shaw 2 Dec 12, 2022
Graph Attention Networks

GAT Graph Attention Networks (Veličković et al., ICLR 2018): https://arxiv.org/abs/1710.10903 GAT layer t-SNE + Attention coefficients on Cora Overvie

Petar Veličković 2.6k Jan 05, 2023
Diffusion Normalizing Flow (DiffFlow) Neurips2021

Diffusion Normalizing Flow (DiffFlow) Reproduce setup environment The repo heavily depends on jam, a personal toolbox developed by Qsh.zh. The API may

76 Jan 01, 2023
Repository for GNSS-based position estimation using a Deep Neural Network

Code repository accompanying our work on 'Improving GNSS Positioning using Neural Network-based Corrections'. In this paper, we present a Deep Neural

32 Dec 13, 2022
The repository offers the official implementation of our paper in PyTorch.

Cloth Interactive Transformer (CIT) Cloth Interactive Transformer for Virtual Try-On Bin Ren1, Hao Tang1, Fanyang Meng2, Runwei Ding3, Ling Shao4, Phi

Bingoren 49 Dec 01, 2022