Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".

Overview

MobileViT

RegNet

Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER.


Table of Contents


Model Architecture

Trulli

MobileViT Architecture

Usage

Training

python main.py
optional arguments:
  -h, --help            show this help message and exit
  --gpu_device GPU_DEVICE
                        Select specific GPU to run the model
  --batch-size N        Input batch size for training (default: 64)
  --epochs N            Number of epochs to train (default: 20)
  --num-class N         Number of classes to classify (default: 10)
  --lr LR               Learning rate (default: 0.01)
  --weight-decay WD     Weight decay (default: 1e-5)
  --model-path PATH     Path to save the model

Citation

@InProceedings{Sachin2021,
  title = {MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER},
  author = {Sachin Mehta and Mohammad Rastegari},
  booktitle = {},
  year = {2021}
}

If this implement have any problem please let me know, thank you.

Comments
  • Training settings

    Training settings

    I really appreciate your efforts in implementing this model in pytorch. Here, I have one concern about the training settings. If what I understand is correct, you just trained the model for less than 5 epoches.

    In addition, the hyper-parameters you adopted is different from that in the original article. For instance, in the original manuscript, authors train mobilevit using AdamW optimizer, label smoothing cross-entry and multi-scale sampler. The training phase has a warmup stage.

    I also found that the classificaion accuracy provided here is much lower than that in the original version.

    I conjecture that the gab between accuracies are caused by different training settings.

    opened by hkzhang91 6
  • load pretrain weight failed

    load pretrain weight failed

    import torch
    import models
    
    model = models.MobileViT_S()
    PATH = "./MobileVit-S.pth.tar"
    weights = torch.load(PATH, map_location=lambda storage, loc: storage)
    model.load_state_dict(weights['state_dict'])
    model.eval()
    torch.save(model, './model.pt')
    
    • I try to load the pre-train weight to test one demo; but the network structure does not seem to match the weights, is there any solution?

    image

    opened by hererookie 2
  • model training hyperparameter

    model training hyperparameter

    A problem has been bothering me. the learning rate, optimizer, batch_size, L2 regularization, label smoothing and epochs are inconsistent with the paper. How should I modify the code?

    opened by Agino-ltp 1
  • Have you test MobileVit on cifar-10?

    Have you test MobileVit on cifar-10?

    Thanks for your wonderful work!

    I prepare to try MobileVit on small dataset, such as MNIST, and I need adjust the network structure. Before this work, I want to know if MobileVit has a better performance than other networks on small dataset.

    I notice "get_cifar10_dataset" in utils.py. Have you tested MobileVit on cifar-10? If you have, could you please show me the accuracy and inference time result?

    opened by Jerryme-xxm 1
  • Issues when loading MobileViT_S()

    Issues when loading MobileViT_S()

    I wanted to load the MobileViT_S() model and use the pre-trained weights, but I have got some errors in my code. To make it easier and help others, I will share my solution (in case there will be someone who is beginner like me):

    def load_mobilevit_weights(model_path):
      # Create an instance of the MobileViT model
      net = MobileViT_S()
      
      # Load the PyTorch state_dict
      state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
      
      # Since there is a problem in the names of layers, we will change the keys to meet the MobileViT model architecture
      for key in list(state_dict.keys()):
        state_dict[key.replace('module.', '')] = state_dict.pop(key)
      
      # Once the keys are fixed, we can modify the parameters of MobileViT
      net.load_state_dict(state_dict)
      
      return net
    
    net = load_mobilevit_weights("MobileViT_S_model_best.pth.tar")
    
    opened by Sehaba95 4
Releases(weight)
Owner
Hong-Jia Chen
Master student at National Chung Cheng University, Taiwan. Interested in Deep Learning and Computer Vision.
Hong-Jia Chen
A hand tracking demo made with mediapipe where you can control lights with pinching your fingers and moving your hand up/down.

HandTrackingBrightnessControl A hand tracking demo made with mediapipe where you can control lights with pinching your fingers and moving your hand up

Teemu Laurila 19 Feb 12, 2022
A machine learning malware analysis framework for Android apps.

🕵️ A machine learning malware analysis framework for Android apps. ☢️ DroidDetective is a Python tool for analysing Android applications (APKs) for p

James Stevenson 77 Dec 27, 2022
🎃 Core identification module of AI powerful point reading system platform.

ppReader-Kernel Intro Core identification module of AI powerful point reading system platform. Usage 硬件: Windows10、GPU:nvdia GTX 1060 、普通RBG相机 软件: con

CrashKing 1 Jan 11, 2022
Assessing syntactic abilities of BERT

BERT-Syntax Assesing the syntactic abilities of BERT. What Evaluate Google's BERT-Base and BERT-Large models on the syntactic agreement datasets from

Yoav Goldberg 147 Aug 02, 2022
The ARCA23K baseline system

ARCA23K Baseline System This is the source code for the baseline system associated with the ARCA23K dataset. Details about ARCA23K and the baseline sy

4 Jul 02, 2022
PyTorch implementations of algorithms for density estimation

pytorch-flows A PyTorch implementations of Masked Autoregressive Flow and some other invertible transformations from Glow: Generative Flow with Invert

Ilya Kostrikov 546 Dec 05, 2022
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

BG Kim 3 Oct 06, 2022
CN24 is a complete semantic segmentation framework using fully convolutional networks

Build status: master (production branch): develop (development branch): Welcome to the CN24 GitHub repository! CN24 is a complete semantic segmentatio

Computer Vision Group Jena 123 Jul 14, 2022
[ICLR 2021] "CPT: Efficient Deep Neural Network Training via Cyclic Precision" by Yonggan Fu, Han Guo, Meng Li, Xin Yang, Yining Ding, Vikas Chandra, Yingyan Lin

CPT: Efficient Deep Neural Network Training via Cyclic Precision Yonggan Fu, Han Guo, Meng Li, Xin Yang, Yining Ding, Vikas Chandra, Yingyan Lin Accep

26 Oct 25, 2022
🛰️ List of earth observation companies and job sites

Earth Observation Companies & Jobs source Portals & Jobs Geospatial Geospatial jobs newsletter: ~biweekly newsletter with geospatial jobs by Ali Ahmad

Dahn 64 Dec 27, 2022
[CVPR'21] Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration

Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration This repository contains the implementation of our paper Locally Aware Pi

sfwang 70 Dec 19, 2022
Cockpit is a visual and statistical debugger specifically designed for deep learning.

Cockpit: A Practical Debugging Tool for Training Deep Neural Networks

Felix Dangel 421 Dec 29, 2022
The first public PyTorch implementation of Attentive Recurrent Comparators

arc-pytorch PyTorch implementation of Attentive Recurrent Comparators by Shyam et al. A blog explaining Attentive Recurrent Comparators Visualizing At

Sanyam Agarwal 150 Oct 14, 2022
Unofficial pytorch implementation of 'Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization'

pytorch-AdaIN This is an unofficial pytorch implementation of a paper, Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Hua

Naoto Inoue 873 Jan 06, 2023
Jetson Nano-based smart camera system that measures crowd face mask usage in real-time.

MaskCam MaskCam is a prototype reference design for a Jetson Nano-based smart camera system that measures crowd face mask usage in real-time, with all

BDTI 212 Dec 29, 2022
This is an implementation of Googles Yogi-Optimizer in Keras (tf.keras)

Yogi-Optimizer_Keras This is an implementation of Googles Yogi-Optimizer in Keras (tf.keras) The NeurIPS-Paper can be found here: http://papers.nips.c

14 Sep 13, 2022
TensorLight - A high-level framework for TensorFlow

TensorLight is a high-level framework for TensorFlow-based machine intelligence applications. It reduces boilerplate code and enables advanced feature

Benjamin Kan 10 Jul 31, 2022
I3-master-layout - Simple master and stack layout script

Simple master and stack layout script | ------ | ----- | | | | | Ma

Tobias S 18 Dec 05, 2022
This is the code for ACL2021 paper A Unified Generative Framework for Aspect-Based Sentiment Analysis

This is the code for ACL2021 paper A Unified Generative Framework for Aspect-Based Sentiment Analysis Install the package in the requirements.txt, the

108 Dec 23, 2022
vit for few-shot classification

Few-Shot ViT Requirements PyTorch (= 1.9) TorchVision timm (latest) einops tqdm numpy scikit-learn scipy argparse tensorboardx Pretrained Checkpoints

Martin Dong 26 Nov 30, 2022