Pre-trained NFNets with 99% of the accuracy of the official paper

Overview

NFNet Pytorch Implementation

This repo contains pretrained NFNet models F0-F6 with high ImageNet accuracy from the paper High-Performance Large-Scale Image Recognition Without Normalization. The small models are as accurate as an EfficientNet-B7, but train 8.7 times faster. The large models set a new SOTA top-1 accuracy on ImageNet.

NFNet F0 F1 F2 F3 F4 F5 F6+SAM
Top-1 accuracy Brock et al. 83.6 84.7 85.1 85.7 85.9 86.0 86.5
Top-1 accuracy this implementation 82.82 84.63 84.90 85.46 85.66 85.62 TBD

All credits go to the authors of the original paper. This repo is heavily inspired by their nice JAX implementation in the official repository. Visit their repo for citing.

Get started

git clone https://github.com/benjs/nfnets_pytorch.git
pip3 install -r requirements.txt

Download pretrained weights from the official repository and place them in the pretrained folder.

from pretrained import pretrained_nfnet
model_F0 = pretrained_nfnet('pretrained/F0_haiku.npz')
model_F1 = pretrained_nfnet('pretrained/F1_haiku.npz')
# ...

The model variant is automatically derived from the parameter count in the pretrained weights file.

Validate yourself

python3 eval.py --pretrained pretrained/F0_haiku.npz --dataset path/to/imagenet/valset/

You can download the ImageNet validation set from the ILSVRC2012 challenge site after asking for access with, for instance, your .edu mail address.

Scaled weight standardization convolutions in your own model

Simply replace all your nn.Conv2d with WSConv2D and all your nn.ReLU with VPReLU or VPGELU (variance preserving ReLU/GELU).

import torch.nn as nn
from model import WSConv2D, VPReLU, VPGELU

# Simply replace your nn.Conv2d layers
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
 
        self.activation = VPReLU(inplace=True) # or VPGELU
        self.conv0 = WSConv2D(in_channels=128, out_channels=256, kernel_size=1, ...)
        # ...

    def forward(self, x):
      out = self.activation(self.conv0(x))
      # ...

SGD with adaptive gradient clipping in your own model

Simply replace your SGD optimizer with SGD_AGC.

from optim import SGD_AGC

optimizer = SGD_AGC(
        named_params=model.named_parameters(), # Pass named parameters
        lr=1e-3,
        momentum=0.9,
        clipping=0.1, # New clipping parameter
        weight_decay=2e-5, 
        nesterov=True)

It is important to exclude certain layers from clipping or momentum. The authors recommends to exclude the last fully convolutional from clipping and the bias/gain parameters from weight decay:

import re

for group in optimizer.param_groups:
    name = group['name'] 
    
    # Exclude from weight decay
    if len(re.findall('stem.*(bias|gain)|conv.*(bias|gain)|skip_gain', name)) > 0:
        group['weight_decay'] = 0

    # Exclude from clipping
    if name.startswith('linear'):
        group['clipping'] = None

Train your own NFNet

Adjust your desired parameters in default_config.yaml and start training.

python3 train.py --dataset /path/to/imagenet/

There is still some parts missing for complete training from scratch:

  • Multi-GPU training
  • Data augmentations
  • FP16 activations and gradients

Contribute

The implementation is still in an early stage in terms of usability / testing. If you have an idea to improve this repo open an issue, start a discussion or submit a pull request.

Development status

  • Pre-trained NFNet Models
    • F0-F5
    • F6+SAM
    • Scaled weight standardization
    • Squeeze and excite
    • Stochastic depth
    • FP16 activations
  • SGD with unit adaptive gradient clipping (SGD-AGC)
    • Exclude certain layers from weight-decay, clipping
    • FP16 gradients
  • PyPI package
  • PyTorch hub submission
  • Label smoothing loss from Szegedy et al.
  • Training on ImageNet
  • Pre-trained weights
  • Tensorboard support
  • general usability improvements
  • Multi-GPU support
  • Data augmentation
  • Signal propagation plots (from first paper)
Comments
  • ModuleNotFoundError: No module named 'haiku'

    ModuleNotFoundError: No module named 'haiku'

    when i try "python3 eval.py --pretrained pretrained/F0_haiku.npz --dataset ***" i got this error, have you ever met this error? how to fix this?

    opened by Rianusr 2
  • Trained without data augmentation?

    Trained without data augmentation?

    Thanks for the great work on the pytorch implementation of NFNet! The accuracies achieved by this implementation are pretty impressive also and I am wondering if these training results were simply derived from the training script, that is, without data augmentation.

    opened by nandi-zhang 2
  • from_pretrained_haiku

    from_pretrained_haiku

    https://github.com/benjs/nfnets_pytorch/blob/7b4d1cc701c7de4ee273ded01ce21cbdb1e60c48/nfnets/pretrained.py#L90

    model = from_pretrained_haiku(args.pretrained)

    where is 'from_pretrained_haiku' method?

    opened by vkmavani 0
  • About WSconv2d

    About WSconv2d

    I see the authoe's code, I find his WSconv2d pad_mod is 'same'. Pytorch's conv2d dono't have pad_mode, and I think your padding should greater 0, but I find your padding always be 0. I want to know why?

    I see you train.py your learning rate is constant, why? Thank you!

    opened by fancyshun 3
  • AveragePool

    AveragePool

    Hi, noticed that the AveragePool ('pool' layer) is not used in forward function. Instead, forward uses torch.mean. Removing the layer doesn't change pooling behavior. I tried using this model as a feature extractor and was a bit confused for a moment.

    opened by bogdankjastrzebski 1
Releases(v0.0.1)
Owner
Benjamin Schmidt
Engineering Student
Benjamin Schmidt
Neural style in TensorFlow! 🎨

neural-style An implementation of neural style in TensorFlow. This implementation is a lot simpler than a lot of the other ones out there, thanks to T

Anish Athalye 5.5k Dec 29, 2022
Implementation of the paper Recurrent Glimpse-based Decoder for Detection with Transformer.

REGO-Deformable DETR By Zhe Chen, Jing Zhang, and Dacheng Tao. This repository is the implementation of the paper Recurrent Glimpse-based Decoder for

Zhe Chen 33 Nov 30, 2022
Membership Inference Attack against Graph Neural Networks

MIA GNN Project Starter If you meet the version mismatch error for Lasagne library, please use following command to upgrade Lasagne library. pip insta

6 Nov 09, 2022
CAMoE + Dual SoftMax Loss (DSL): Improving Video-Text Retrieval by Multi-Stream Corpus Alignment and Dual Softmax Loss

CAMoE + Dual SoftMax Loss (DSL): Improving Video-Text Retrieval by Multi-Stream Corpus Alignment and Dual Softmax Loss This is official implement of "

程星 87 Dec 24, 2022
Tensorflow implementation of Fully Convolutional Networks for Semantic Segmentation

FCN.tensorflow Tensorflow implementation of Fully Convolutional Networks for Semantic Segmentation (FCNs). The implementation is largely based on the

Sarath Shekkizhar 1.3k Dec 25, 2022
Semantically Contrastive Learning for Low-light Image Enhancement

Semantically Contrastive Learning for Low-light Image Enhancement Here, we propose an effective semantically contrastive learning paradigm for Low-lig

48 Dec 16, 2022
High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.

TL;DR Ignite is a high-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently. Click on the image to

4.2k Jan 01, 2023
Multi-layer convolutional LSTM with Pytorch

Convolution_LSTM_pytorch Thanks for your attention. I haven't got time to maintain this repo for a long time. I recommend this repo which provides an

Zijie Zhuang 734 Jan 03, 2023
Official implementation of CATs: Cost Aggregation Transformers for Visual Correspondence NeurIPS'21

CATs: Cost Aggregation Transformers for Visual Correspondence NeurIPS'21 For more information, check out the paper on [arXiv]. Training with different

Sunghwan Hong 120 Jan 04, 2023
The Turing Change Point Detection Benchmark: An Extensive Benchmark Evaluation of Change Point Detection Algorithms on real-world data

Turing Change Point Detection Benchmark Welcome to the repository for the Turing Change Point Detection Benchmark, a benchmark evaluation of change po

The Alan Turing Institute 85 Dec 28, 2022
Pytorch Implementations of large number classical backbone CNNs, data enhancement, torch loss, attention, visualization and some common algorithms.

Torch-template-for-deep-learning Pytorch implementations of some **classical backbone CNNs, data enhancement, torch loss, attention, visualization and

Li Shengyan 270 Dec 31, 2022
Official implementation of the NeurIPS 2021 paper Online Learning Of Neural Computations From Sparse Temporal Feedback

Online Learning Of Neural Computations From Sparse Temporal Feedback This repository is the official implementation of the NeurIPS 2021 paper Online L

Lukas Braun 3 Dec 15, 2021
Improving Non-autoregressive Generation with Mixup Training

MIST Training MIST TRAIN_FILE=/your/path/to/train.json VALID_FILE=/your/path/to/valid.json OUTPUT_DIR=/your/path/to/save_checkpoints CACHE_DIR=/your/p

7 Nov 22, 2022
Implementation for ACProp ( Momentum centering and asynchronous update for adaptive gradient methdos, NeurIPS 2021)

This repository contains code to reproduce results for submission NeurIPS 2021, "Momentum Centering and Asynchronous Update for Adaptive Gradient Meth

Juntang Zhuang 15 Jun 11, 2022
TumorInsight is a Brain Tumor Detection and Classification model built using RESNET50 architecture.

A Brain Tumor Detection and Classification Model built using RESNET50 architecture. The model is also deployed as a web application using Flask framework.

Pranav Khurana 0 Aug 17, 2021
RipsNet: a general architecture for fast and robust estimation of the persistent homology of point clouds

RipsNet: a general architecture for fast and robust estimation of the persistent homology of point clouds This repository contains the code asscoiated

Felix Hensel 14 Dec 12, 2022
The implemention of Video Depth Estimation by Fusing Flow-to-Depth Proposals

Flow-to-depth (FDNet) video-depth-estimation This is the implementation of paper Video Depth Estimation by Fusing Flow-to-Depth Proposals Jiaxin Xie,

32 Jun 14, 2022
EdiBERT, a generative model for image editing

EdiBERT, a generative model for image editing EdiBERT is a generative model based on a bi-directional transformer, suited for image manipulation. The

16 Dec 07, 2022
Predict multi paths to a moving person depending on his trajectory history.

Multi-future Trajectory Prediction The project is about using the Multiverse model to make possible multible-future trajectory prediction for a seen p

Said Gamal 1 Jan 18, 2022