OCRA (Object-Centric Recurrent Attention) source code

Related tags

Deep LearningOCRA
Overview

OCRA (Object-Centric Recurrent Attention) source code

Hossein Adeli and Seoyoung Ahn

Please cite this article if you find this repository useful:


  • For data generation and loading

    1. stimuli_util.ipynb includes all the codes and the instructions for how to generate the datasets for the three tasks; MultiMNIST, MultiMNIST Cluttered and MultiSVHN.
    2. loaddata.py should be updated with the location of the data files for the tasks if not the default used.
  • For training and testing the model:

    1. OCRA_demo.ipynb includes the code for building and training the model. In the first notebook cell, a hyperparameter file should be specified. Parameter files are provided here (different settings are discussed in the supplementary file)

    2. multimnist_params_10glimpse.txt and multimnist_params_3glimpse.txt set all the hyperparameters for MultiMNIST task with 10 and 3 glimpses, respectively.

    OCRA_demo-MultiMNIST_3glimpse_training.ipynb shows how to load a parameter file and train the model.

    1. multimnist_cluttered_params_7glimpse.txt and multimnist_cluttered_params_5glimpse.txt set all the hyperparameters for MultiMNIST Cluttered task with 7 and 5 glimpses, respectively.

    2. multisvhn_params.txt sets all the hyperparameters for the MultiSVHN task with 12 glimpses.

    3. This notebook also includes code for testing a trained model and also for plotting the attention windows for sample images.

    OCRA_demo-cluttered_5steps_loadtrained.ipynb shows how to load a trained model and test it on the test dataset. Example pretrained models are included in the repository under pretrained folder. Download all the pretrained models.

Image-level accuracy averaged from 5 runs

Task (Model name) Error Rate (SD)
MultiMNIST (OCRA-10glimpse) 5.08 (0.17)
Cluttered MultiMNIST (OCRA-7glimpse) 7.12 (1.05)
MultiSVHN (OCRA-12glimpse) 10.07 (0.53)

Validation losses during training

From MultiMNIST OCRA-10glimpse:

From Cluttered MultiMNIST OCRA-7glimpse

Supplementary Results:

Object-centric behavior

The opportunity to observe the object-centric behavior is bigger in the cluttered task. Since the ratio of the glimpse size to the image size is small (covering less than 4 percent of the image), the model needs to optimally move and select the objects to accurately recognize them. Also reducing the number of glimpses has a similar effect, (we experimented with 3 and 5) forcing the model to leverage its object-centric representation to find the objects without being distracted by the noise segments. We include many more examples of the model behavior with both 3 and 5 glimpses to show this behavior.

MultiMNIST Cluttered task with 5 glimpses






MultiMNIST Cluttered task with 3 glimpses





The Street View House Numbers Dataset

We train the model to "read" the digits from left to right by having the order of the predicted sequence match the ground truth from left to right. We allow the model to make 12 glimpses, with the first two not being constrained and the capsule length from every following two glimpses will be read out for the output digit (e.g. the capsule lengths from the 3rd and 4th glimpses are read out to predict digit number 1; the left-most digit and so on). Below are sample behaviors from our model.

The top five rows show the original images, and the bottom five rows show the reconstructions

SVHN_gif

The generation of sample images across 12 glimpses

SVHN_gif

The generatin in a gif fromat

SVHN_gif

The model learns to detect and reconstruct objects. The model achieved ~2.5 percent error rate on recognizing individual digits and ~10 percent error in recognizing whole sequences still lagging SOTA performance on this measure. We believe this to be strongly related to our small two-layer convolutional backbone and we expect to get better results with a deeper one, which we plan to explore next. However, the model shows reasonable attention behavior in performing this task.

Below shows the model's read and write attention behavior as it reads and reconstructs one image.

Herea are a few sample mistakes from our model:

SVHN_error1
ground truth [ 1, 10, 10, 10, 10]
prediction [ 0, 10, 10, 10, 10]

SVHN_error2
ground truth [ 2, 8, 10, 10, 10]
prediction [ 2, 9, 10, 10, 10]

SVHN_error3
ground truth [ 1, 2, 9, 10, 10]
prediction [ 1, 10, 10, 10, 10]

SVHN_error4
ground truth [ 5, 1, 10, 10, 10]
prediction [ 5, 7, 10, 10, 10]


Some MNIST cluttered results

Testing the model on MNIST cluttered dataset with three time steps


Code references:

  1. XifengGuo/CapsNet-Pytorch
  2. kamenbliznashki/generative_models
  3. pitsios-s/SVHN
Owner
Hossein Adeli
Hossein Adeli
Just Randoms Cats with python

Random-Cat Just Randoms Cats with python.

OriCode 2 Dec 21, 2021
Morphable Detector for Object Detection on Demand

Morphable Detector for Object Detection on Demand (ICCV 2021) PyTorch implementation of the paper Morphable Detector for Object Detection on Demand. I

9 Feb 23, 2022
PyTorch implementation of DreamerV2 model-based RL algorithm

PyDreamer Reimplementation of DreamerV2 model-based RL algorithm in PyTorch. The official DreamerV2 implementation can be found here. Features ... Run

118 Dec 15, 2022
Pytorch Implementation for Dilated Continuous Random Field

DilatedCRF Pytorch implementation for fully-learnable DilatedCRF. If you find my work helpful, please consider our paper: @article{Mo2022dilatedcrf,

DunnoCoding_Plus 3 Nov 13, 2022
Source code to accompany Defunctland's video "FASTPASS: A Complicated Legacy"

Shapeland Simulator Source code to accompany Defunctland's video "FASTPASS: A Complicated Legacy" Download the video at https://www.youtube.com/watch?

TouringPlans.com 70 Dec 14, 2022
A PyTorch implementation of "Pathfinder Discovery Networks for Neural Message Passing"

A PyTorch implementation of "Pathfinder Discovery Networks for Neural Message Passing" (WebConf 2021). Abstract In this work we propose Pathfind

Benedek Rozemberczki 49 Dec 01, 2022
A multi-scale unsupervised learning for deformable image registration

A multi-scale unsupervised learning for deformable image registration Shuwei Shao, Zhongcai Pei, Weihai Chen, Wentao Zhu, Xingming Wu and Baochang Zha

ShuweiShao 2 Apr 13, 2022
[ArXiv 2021] One-Shot Generative Domain Adaptation

GenDA - One-Shot Generative Domain Adaptation One-Shot Generative Domain Adaptation Ceyuan Yang*, Yujun Shen*, Zhiyi Zhang, Yinghao Xu, Jiapeng Zhu, Z

GenForce: May Generative Force Be with You 46 Dec 19, 2022
Code for Piggyback: Adapting a Single Network to Multiple Tasks by Learning to Mask Weights

Piggyback: https://arxiv.org/abs/1801.06519 Pretrained masks and backbones are available here: https://uofi.box.com/s/c5kixsvtrghu9yj51yb1oe853ltdfz4q

Arun Mallya 165 Nov 22, 2022
Cerberus Transformer: Joint Semantic, Affordance and Attribute Parsing

Cerberus Transformer: Joint Semantic, Affordance and Attribute Parsing Paper Introduction Multi-task indoor scene understanding is widely considered a

62 Dec 05, 2022
An AutoML Library made with Optuna and PyTorch Lightning

An AutoML Library made with Optuna and PyTorch Lightning Installation Recommended pip install -U gradsflow From source pip install git+https://github.

GradsFlow 294 Dec 17, 2022
Finite-temperature variational Monte Carlo calculation of uniform electron gas using neural canonical transformation.

CoulombGas This code implements the neural canonical transformation approach to the thermodynamic properties of uniform electron gas. Building on JAX,

FermiFlow 9 Mar 03, 2022
Speeding-Up Back-Propagation in DNN: Approximate Outer Product with Memory

Approximate Outer Product Gradient Descent with Memory Code for the numerical experiment of the paper Speeding-Up Back-Propagation in DNN: Approximate

2 Mar 02, 2022
object recognition with machine learning on Respberry pi

Respberrypi_object-recognition object recognition with machine learning on Respberry pi line.py 建立一支與樹梅派連線的 linebot 使用此 linebot 遠端控制樹梅派拍照 config.ini l

1 Dec 11, 2021
Towards Understanding Quality Challenges of the Federated Learning: A First Look from the Lens of Robustness

FL Analysis This repository contains the code and results for the paper "Towards Understanding Quality Challenges of the Federated Learning: A First L

3 Oct 17, 2022
It's final year project of Diploma Engineering. This project is based on Computer Vision.

Face-Recognition-Based-Attendance-System It's final year project of Diploma Engineering. This project is based on Computer Vision. Brief idea about ou

Neel 10 Nov 02, 2022
Local Multi-Head Channel Self-Attention for FER2013

LHC-Net Local Multi-Head Channel Self-Attention This repository is intended to provide a quick implementation of the LHC-Net and to replicate the resu

12 Jan 04, 2023
PG2Net: Personalized and Group PreferenceGuided Network for Next Place Prediction

PG2Net PG2Net:Personalized and Group Preference Guided Network for Next Place Prediction Datasets Experiment results on two Foursquare check-in datase

Urban Mobility 5 Dec 20, 2022
Neon-erc20-example - Example of creating SPL token and wrapping it with ERC20 interface in Neon EVM

Example of wrapping SPL token by ERC2-20 interface in Neon Requirements Install

7 Mar 28, 2022
Semi-Supervised Semantic Segmentation with Pixel-Level Contrastive Learning from a Class-wise Memory Bank

This repository provides the official code for replicating experiments from the paper: Semi-Supervised Semantic Segmentation with Pixel-Level Contrast

Iñigo Alonso Ruiz 58 Dec 15, 2022