An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" in Pytorch.

Related tags

Deep LearningGLOM
Overview

GLOM

An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" for MNIST Dataset. To understand this implementation, please watch Yannick Kilcher's GLOM video, then read this README.md, then read the code.

Running

Open in jupyter notebook to run. Program expects an Nvidia graphics card for gpu speedup. If you run out of gpu memory, decrease the batch_size variable. If you want to look at the code on github and it fails, try reloading or refreshing several times.

Results

The best models, which have been posted under the best_models folder, reached an accuracy of about 91%.

Implementation details

Three Types of networks per layer of vectors

  1. Top-Down Network
  2. Bottom-up Network
  3. Attention on the same layer Network

Intro to State

There is an initial state that all three types of network outputs get added to after every time step. The bottom layer of the state is the input vector where the MNIST pixel data is kept and doesn't get anything added to it to retain the MNIST pixel data. The top layer of the state is the output layer where the loss function is applied and trained to be the one-hot MNIST target vector.

Explanation of compute_all function

Each type of network will see a 3x3 grid of vectors surrounding the current network input vector at the current layer. This is done to allow information to travel faster laterally across vectors, allowing for more information to be sent across an image in less steps. The easy way to do this is to shift (or roll) every vector along the x and y axis and then concatenate the vectors ontop of eachother so that every place a vector used to be in the state, now contains every vector and its neighboring vectors in the same layer. This also connects the edges of the image so that data can be passed from one edge of the image to the other, reducing the maximum distance any two pixels or vectors can be from one another.

For a more complex dataset, its possible this could pose some issues since two separate edges of an image aren't generally continous, but for MNIST, this problem doesn't arise. Then, these vectors are fed to each type of model. The models will get an input of all neighboring state vectors for a certain layer for each pixel that is given. Each model will then output a single vector. But there are 3 types of models per layer. In this example, every line drawn is a new model that is reused for every pixel this process is done for. After each model type has given an output, the three lists of vectors are added together.

This will give a single list of vectors that will be added to the corresponding list of vectors at the specific x,y coordinate from the original state.

Repeating this step for every list of vectors per x,y coordinate in the original state will yield the full new State value.

Since each network only sees a 3x3 grid and not larger image patches, this technique can be used for any size images and is easily parrallelizable.

If I had more compute

My 2080Ti runs into memory errors running this if the batch size is above around 30, so here are my implementatin ideas if I had more compute.

  1. Increase batch_size. This probably wont affect the training, but it would make testing the accuracy faster.
  2. Saving more states throughout the steps taken and adding them together. This would allow for gradients to get passed back to the original state similar to how RESNET can train very large model since the gradients can get passed backwards easier. This has been implemented to a smaller degree already and showed massive accuracy improvements.
  3. Perform some kind of evolutionary parameter search by mutating the model parameters while also using backprop. This has been shown to improve the accuracy of image classifiers and other models. But this would take a ton of compute.

Yannic Kilcher's Attention

This hass been pushed to github because during testing and tuning hyperparameters, a better model than previous was found. More testing needs to be done and I'm working on the visual explanation for it now. Previous versions of this code don't have the attention seen in the current version and will have similar performance.

Other Ideas behind the paper implementation

This is basically a neural cellular automata from the paper Growing Neural Cellular Automata with some inspiration from the follow up paper Self-classifying MNIST Digits. Except instead of a single list of numbers (or one vector) per pixel, there are several vectors per pixel in each image. The Growing Neural Cellular Automata paper was very difficult to train also because the long gradient chains, so increasing the models complexity in this GLOM paper makes training even harder. But the neural cellular automata papers are the reason why the MSE loss function is used while also adding random noise to the state during training.

To do

  1. Generated the explanation for Yannick Kilcher's version of attention that is implemented here.
  2. See if part-whole heirarchies are being found.
  3. Keep testing hyperpatameters to push accuracy higher.
  4. Test different state initializations.
  5. Train on harder datasets.

If you find any issues, please feel free to contact me

Owner
Just a random coder
[Arxiv preprint] Causality-inspired Single-source Domain Generalization for Medical Image Segmentation (code&data-processing pipeline)

Causality-inspired Single-source Domain Generalization for Medical Image Segmentation Arxiv preprint Repository under construction. Might still be bug

Cheng 31 Dec 27, 2022
Making a music video with Wav2CLIP and VQGAN-CLIP

music2video Overview A repo for making a music video with Wav2CLIP and VQGAN-CLIP. The base code was derived from VQGAN-CLIP The CLIP embedding for au

Joel Jang | 장요엘 163 Dec 26, 2022
Neural Message Passing for Computer Vision

Neural Message Passing for Quantum Chemistry Implementation of different models of Neural Networks on graphs as explained in the article proposed by G

Pau Riba 310 Nov 07, 2022
Rule based classification A hotel s customers dataset

Rule-based-classification-A-hotel-s-customers-dataset- Aim: Categorize new customers by segment and predict how much revenue they can generate This re

Şebnem 4 Jan 02, 2022
An LSTM for time-series classification

Update 10-April-2017 And now it works with Python3 and Tensorflow 1.1.0 Update 02-Jan-2017 I updated this repo. Now it works with Tensorflow 0.12. In

Rob Romijnders 391 Dec 27, 2022
This is the official Pytorch implementation of the paper "Diverse Motion Stylization for Multiple Style Domains via Spatial-Temporal Graph-Based Generative Model"

Diverse Motion Stylization (Official) This is the official Pytorch implementation of this paper. Diverse Motion Stylization for Multiple Style Domains

Soomin Park 28 Dec 16, 2022
Contenido del curso Bases de datos del DCC PUC versión 2021-2

IIC2413 - Bases de Datos Tabla de contenidos Equipo Profesores Ayudantes Contenidos Calendario Evaluaciones Resumen de notas Foro Política de integrid

54 Nov 23, 2022
QueryFuzz implements a metamorphic testing approach to test Datalog engines.

Datalog is a popular query language with applications in several domains. Like any complex piece of software, Datalog engines may contain bugs. The mo

34 Sep 10, 2022
Model search is a framework that implements AutoML algorithms for model architecture search at scale

Model search (MS) is a framework that implements AutoML algorithms for model architecture search at scale. It aims to help researchers speed up their exploration process for finding the right model a

Google 3.2k Dec 31, 2022
Label Hallucination for Few-Shot Classification

Label Hallucination for Few-Shot Classification This repo covers the implementation of the following paper: Label Hallucination for Few-Shot Classific

Yiren Jian 13 Nov 13, 2022
Ratatoskr: Worcester Tech's conference scheduling system

Ratatoskr: Worcester Tech's conference scheduling system In Norse mythology, Ratatoskr is a squirrel who runs up and down the world tree Yggdrasil to

4 Dec 22, 2022
Official implementation of ACTION-Net: Multipath Excitation for Action Recognition (CVPR'21).

ACTION-Net Official implementation of ACTION-Net: Multipath Excitation for Action Recognition (CVPR'21). Getting Started EgoGesture data folder struct

V-Sense 171 Dec 26, 2022
Efficient Sharpness-aware Minimization for Improved Training of Neural Networks

Efficient Sharpness-aware Minimization for Improved Training of Neural Networks Code for “Efficient Sharpness-aware Minimization for Improved Training

Angusdu 32 Oct 18, 2022
Progressive Image Deraining Networks: A Better and Simpler Baseline

Progressive Image Deraining Networks: A Better and Simpler Baseline [arxiv] [pdf] [supp] Introduction This paper provides a better and simpler baselin

190 Dec 01, 2022
最新版本yolov5+deepsort目标检测和追踪,支持5.0版本可训练自己数据集

使用YOLOv5+Deepsort实现车辆行人追踪和计数,代码封装成一个Detector类,更容易嵌入到自己的项目中。

422 Dec 30, 2022
Code for the ICCV 2021 Workshop paper: A Unified Efficient Pyramid Transformer for Semantic Segmentation.

Unified-EPT Code for the ICCV 2021 Workshop paper: A Unified Efficient Pyramid Transformer for Semantic Segmentation. Installation Linux, CUDA=10.0,

29 Aug 23, 2022
Official Codes for Graph Modularity:Towards Understanding the Cross-Layer Transition of Feature Representations in Deep Neural Networks.

Dynamic-Graphs-Construction Official Codes for Graph Modularity:Towards Understanding the Cross-Layer Transition of Feature Representations in Deep Ne

11 Dec 14, 2022
PyTorch code for the paper "Curriculum Graph Co-Teaching for Multi-target Domain Adaptation" (CVPR2021)

PyTorch code for the paper "Curriculum Graph Co-Teaching for Multi-target Domain Adaptation" (CVPR2021) This repo presents PyTorch implementation of M

Evgeny 79 Dec 19, 2022
Hierarchical Time Series Forecasting with a familiar API

scikit-hts Hierarchical Time Series with a familiar API. This is the result from not having found any good implementations of HTS on-line, and my work

Carlo Mazzaferro 204 Dec 17, 2022
PyTorch code for Composing Partial Differential Equations with Physics-Aware Neural Networks

FInite volume Neural Network (FINN) This repository contains the PyTorch code for models, training, and testing, and Python code for data generation t

Cognitive Modeling 20 Dec 18, 2022