My implementation of DeepMind's Perceiver

Overview

DeepMind Perceiver (in PyTorch)

Disclaimer: This is not official and I'm not affiliated with DeepMind.

My implementation of the Perceiver: General Perception with Iterative Attention. You can read more about the model on DeepMind's website.

I trained an MNIST model which you can find in models/mnist.pkl or by using perceiver.load_mnist_model(). It gets 96.02% on the test-data.

Getting started

To run this you need PyTorch installed:

pip3 install torch

From perceiver you can import Perceiver or PerceiverLogits.

Then you can use it as such (or look in examples.ipynb):

from perceiver import Perceiver

model = Perceiver(
    input_channels, # <- How many channels in the input? E.g. 3 for RGB.
    input_shape, # <- How big is the input in the different dimensions? E.g. (28, 28) for MNIST
    fourier_bands=4, # <- How many bands should the positional encoding have?
    latents=64, # <- How many latent vectors?
    d_model=32, # <- Model dimensionality. Every pixel/token/latent vector will have this size.
    heads=8, # <- How many heads in self-attention? Cross-attention always has 1 head.
    latent_blocks=6, # <- How much latent self-attention for each cross attention with the input?
    dropout=0.1, # <- Dropout
    layers=8, # <- This will become two unique layer-blocks: layer 1 and layer 2-8 (using weight sharing).
)

The above model outputs the latents after the final layer. If you want logits instead, use the following model:

from perceiver import PerceiverLogits

model = PerceiverLogits(
    input_channels, # <- How many channels in the input? E.g. 3 for RGB.
    input_shape, # <- How big is the input in the different dimensions? E.g. (28, 28) for MNIST
    output_features, # <- How many different classes? E.g. 10 for MNIST.
    fourier_bands=4, # <- How many bands should the positional encoding have?
    latents=64, # <- How many latent vectors?
    d_model=32, # <- Model dimensionality. Every pixel/token/latent vector will have this size.
    heads=8, # <- How many heads in self-attention? Cross-attention always has 1 head.
    latent_blocks=6, # <- How much latent self-attention for each cross attention with the input?
    dropout=0.1, # <- Dropout
    layers=8, # <- This will become two unique layer-blocks: layer 1 and layer 2-8 (using weight sharing).
)

To use my pre-trained MNIST model (not very good):

from perceiver import load_mnist_model

model = load_mnist_model()

TODO:

  • Positional embedding generalized to n dimensions (with fourier features)
  • Train other models (like CIFAR-100 or something not in the image domain)
  • Type indication
  • Unit tests for components of model
  • Package
Owner
Louis Arge
Experienced full-stack developer. Self-studying machine learning.
Louis Arge
FreeSOLO for unsupervised instance segmentation, CVPR 2022

FreeSOLO: Learning to Segment Objects without Annotations This project hosts the code for implementing the FreeSOLO algorithm for unsupervised instanc

NVIDIA Research Projects 253 Jan 02, 2023
lightweight python wrapper for vowpal wabbit

vowpal_porpoise Lightweight python wrapper for vowpal_wabbit. Why: Scalable, blazingly fast machine learning. Install Install vowpal_wabbit. Clone and

Joseph Reisinger 163 Nov 24, 2022
Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020)

GraspNet Baseline Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020). [paper] [dataset] [API] [do

GraspNet 209 Dec 29, 2022
A computational optimization project towards the goal of gerrymandering the results of a hypothetical election in the UK.

A computational optimization project towards the goal of gerrymandering the results of a hypothetical election in the UK.

Emma 1 Jan 18, 2022
Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers.

Less is More: Pay Less Attention in Vision Transformers Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers. By

73 Jan 01, 2023
Using pretrained language models for biomedical knowledge graph completion.

LMs for biomedical KG completion This repository contains code to run the experiments described in: Scientific Language Models for Biomedical Knowledg

Rahul Nadkarni 41 Nov 30, 2022
EDPN: Enhanced Deep Pyramid Network for Blurry Image Restoration

EDPN: Enhanced Deep Pyramid Network for Blurry Image Restoration Ruikang Xu, Zeyu Xiao, Jie Huang, Yueyi Zhang, Zhiwei Xiong. EDPN: Enhanced Deep Pyra

69 Dec 15, 2022
A collection of scripts I developed for personal and working projects.

A collection of scripts I developed for personal and working projects Table of contents Introduction Repository diagram structure List of scripts pyth

Gianluca Bianco 109 Dec 26, 2022
SwinTrack: A Simple and Strong Baseline for Transformer Tracking

SwinTrack This is the official repo for SwinTrack. A Simple and Strong Baseline Prerequisites Environment conda (recommended) conda create -y -n SwinT

LitingLin 196 Jan 04, 2023
kullanışlı ve işinizi kolaylaştıracak bir araç

Hey merhaba! işte çok sorulan sorularının cevabı ve sorunlarının çözümü; Soru= İçinde var denilen birçok şeyi göremiyorum bunun sebebi nedir? Cevap= B

Sexettin 16 Dec 17, 2022
Gesture-Volume-Control - This Python program can adjust the system's volume by using hand gestures

Gesture-Volume-Control This Python program can adjust the system's volume by usi

VatsalAryanBhatanagar 1 Dec 30, 2021
Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation

Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation This paper has been accepted and early accessed

Yun Liu 39 Sep 20, 2022
PowerGridworld: A Framework for Multi-Agent Reinforcement Learning in Power Systems

PowerGridworld provides users with a lightweight, modular, and customizable framework for creating power-systems-focused, multi-agent Gym environments that readily integrate with existing training fr

National Renewable Energy Laboratory 37 Dec 17, 2022
A Diagnostic Dataset for Compositional Language and Elementary Visual Reasoning

CLEVR Dataset Generation This is the code used to generate the CLEVR dataset as described in the paper: CLEVR: A Diagnostic Dataset for Compositional

Facebook Research 503 Jan 04, 2023
A tutorial showing how to train, convert, and run TensorFlow Lite object detection models on Android devices, the Raspberry Pi, and more!

A tutorial showing how to train, convert, and run TensorFlow Lite object detection models on Android devices, the Raspberry Pi, and more!

Evan 1.3k Jan 02, 2023
FAVD: Featherweight Assisted Vulnerability Discovery

FAVD: Featherweight Assisted Vulnerability Discovery This repository contains the replication package for the paper "Featherweight Assisted Vulnerabil

secureIT 4 Sep 16, 2022
OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network

Stock Price Prediction of Apple Inc. Using Recurrent Neural Network OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network Dataset:

Nouroz Rahman 410 Jan 05, 2023
Towards Part-Based Understanding of RGB-D Scans

Towards Part-Based Understanding of RGB-D Scans (CVPR 2021) We propose the task of part-based scene understanding of real-world 3D environments: from

26 Nov 23, 2022
Beyond imagenet attack (accepted by ICLR 2022) towards crafting adversarial examples for black-box domains.

Beyond ImageNet Attack: Towards Crafting Adversarial Examples for Black-box Domains (ICLR'2022) This is the Pytorch code for our paper Beyond ImageNet

Alibaba-AAIG 37 Nov 23, 2022
[CVPR 2022 Oral] Balanced MSE for Imbalanced Visual Regression https://arxiv.org/abs/2203.16427

Balanced MSE Code for the paper: Balanced MSE for Imbalanced Visual Regression Jiawei Ren, Mingyuan Zhang, Cunjun Yu, Ziwei Liu CVPR 2022 (Oral) News

Jiawei Ren 267 Jan 01, 2023