Fine-tune pretrained Convolutional Neural Networks with PyTorch

Overview

Fine-tune pretrained Convolutional Neural Networks with PyTorch.

PyPI CircleCI codecov.io

Features

  • Gives access to the most popular CNN architectures pretrained on ImageNet.
  • Automatically replaces classifier on top of the network, which allows you to train a network with a dataset that has a different number of classes.
  • Allows you to use images with any resolution (and not only the resolution that was used for training the original model on ImageNet).
  • Allows adding a Dropout layer or a custom pooling layer.

Supported architectures and models

From the torchvision package:

  • ResNet (resnet18, resnet34, resnet50, resnet101, resnet152)
  • ResNeXt (resnext50_32x4d, resnext101_32x8d)
  • DenseNet (densenet121, densenet169, densenet201, densenet161)
  • Inception v3 (inception_v3)
  • VGG (vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn)
  • SqueezeNet (squeezenet1_0, squeezenet1_1)
  • MobileNet V2 (mobilenet_v2)
  • ShuffleNet v2 (shufflenet_v2_x0_5, shufflenet_v2_x1_0)
  • AlexNet (alexnet)
  • GoogLeNet (googlenet)

From the Pretrained models for PyTorch package:

  • ResNeXt (resnext101_32x4d, resnext101_64x4d)
  • NASNet-A Large (nasnetalarge)
  • NASNet-A Mobile (nasnetamobile)
  • Inception-ResNet v2 (inceptionresnetv2)
  • Dual Path Networks (dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107)
  • Inception v4 (inception_v4)
  • Xception (xception)
  • Squeeze-and-Excitation Networks (senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d)
  • PNASNet-5-Large (pnasnet5large)
  • PolyNet (polynet)

Requirements

  • Python 3.5+
  • PyTorch 1.1+

Installation

pip install cnn_finetune

Major changes:

Version 0.4

  • Default value for pretrained argument in make_model is changed from False to True. Now call make_model('resnet18', num_classes=10) is equal to make_model('resnet18', num_classes=10, pretrained=True)

Example usage:

Make a model with ImageNet weights for 10 classes

from cnn_finetune import make_model

model = make_model('resnet18', num_classes=10, pretrained=True)

Make a model with Dropout

model = make_model('nasnetalarge', num_classes=10, pretrained=True, dropout_p=0.5)

Make a model with Global Max Pooling instead of Global Average Pooling

import torch.nn as nn

model = make_model('inceptionresnetv2', num_classes=10, pretrained=True, pool=nn.AdaptiveMaxPool2d(1))

Make a VGG16 model that takes images of size 256x256 pixels

VGG and AlexNet models use fully-connected layers, so you have to additionally pass the input size of images when constructing a new model. This information is needed to determine the input size of fully-connected layers.

model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256))

Make a VGG16 model that takes images of size 256x256 pixels and uses a custom classifier

import torch.nn as nn

def make_classifier(in_features, num_classes):
    return nn.Sequential(
        nn.Linear(in_features, 4096),
        nn.ReLU(inplace=True),
        nn.Linear(4096, num_classes),
    )

model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256), classifier_factory=make_classifier)

Show preprocessing that was used to train the original model on ImageNet

>> model = make_model('resnext101_64x4d', num_classes=10, pretrained=True)
>> print(model.original_model_info)
ModelInfo(input_space='RGB', input_size=[3, 224, 224], input_range=[0, 1], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
>> print(model.original_model_info.mean)
[0.485, 0.456, 0.406]

CIFAR10 Example

See examples/cifar10.py file (requires PyTorch 1.1+).

Owner
Alex Parinov
Computer Vision Architect
Alex Parinov
Breaching - Breaching privacy in federated learning scenarios for vision and text

Breaching - A Framework for Attacks against Privacy in Federated Learning This P

Jonas Geiping 139 Jan 03, 2023
An Open Source Machine Learning Framework for Everyone

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

170.1k Jan 05, 2023
Implementation of ICLR 2020 paper "Revisiting Self-Training for Neural Sequence Generation"

Self-Training for Neural Sequence Generation This repo includes instructions for running noisy self-training algorithms from the following paper: Revi

Junxian He 45 Dec 31, 2022
This repo includes the supplementary of our paper "CEMENT: Incomplete Multi-View Weak-Label Learning with Long-Tailed Labels"

Supplementary Materials for CEMENT: Incomplete Multi-View Weak-Label Learning with Long-Tailed Labels This repository includes all supplementary mater

Zhiwei Li 0 Jan 05, 2022
A curated list of automated deep learning (including neural architecture search and hyper-parameter optimization) resources.

Awesome AutoDL A curated list of automated deep learning related resources. Inspired by awesome-deep-vision, awesome-adversarial-machine-learning, awe

D-X-Y 2k Dec 30, 2022
BossNAS: Exploring Hybrid CNN-transformers with Block-wisely Self-supervised Neural Architecture Search

BossNAS This repository contains PyTorch evaluation code, retraining code and pretrained models of our paper: BossNAS: Exploring Hybrid CNN-transforme

Changlin Li 127 Dec 26, 2022
Home repository for the Regularized Greedy Forest (RGF) library. It includes original implementation from the paper and multithreaded one written in C++, along with various language-specific wrappers.

Regularized Greedy Forest Regularized Greedy Forest (RGF) is a tree ensemble machine learning method described in this paper. RGF can deliver better r

RGF-team 364 Dec 28, 2022
Transport Mode detection - can detect the mode of transport with the help of features such as acceeration,jerk etc

title emoji colorFrom colorTo sdk app_file pinned Transport_Mode_Detector 🚀 purple yellow gradio app.py false Configuration title: string Display tit

Nishant Rajadhyaksha 3 Jan 16, 2022
Official repository for "Restormer: Efficient Transformer for High-Resolution Image Restoration". SOTA for motion deblurring, image deraining, denoising (Gaussian/real data), and defocus deblurring.

Restormer: Efficient Transformer for High-Resolution Image Restoration Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan,

Syed Waqas Zamir 906 Dec 30, 2022
My solutions for Stanford University course CS224W: Machine Learning with Graphs Fall 2021 colabs (GNN, GAT, GraphSAGE, GCN)

machine-learning-with-graphs My solutions for Stanford University course CS224W: Machine Learning with Graphs Fall 2021 colabs Course materials can be

Marko Njegomir 7 Dec 14, 2022
A tool for making map images from OpenTTD save games

OpenTTD Surveyor A tool for making map images from OpenTTD save games. This is not part of the main OpenTTD codebase, nor is it ever intended to be pa

Aidan Randle-Conde 9 Feb 15, 2022
an implementation of Video Frame Interpolation via Adaptive Separable Convolution using PyTorch

This work has now been superseded by: https://github.com/sniklaus/revisiting-sepconv sepconv-slomo This is a reference implementation of Video Frame I

Simon Niklaus 985 Jan 08, 2023
Pytorch implementation of "Geometrically Adaptive Dictionary Attack on Face Recognition" (WACV 2022)

Geometrically Adaptive Dictionary Attack on Face Recognition This is the Pytorch code of our paper "Geometrically Adaptive Dictionary Attack on Face R

6 Nov 21, 2022
Source code for GNN-LSPE (Graph Neural Networks with Learnable Structural and Positional Representations)

Graph Neural Networks with Learnable Structural and Positional Representations Source code for the paper "Graph Neural Networks with Learnable Structu

Vijay Prakash Dwivedi 180 Dec 22, 2022
A comprehensive and up-to-date developer education platform for Urbit.

curriculum A comprehensive and up-to-date developer education platform for Urbit. This project organizes developer capabilities into a hierarchy of co

Sigilante 36 Oct 04, 2022
YOLO5Face: Why Reinventing a Face Detector (https://arxiv.org/abs/2105.12931)

Introduction Yolov5-face is a real-time,high accuracy face detection. Performance Single Scale Inference on VGA resolution(max side is equal to 640 an

DeepCam Shenzhen 1.4k Jan 07, 2023
The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021]. https://arxiv.org/pdf/2101.12378.pdf

NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021] Release Notes The offical PyTorch implementation of NeMo, p

Angtian Wang 76 Nov 23, 2022
🥇 LG-AI-Challenge 2022 1위 솔루션 입니다.

LG-AI-Challenge-for-Plant-Classification Dacon에서 진행된 농업 환경 변화에 따른 작물 병해 진단 AI 경진대회 에 대한 코드입니다. (colab directory에 코드가 잘 정리 되어있습니다.) Requirements python

siwooyong 10 Jun 30, 2022
MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc.

MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc. ⭐⭐⭐⭐⭐

568 Jan 04, 2023
Modified fork of Xuebin Qin's U-2-Net Repository. Used for demonstration purposes.

U^2-Net (U square net) Modified version of U2Net used for demonstation purposes. Paper: U^2-Net: Going Deeper with Nested U-Structure for Salient Obje

Shreyas Bhat Kera 13 Aug 28, 2022