Official PyTorch implementation of PS-KD

Overview

LGCNS AI Research pytorch

Self-Knowledge Distillation with Progressive Refinement of Targets (PS-KD)

Accepted at ICCV 2021, oral presentation

  • Official PyTorch implementation of Self-Knowledge Distillation with Progressive Refinement of Targets (PS-KD).
    [Slides] [Paper] [Video]
  • Kyungyul Kim, ByeongMoon Ji, Doyoung Yoon and Sangheum Hwang

Abstract

The generalization capability of deep neural networks has been substantially improved by applying a wide spectrum of regularization methods, e.g., restricting function space, injecting randomness during training, augmenting data, etc. In this work, we propose a simple yet effective regularization method named progressive self-knowledge distillation (PS-KD), which progressively distills a model's own knowledge to soften hard targets (i.e., one-hot vectors) during training. Hence, it can be interpreted within a framework of knowledge distillation as a student becomes a teacher itself. Specifically, targets are adjusted adaptively by combining the ground-truth and past predictions from the model itself. Please refer to the paper for more details.

Requirements

We have tested the code on the following environments:

  • Python 3.7.7 / Pytorch (>=1.6.0) / torchvision (>=0.7.0)

Datasets

Currently, only CIFAR-100, ImageNet dataset is supported.

#) To verify the effectivness of PS-KD on Detection task and Machine translation task, we used

  • For object detection: Pascal VOC
  • For machine translation: IWSLT 15 English-German / German-English, Multi30k.
  • (Please refer to the paper for more details)

How to Run

Single-node & Multi-GPU Training

To train a single model with 1 nodes & multi-GPU, run the command as follows:

$ python3 main.py --lr 0.1 \
                  --lr_decay_schedule 150 225 \
                  --PSKD \
                  --experiments_dir '<set your own path>' \
                  --classifier_type 'ResNet18' \
                  --data_path '<root your own data path>' \
                  --data_type '<cifar100 or imagenet>' \
                  --alpha_T 0.8 \
                  --rank 0 \
                  --world_size 1 \
                  --multiprocessing_distributed True

Multi-node Training

To train a single model with 2 nodes, for instance, run the commands below in sequence:

# on the node #0
$ python3 main.py --lr 0.1 \
                  --lr_decay_schedule 150 225 \
                  --PSKD \
                  --experiments_dir '<set your own path>' \
                  --classifier_type 'ResNet18' \
                  --data_path '<root your own data path>' \
                  --data_type '<cifar100 or imagenet>' \
                  --alpha_T 0.8 \
                  --rank 0 \
                  --world_size 2 \
                  --dist_url tcp://{master_ip}:{master_port} \
                  --multiprocessing_distributed
# on the node #1
$ python3 main.py --lr 0.1 \
                  --lr_decay_schedule 150 225 \
                  --PSKD \
                  --experiments_dir '<set your own path>' \
                  --classifier_type 'ResNet18' \
                  --data_path '<root your own data path>' \
                  --data_type '<cifar100 or imagenet>' \
                  --alpha_T 0.8 \
                  --rank 1 \
                  --world_size 2 \
                  --dist_url tcp://{master_ip}:{master_port} \
                  --multiprocessing_distributed

Saving & Loading Checkpoints

Saved Filenames

  • save_dir will be automatically determined(with sequential number suffixes) unless otherwise designated.
  • Model's checkpoints are saved in ./{experiments_dir}/models/checkpoint_{epoch}.pth.
  • The best checkpoints are saved in ./{experiments_dir}/models/checkpoint_best.pth.

Loading Checkpoints (resume)

  • Pass model path as a --resume argument

Experimental Results

Performance measures

  • Top-1 Error / Top-5 Error
  • Negative Log Likelihood (NLL)
  • Expected Calibration Error (ECE)
  • Area Under the Risk-coverage Curve (AURC)

Results on CIFAR-100

Model + Method Dataset Top-1 Error Top-5 Error NLL ECE AURC
PreAct ResNet-18 (baseline) CIFAR-100 24.18 6.90 1.10 11.84 67.65
PreAct ResNet-18 + Label Smoothing CIFAR-100 20.94 6.02 0.98 10.79 57.74
PreAct ResNet-18 + CS-KD [CVPR'20] CIFAR-100 21.30 5.70 0.88 6.24 56.56
PreAct ResNet-18 + TF-KD [CVPR'20] CIFAR-100 22.88 6.01 1.05 11.96 61.77
PreAct ResNet-18 + PS-KD CIFAR-100 20.82 5.10 0.76 1.77 52.10
PreAct ResNet-101 (baseline) CIFAR-100 20.75 5.28 0.89 10.02 55.45
PreAct ResNet-101 + Label Smoothing CIFAR-100 19.84 5.07 0.93 3.43 95.76
PreAct ResNet-101 + CS-KD [CVPR'20] CIFAR-100 20.76 5.62 1.02 12.18 64.44
PreAct ResNet-101 + TF-KD [CVPR'20] CIFAR-100 20.13 5.10 0.84 6.14 58.8
PreAct ResNet-101 + PS-KD CIFAR-100 19.43 4.30 0.74 6.92 49.01
DenseNet-121 (baseline) CIFAR-100 20.05 4.99 0.82 7.34 52.21
DenseNet-121 + Label Smoothing CIFAR-100 19.80 5.46 0.92 3.76 91.06
DenseNet-121 + CS-KD [CVPR'20] CIFAR-100 20.47 6.21 1.07 13.80 73.37
DenseNet-121 + TF-KD [CVPR'20] CIFAR-100 19.88 5.10 0.85 7.33 69.23
DenseNet-121 + PS-KD CIFAR-100 18.73 3.90 0.69 3.71 45.55
ResNeXt-29 (baseline) CIFAR-100 18.65 4.47 0.74 4.17 44.27
ResNeXt-29 + Label Smoothing CIFAR-100 17.60 4.23 1.05 22.14 41.92
ResNeXt-29 + CS-KD [CVPR'20] CIFAR-100 18.26 4.37 0.80 5.95 42.11
ResNeXt-29 + TF-KD [CVPR'20] CIFAR-100 17.33 3.87 0.74 6.73 40.34
ResNeXt-29 + PS-KD CIFAR-100 17.28 3.60 0.72 9.18 40.19
PyramidNet-200 (baseline) CIFAR-100 16.80 3.69 0.73 8.04 36.95
PyramidNet-200 + Label Smoothing CIFAR-100 17.82 4.72 0.89 3.46 105.02
PyramidNet-200 + CS-KD [CVPR'20] CIFAR-100 18.31 5.70 1.17 14.70 70.05
PyramidNet-200 + TF-KD [CVPR'20] CIFAR-100 16.48 3.37 0.79 10.48 37.04
PyramidNet-200 + PS-KD CIFAR-100 15.49 3.08 0.56 1.83 32.14

Results on ImageNet

Model +Method Dataset Top-1 Error Top-5 Error NLL ECE AURC
DenseNet-264* ImageNet 22.15 6.12 -- -- --
ResNet-152 ImageNet 22.19 6.19 0.88 3.84 61.79
ResNet-152 + Label Smoothing ImageNet 21.73 5.85 0.92 3.91 68.24
ResNet-152 + CS-KD [CVPR'20] ImageNet 21.61 5.92 0.90 5.79 62.12
ResNet-152 + TF-KD [CVPR'20] ImageNet 22.76 6.43 0.91 4.70 65.28
ResNet-152 + PS-KD ImageNet 21.41 5.86 0.84 2.51 61.01

* denotes results reported in the original papers

Citation

If you find this repository useful, please consider giving a star and citation PS-KD:

@InProceedings{Kim_2021_ICCV,
    author    = {Kim, Kyungyul and Ji, ByeongMoon and Yoon, Doyoung and Hwang, Sangheum},
    title     = {Self-Knowledge Distillation With Progressive Refinement of Targets},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {6567-6576}
}

Contact for Issues

License

Copyright (c) 2021-present LG CNS Corp.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
Owner
Open source repository of LG CNS AI Research (LAIR), LG
Adversarial vulnerability of powerful near out-of-distribution detection

Adversarial vulnerability of powerful near out-of-distribution detection by Stanislav Fort In this repository we're collecting replications for the ke

Stanislav Fort 9 Aug 30, 2022
Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models

Molecular Sets (MOSES): A benchmarking platform for molecular generation models Deep generative models are rapidly becoming popular for the discovery

MOSES 656 Dec 29, 2022
Easy-to-use,Modular and Extendible package of deep-learning based CTR models .

DeepCTR DeepCTR is a Easy-to-use,Modular and Extendible package of deep-learning based CTR models along with lots of core components layers which can

浅梦 6.6k Jan 08, 2023
A Simple Key-Value Data-store written in Python

mercury-db This is a File Based Key-Value Datastore that supports basic CRUD (Create, Read, Update, Delete) operations developed using Python. The dat

Vaidhyanathan S M 1 Jan 09, 2022
KGDet: Keypoint-Guided Fashion Detection (AAAI 2021)

KGDet: Keypoint-Guided Fashion Detection (AAAI 2021) This is an official implementation of the AAAI-2021 paper "KGDet: Keypoint-Guided Fashion Detecti

Qian Shenhan 35 Dec 29, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers

ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers Official implementation of ViewFormer. ViewFormer is a NeRF-free neural rend

Jonáš Kulhánek 169 Dec 30, 2022
Get started learning C# with C# notebooks powered by .NET Interactive and VS Code.

.NET Interactive Notebooks for C# Welcome to the home of .NET interactive notebooks for C#! How to Install Download the .NET Coding Pack for VS Code f

.NET Platform 425 Dec 25, 2022
Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis (CVPR2022)

Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis Multi-View Consistent Generative Adversarial Networks for 3D-aware

Xuanmeng Zhang 78 Dec 10, 2022
ImageNet Adversarial Image Evaluation

ImageNet Adversarial Image Evaluation This repository contains the code and some materials used in the experimental work presented in the following pa

Utku Ozbulak 11 Dec 26, 2022
Learning to Initialize Neural Networks for Stable and Efficient Training

GradInit This repository hosts the code for experiments in the paper, GradInit: Learning to Initialize Neural Networks for Stable and Efficient Traini

Chen Zhu 124 Dec 30, 2022
Data from "HateCheck: Functional Tests for Hate Speech Detection Models" (Röttger et al., ACL 2021)

In this repo, you can find the data from our ACL 2021 paper "HateCheck: Functional Tests for Hate Speech Detection Models". "test_suite_cases.csv" con

Paul Röttger 43 Nov 11, 2022
Interactive Visualization to empower domain experts to align ML model behaviors with their knowledge.

An interactive visualization system designed to helps domain experts responsibly edit Generalized Additive Models (GAMs). For more information, check

InterpretML 83 Jan 04, 2023
The implementation of the paper "A Deep Feature Aggregation Network for Accurate Indoor Camera Localization".

A Deep Feature Aggregation Network for Accurate Indoor Camera Localization This is the PyTorch implementation of our paper "A Deep Feature Aggregation

9 Dec 09, 2022
Measuring and Improving Consistency in Pretrained Language Models

ParaRel 🤘 This repository contains the code and data for the paper: Measuring and Improving Consistency in Pretrained Language Models as well as the

Yanai Elazar 26 Dec 02, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
A higher performance pytorch implementation of DeepLab V3 Plus(DeepLab v3+)

A Higher Performance Pytorch Implementation of DeepLab V3 Plus Introduction This repo is an (re-)implementation of Encoder-Decoder with Atrous Separab

linhua 326 Nov 22, 2022
Evaluating different engineering tricks that make RL work

Reinforcement Learning Tricks, Index This repository contains the code for the paper "Distilling Reinforcement Learning Tricks for Video Games". Short

Anssi 15 Dec 26, 2022
Implementation of Advantage-Weighted Regression: Simple and Scalable Off-Policy Reinforcement Learning

advantage-weighted-regression Implementation of Advantage-Weighted Regression: Simple and Scalable Off-Policy Reinforcement Learning, by Peng et al. (

Omar D. Domingues 1 Dec 02, 2021
Human Pose estimation with TensorFlow framework

Human Pose Estimation with TensorFlow Here you can find the implementation of the Human Body Pose Estimation algorithm, presented in the DeeperCut and

Eldar Insafutdinov 1.1k Dec 29, 2022