Pretrained models for Jax/Haiku; MobileNet, ResNet, VGG, Xception.

Overview

Pre-trained image classification models for Jax/Haiku

Jax/Haiku Applications are deep learning models that are made available alongside pre-trained weights. These models can be used for prediction, feature extraction, and fine-tuning.

Available Models

  • MobileNetV1
  • ResNet, ResNetV2
  • VGG16, VGG19
  • Xception

Planned Releases

  • MobileNetV2, MobileNetV3
  • InceptionResNetV2, InceptionV3
  • EfficientNetV1, EfficientNetV2

Installation

Haikumodels require Python 3.7 or later.

  1. Needed libraries can be installed using "installation.txt".
  2. If Jax GPU support desired, must be installed seperately according to system needs.

Usage examples for image classification models

Classify ImageNet classes with ResNet50

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)


def _model(images, is_training):
  net = hm.ResNet50()
  return net(images, is_training)


model = hk.transform_with_state(_model)

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.resnet.preprocess_input(x)

params, state = model.init(rng, x, is_training=True)

preds, _ = model.apply(params, state, None, x, is_training=False)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print("Predicted:", hm.decode_predictions(preds, top=3)[0])
# Predicted:
# [('n02504013', 'Indian_elephant', 0.8784022),
# ('n01871265', 'tusker', 0.09620289),
# ('n02504458', 'African_elephant', 0.025362419)]

Extract features with VGG16

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)

model = hk.without_apply_rng(hk.transform(hm.VGG16(include_top=False)))

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.vgg.preprocess_input(x)

params = model.init(rng, x)

features = model.apply(params, x)

Fine-tune Xception on a new set of classes

from typing import Callable, Any, Sequence, Optional

import optax
import haiku as hk
import jax
import jax.numpy as jnp

import haikumodels as hm

rng = jax.random.PRNGKey(42)


class Freezable_TrainState(NamedTuple):
  trainable_params: hk.Params
  non_trainable_params: hk.Params
  state: hk.State
  opt_state: optax.OptState


# create your custom top layers and include the desired pretrained model
class ft_xception(hk.Module):

  def __init__(
      self,
      classes: int,
      classifier_activation: Callable[[jnp.ndarray],
                                      jnp.ndarray] = jax.nn.softmax,
      with_bias: bool = True,
      w_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      b_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.classifier_activation = classifier_activation

    self.xception_no_top = hm.Xception(include_top=False)
    self.dense_layer = hk.Linear(
        output_size=1024,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_dense_layer",
    )
    self.top_layer = hk.Linear(
        output_size=classes,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_top_layer",
    )

  def __call__(self, inputs: jnp.ndarray, is_training: bool):
    out = self.xception_no_top(inputs, is_training)
    out = jnp.mean(out, axis=(1, 2))
    out = self.dense_layer(out)
    out = jax.nn.relu(out)
    out = self.top_layer(out)
    out = self.classifier_activation(out)


# use `transform_with_state` if models has batchnorm in it
# else use `transform` and then `without_apply_rng`
def _model(images, is_training):
  net = ft_xception(classes=200)
  return net(images, is_training)


model = hk.transform_with_state(_model)

# create your desired optimizer using Optax or alternatives
opt = optax.rmsprop(learning_rate=1e-4, momentum=0.90)


# this function will initialize params and state
# use the desired keyword to divide params to trainable and non_trainable
def initial_state(x_y, nonfreeze_key="trainable"):
  x, _ = x_y
  params, state = model.init(rng, x, is_training=True)

  trainable_params, non_trainable_params = hk.data_structures.partition(
      lambda m, n, p: nonfreeze_key in m, params)

  opt_state = opt.init(params)

  return Freezable_TrainState(trainable_params, non_trainable_params, state,
                              opt_state)


train_state = initial_state(next(gen_x_y))


# create your own custom loss function as desired
def loss_function(trainable_params, non_trainable_params, state, x_y):
  x, y = x_y
  params = hk.data_structures.merge(trainable_params, non_trainable_params)
  y_, state = model.apply(params, state, None, x, is_training=True)

  cce = categorical_crossentropy(y, y_)

  return cce, state


# to update params and optimizer, a train_step function must be created
@jax.jit
def train_step(train_state: Freezable_TrainState, x_y):
  trainable_params, non_trainable_params, state, opt_state = train_state
  trainable_params_grads, _ = jax.grad(loss_function,
                                       has_aux=True)(trainable_params,
                                                     non_trainable_params,
                                                     state, x_y)

  updates, new_opt_state = opt.update(trainable_params_grads, opt_state)
  new_trainable_params = optax.apply_updates(trainable_params, updates)

  train_state = Freezable_TrainState(new_trainable_params, non_trainable_params,
                                     state, new_opt_state)
  return train_state


# train the model on the new data for few epochs
train_state = train_step(train_state, next(gen_x_y))

# after training is complete it possible to merge
# trainable and non_trainable params to use for prediction
trainable_params, non_trainable_params, state, _ = train_state
params = hk.data_structures.merge(trainable_params, non_trainable_params)
preds, _ = model.apply(params, state, None, x, is_training=False)
You might also like...
3D ResNet Video Classification accelerated by TensorRT
3D ResNet Video Classification accelerated by TensorRT

Activity Recognition TensorRT Perform video classification using 3D ResNets trained on Kinetics-400 dataset and accelerated with TensorRT P.S Click on

improvement of CLIP features over the traditional resnet features on the visual question answering, image captioning, navigation and visual entailment tasks.

CLIP-ViL In our paper "How Much Can CLIP Benefit Vision-and-Language Tasks?", we show the improvement of CLIP features over the traditional resnet fea

PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT.

MoCo v3 for Self-supervised ResNet and ViT Introduction This is a PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT. The original M

Reproduces ResNet-V3 with pytorch
Reproduces ResNet-V3 with pytorch

ResNeXt.pytorch Reproduces ResNet-V3 (Aggregated Residual Transformations for Deep Neural Networks) with pytorch. Tried on pytorch 1.6 Trains on Cifar

DeepLab resnet v2 model in pytorch

pytorch-deeplab-resnet DeepLab resnet v2 model implementation in pytorch. The architecture of deepLab-ResNet has been replicated exactly as it is from

Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet
Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet

Reproduce ResNet-v2 using MXNet Requirements Install MXNet on a machine with CUDA GPU, and it's better also installed with cuDNN v5 Please fix the ran

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.
In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.

cdf_att_classification classes = {0: 'cat', 1: 'dog', 2: 'flower'} In this project we use both Resnet and Self-attention layer for cdf-Classification.

Comments
  • Expected top-1 test accuracy

    Expected top-1 test accuracy

    Hi

    This is a fantastic project! The released checkpoints are super helpful!

    I am wondering what's the top-1 test accuracy that one should get using the released ResNet-50 checkpoints. I am able to reach 0.749 using the my own ImageNet dataloader implemented via Tensorflow Datasets. Is the number close to your results?

    BTW, it would also be very helpful if you could release your training and dataloading code for these models!

    Thanks,

    opened by xidulu 2
  • Fitting issue

    Fitting issue

    I was trying to use a few of your pre-trained models, in particular the ResNet50 and VGG16 for features extraction, but unfortunately I didn't manage to fit on the Nvidia Titan X with 12GB of VRAM my question is which GPU did you use for training, how much VRAM I need for use them?

    For the VGG16 the system was asking me for 4 more GB and for the ResNet50 about 20 more

    Thanks.

    opened by mattiadutto 1
Owner
Alper Baris CELIK
Alper Baris CELIK
PG2Net: Personalized and Group PreferenceGuided Network for Next Place Prediction

PG2Net PG2Net:Personalized and Group Preference Guided Network for Next Place Prediction Datasets Experiment results on two Foursquare check-in datase

Urban Mobility 5 Dec 20, 2022
Codes to calculate solar-sensor zenith and azimuth angles directly from hyperspectral images collected by UAV. Works only for UAVs that have high resolution GNSS/IMU unit.

UAV Solar-Sensor Angle Calculation Table of Contents About The Project Built With Getting Started Prerequisites Installation Datasets Contributing Lic

Sourav Bhadra 1 Jan 15, 2022
A package to predict protein inter-residue geometries from sequence data

trRosetta This package is a part of trRosetta protein structure prediction protocol developed in: Improved protein structure prediction using predicte

Ivan Anishchenko 185 Jan 07, 2023
PyTorch Code of "Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics"

Memory In Memory Networks It is based on the paper Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spati

Yang Li 12 May 30, 2022
Dynamic Head: Unifying Object Detection Heads with Attentions

Dynamic Head: Unifying Object Detection Heads with Attentions dyhead_video.mp4 This is the official implementation of CVPR 2021 paper "Dynamic Head: U

Microsoft 550 Dec 21, 2022
Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

Mikhail Grankin 768 Jan 01, 2023
Uses OpenCV and Python Code to detect a face on the screen

Simple-Face-Detection This code uses OpenCV and Python Code to detect a face on the screen. This serves as an example program. Important prerequisites

Denis Woolley (CreepyD) 1 Feb 12, 2022
Code for the paper titled "Generalized Depthwise-Separable Convolutions for Adversarially Robust and Efficient Neural Networks" (NeurIPS 2021 Spotlight).

Generalized Depthwise-Separable Convolutions for Adversarially Robust and Efficient Neural Networks This repository contains the code and pre-trained

Hassan Dbouk 7 Dec 05, 2022
Simulation-based inference for the Galactic Center Excess

Simulation-based inference for the Galactic Center Excess Siddharth Mishra-Sharma and Kyle Cranmer Abstract The nature of the Fermi gamma-ray Galactic

Siddharth Mishra-Sharma 3 Jan 21, 2022
Bayesian algorithm execution (BAX)

Bayesian Algorithm Execution (BAX) Code for the paper: Bayesian Algorithm Execution: Estimating Computable Properties of Black-box Functions Using Mut

Willie Neiswanger 38 Dec 08, 2022
[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

MDCA Calibration This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved

MDCA Calibration 21 Dec 22, 2022
Python scripts for performing stereo depth estimation using the MobileStereoNet model in ONNX

ONNX-MobileStereoNet Python scripts for performing stereo depth estimation using the MobileStereoNet model in ONNX Stereo depth estimation on the cone

Ibai Gorordo 23 Nov 29, 2022
Keyword spotting on Arm Cortex-M Microcontrollers

Keyword spotting for Microcontrollers This repository consists of the tensorflow models and training scripts used in the paper: Hello Edge: Keyword sp

Arm Software 1k Dec 30, 2022
FLAVR is a fast, flow-free frame interpolation method capable of single shot multi-frame prediction

FLAVR is a fast, flow-free frame interpolation method capable of single shot multi-frame prediction. It uses a customized encoder decoder architecture with spatio-temporal convolutions and channel ga

Tarun K 280 Dec 23, 2022
An Easy-to-use, Modular and Prolongable package of deep-learning based Named Entity Recognition Models.

DeepNER An Easy-to-use, Modular and Prolongable package of deep-learning based Named Entity Recognition Models. This repository contains complex Deep

Derrick 9 May 30, 2022
Codes and Data Processing Files for our paper.

Code Scripts and Processing Files for EEG Sleep Staging Paper 1. Folder Tree ./src_preprocess (data preprocessing files for SHHS and Sleep EDF) sleepE

Chaoqi Yang 18 Dec 12, 2022
Pytorch implementation for the paper: Contrastive Learning for Cold-start Recommendation

Contrastive Learning for Cold-start Recommendation This is our Pytorch implementation for the paper: Yinwei Wei, Xiang Wang, Qi Li, Liqiang Nie, Yan L

45 Dec 13, 2022
DCSAU-Net: A Deeper and More Compact Split-Attention U-Net for Medical Image Segmentation

DCSAU-Net: A Deeper and More Compact Split-Attention U-Net for Medical Image Segmentation By Qing Xu, Wenting Duan and Na He Requirements pytorch==1.1

Qing Xu 20 Dec 09, 2022
This repo contains the source code and a benchmark for predicting user's utilities with Machine Learning techniques for Computational Persuasion

Machine Learning for Argument-Based Computational Persuasion This repo contains the source code and a benchmark for predicting user's utilities with M

Ivan Donadello 4 Nov 07, 2022
Kaggle DSTL Satellite Imagery Feature Detection

Kaggle DSTL Satellite Imagery Feature Detection

Konstantin Lopuhin 206 Oct 29, 2022