Global Filter Networks for Image Classification

Overview

Global Filter Networks for Image Classification

Created by Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, Jie Zhou

This repository contains PyTorch implementation for GFNet.

Global Filter Networks is a transformer-style architecture that learns long-term spatial dependencies in the frequency domain with log-linear complexity. Our architecture replaces the self-attention layer in vision transformers with three key operations: a 2D discrete Fourier transform, an element-wise multiplication between frequency-domain features and learnable global filters, and a 2D inverse Fourier transform.

intro

Our code is based on pytorch-image-models and DeiT.

[Project Page] [arXiv]

Global Filter Layer

GFNet is a conceptually simple yet computationally efficient architecture, which consists of several stacking Global Filter Layers and Feedforward Networks (FFN). The Global Filter Layer mixes tokens with log-linear complexity benefiting from the highly efficient Fast Fourier Transform (FFT) algorithm. The layer is easy to implement:

import torch
import torch.nn as nn
import torch.fft

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h

    def forward(self, x):
        B, H, W, C = x.shape
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = torch.view_as_complex(self.complex_weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
        return x

Compared to self-attention and spatial MLP, our Global Filter Layer is much more efficient to process high-resolution feature maps:

efficiency

Model Zoo

We provide our GFNet models pretrained on ImageNet:

name arch Params FLOPs [email protected] [email protected] url
GFNet-Ti gfnet-ti 7M 1.3G 74.6 92.2 Tsinghua Cloud / Google Drive
GFNet-XS gfnet-xs 16M 2.8G 78.6 94.2 Tsinghua Cloud / Google Drive
GFNet-S gfnet-s 25M 4.5G 80.0 94.9 Tsinghua Cloud / Google Drive
GFNet-B gfnet-b 43M 7.9G 80.7 95.1 Tsinghua Cloud / Google Drive
GFNet-H-Ti gfnet-h-ti 15M 2.0G 80.1 95.1 Tsinghua Cloud / Google Drive
GFNet-H-S gfnet-h-s 32M 4.5G 81.5 95.6 Tsinghua Cloud / Google Drive
GFNet-H-B gfnet-h-b 54M 8.4G 82.9 96.2 Tsinghua Cloud / Google Drive

Usage

Requirements

  • torch>=1.8.1
  • torchvision
  • timm

Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be

│ILSVRC2012/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Evaluation

To evaluate a pre-trained GFNet model on the ImageNet validation set with a single GPU, run:

python infer.py --data-path /path/to/ILSVRC2012/ --arch arch_name --path /path/to/model

Training

ImageNet

To train GFNet models on ImageNet from scratch, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet.py  --output_dir logs/gfnet-xs --arch gfnet-xs --batch-size 128 --data-path /path/to/ILSVRC2012/

To finetune a pre-trained model at higher resolution, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet.py  --output_dir logs/gfnet-xs-img384 --arch gfnet-xs --input-size 384 --batch-size 64 --data-path /path/to/ILSVRC2012/ --lr 5e-6 --weight-decay 1e-8 --min-lr 5e-6 --epochs 30 --finetune /path/to/model

Transfer Learning Datasets

To finetune a pre-trained model on a transfer learning dataset, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet_transfer.py  --output_dir logs/gfnet-xs-cars --arch gfnet-xs --batch-size 64 --data-set CARS --data-path /path/to/stanford_cars --epochs 1000 --dist-eval --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune /path/to/model 

License

MIT License

Citation

If you find our work useful in your research, please consider citing:

@article{rao2021global,
  title={Global Filter Networks for Image Classification},
  author={Rao, Yongming and Zhao, Wenliang and Zhu, Zheng and Lu, Jiwen and Zhou, Jie},
  journal={arXiv preprint arXiv:2107.00645},
  year={2021}
}
Iris prediction model is used to classify iris species created julia's DecisionTree, DataFrames, JLD2, PlotlyJS and Statistics packages.

Iris Species Predictor Iris prediction is used to classify iris species using their sepal length, sepal width, petal length and petal width created us

Siva Prakash 2 Jan 06, 2022
Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval (NeurIPS'21)

Baleen Baleen is a state-of-the-art model for multi-hop reasoning, enabling scalable multi-hop search over massive collections for knowledge-intensive

Stanford Future Data Systems 22 Dec 05, 2022
The official implementation of Theme Transformer

Theme Transformer This is the official implementation of Theme Transformer. Checkout our demo and paper : Demo | arXiv Environment: using python versi

Ian Shih 85 Dec 08, 2022
Unofficial PyTorch implementation of TokenLearner by Google AI

tokenlearner-pytorch Unofficial PyTorch implementation of TokenLearner by Ryoo et al. from Google AI (abs, pdf) Installation You can install TokenLear

Rishabh Anand 46 Dec 20, 2022
A Pythonic library for Nvidia Codec.

A Pythonic library for Nvidia Codec. The project is still in active development; expect breaking changes. Why another Python library for Nvidia Codec?

Zesen Qian 12 Dec 27, 2022
Python implementation of Project Fluent

Project Fluent This is a collection of Python packages to use the Fluent localization system. python-fluent consists of these packages: fluent.syntax

Project Fluent 155 Dec 28, 2022
PyTorch implementation of UPFlow (unsupervised optical flow learning)

UPFlow: Upsampling Pyramid for Unsupervised Optical Flow Learning By Kunming Luo, Chuan Wang, Shuaicheng Liu, Haoqiang Fan, Jue Wang, Jian Sun Megvii

kunming luo 87 Dec 20, 2022
Pose Transformers: Human Motion Prediction with Non-Autoregressive Transformers

Pose Transformers: Human Motion Prediction with Non-Autoregressive Transformers This is the repo used for human motion prediction with non-autoregress

Idiap Research Institute 26 Dec 14, 2022
Image-to-Image Translation with Conditional Adversarial Networks (Pix2pix) implementation in keras

pix2pix-keras Pix2pix implementation in keras. Original paper: Image-to-Image Translation with Conditional Adversarial Networks (pix2pix) Paper Author

William Falcon 141 Dec 30, 2022
x-transformers-paddle 2.x version

x-transformers-paddle x-transformers-paddle 2.x version paddle 2.x版本 https://github.com/lucidrains/x-transformers 。 requirements paddlepaddle-gpu==2.2

yujun 7 Dec 08, 2022
Cours d'Algorithmique Appliquée avec Python pour BTS SIO SISR

Course: Introduction to Applied Algorithms with Python (in French) This is the source code of the website for the Applied Algorithms with Python cours

Loic Yvonnet 0 Jan 27, 2022
Ratatoskr: Worcester Tech's conference scheduling system

Ratatoskr: Worcester Tech's conference scheduling system In Norse mythology, Ratatoskr is a squirrel who runs up and down the world tree Yggdrasil to

4 Dec 22, 2022
Ultra-lightweight human body posture key point CNN model. ModelSize:2.3MB HUAWEI P40 NCNN benchmark: 6ms/img,

Ultralight-SimplePose Support NCNN mobile terminal deployment Based on MXNET(=1.5.1) GLUON(=0.7.0) framework Top-down strategy: The input image is t

223 Dec 27, 2022
A high performance implementation of HDBSCAN clustering.

HDBSCAN HDBSCAN - Hierarchical Density-Based Spatial Clustering of Applications with Noise. Performs DBSCAN over varying epsilon values and integrates

2.3k Jan 02, 2023
A user-friendly research and development tool built to standardize RL competency assessment for custom agents and environments.

Built with ❤️ by Sam Showalter Contents Overview Installation Dependencies Usage Scripts Standard Execution Environment Development Environment Benchm

SRI-AIC 1 Nov 18, 2021
Background Matting: The World is Your Green Screen

Background Matting: The World is Your Green Screen By Soumyadip Sengupta, Vivek Jayaram, Brian Curless, Steve Seitz, and Ira Kemelmacher-Shlizerman Th

Soumyadip Sengupta 4.6k Jan 04, 2023
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

128 Dec 27, 2022
Human motion synthesis using Unity3D

Human motion synthesis using Unity3D Prerequisite: Software: amc2bvh.exe, Unity 2017, Blender. Unity: RockVR (Video Capture), scenes, character models

Hao Xu 9 Jun 01, 2022
Data visualization app for H&M competition in kaggle

handm_data_visualize_app Data visualization app by streamlit for H&M competition in kaggle. competition page: https://www.kaggle.com/competitions/h-an

Kyohei Uto 12 Apr 30, 2022