PyTorch implementation of Value Iteration Networks (VIN): Clean, Simple and Modular. Visualization in Visdom.

Overview

VIN: Value Iteration Networks

This is an implementation of Value Iteration Networks (VIN) in PyTorch to reproduce the results.(TensorFlow version)

Architecture of Value Iteration Network

Key idea

  • A fully differentiable neural network with a 'planning' sub-module.
  • Value Iteration = Conv Layer + Channel-wise Max Pooling
  • Generalize better than reactive policies for new, unseen tasks.

Learned Reward Image and Its Value Images for each VI Iteration

Visualization Grid world Reward Image Value Images
8x8
16x16
28x28

Dependencies

This repository requires following packages:

  • Python >= 3.6
  • Numpy >= 1.12.1
  • PyTorch >= 0.1.10
  • SciPy >= 0.19.0
  • visdom >= 0.1

Datasets

Each data sample consists of (x, y) coordinates of current state in grid world, followed by an obstacle image and a goal image.

Dataset size 8x8 16x16 28x28
Train set 77760 776440 4510695
Test set 12960 129440 751905

Running Experiment: Training

Grid world 8x8

python run.py --datafile data/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128

Grid world 16x16

python run.py --datafile data/gridworld_16x16.npz --imsize 16 --lr 0.008 --epochs 30 --k 20 --batch_size 128

Grid world 28x28

python run.py --datafile data/gridworld_28x28.npz --imsize 28 --lr 0.003 --epochs 30 --k 36 --batch_size 128

Flags:

  • datafile: The path to the data files.
  • imsize: The size of input images. From: [8, 16, 28]
  • lr: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]
  • epochs: Number of epochs to train. Default: 30
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • ch_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • ch_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • ch_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
  • batch_size: Batch size. Default: 128

Visualization with Visdom

We shall visualize the learned reward image and its corresponding value images for each VI iteration by using visdom.

Firstly start the server

python -m visdom.server

Open Visdom in browser in http://localhost:8097

Then run following to visualize learn reward and value images.

python vis.py --datafile learned_rewards_values_28x28.npz

NOTE: If you would like to produce GIF animation of value images on your own, the following command might be useful.

convert -delay 20 -loop 0 *.png value_function.gif

Benchmarks

GPU: TITAN X

Performance: Test Accuracy

NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.

Test Accuracy 8x8 16x16 28x28
PyTorch 99.16% 92.44% 88.20%
TensorFlow 99.03% 90.2% 82%

Speed with GPU

Speed per epoch 8x8 16x16 28x28
PyTorch 3s 15s 100s
TensorFlow 4s 25s 165s

Frequently Asked Questions

  • Q: How to get reward image from observation ?

    • A: Observation image has 2 channels. First channel is obstacle image (0: free, 1: obstacle). Second channel is goal image (0: free, 10: goal). For example, in 8x8 grid world, the shape of an input tensor with batch size 128 is [128, 2, 8, 8]. Then it is fed into a convolutional layer with [3, 3] filter and 150 feature maps, followed by another convolutional layer with [3, 3] filter and 1 feature map. The shape of the output tensor is [128, 1, 8, 8]. This is the reward image.
  • Q: What is exactly transition model, and how to obtain value image by VI-module from reward image ?

    • A: Let us assume batch size is 128 under 8x8 grid world. Once we obtain the reward image with shape [128, 1, 8, 8], we do convolutional layer for q layers in VI module. The [3, 3] filter represents the transition probabilities. There is a set of 10 filters, each for generating a feature map in q layers. Each feature map corresponds to an "action". Note that this is larger than real available actions which is only 8. Then we do a channel-wise Max Pooling to obtain the value image with shape [128, 1, 8, 8]. Finally we stack this value image with reward image for a new VI iteration.

References

Further Readings

Owner
Xingdong Zuo
AI in well-being is my dream. Neural networks need to understand the world causally.
Xingdong Zuo
Optimising chemical reactions using machine learning

Summit Summit is a set of tools for optimising chemical processes. We’ve started by targeting reactions. What is Summit? Currently, reaction optimisat

Sustainable Reaction Engineering Group 75 Dec 14, 2022
Official code for the paper "Self-Supervised Prototypical Transfer Learning for Few-Shot Classification"

Self-Supervised Prototypical Transfer Learning for Few-Shot Classification This repository contains the reference source code and pre-trained models (

EPFL INDY 44 Nov 04, 2022
A Dataset for Direct Quotation Extraction and Attribution in News Articles.

DirectQuote - A Dataset for Direct Quotation Extraction and Attribution in News Articles DirectQuote is a corpus containing 19,760 paragraphs and 10,3

THUNLP-MT 9 Sep 23, 2022
SeqAttack: a framework for adversarial attacks on token classification models

A framework for adversarial attacks against token classification models

Walter 23 Nov 25, 2022
A Python package to process & model ChEMBL data.

insilico: A Python package to process & model ChEMBL data. ChEMBL is a manually curated chemical database of bioactive molecules with drug-like proper

Steven Newton 0 Dec 09, 2021
Meaningful titles for tabs and PDF downloads! Also supports tab search.

arxiv-utils If you are a researcher that reads a lot on ArXiv, you'll benefit a lot from this web extension. Renames the title of PDF page to the pape

Johnson 174 Dec 20, 2022
Repo for 2021 SDD assessment task 2, by Felix, Anna, and James.

SoftwareTask2 Repo for 2021 SDD assessment task 2, by Felix, Anna, and James. File/folder structure: helloworld.py - demonstrates various map backgrou

3 Dec 13, 2022
The official implementation of the Hybrid Self-Attention NEAT algorithm

PUREPLES - Pure Python Library for ES-HyperNEAT About This is a library of evolutionary algorithms with a focus on neuroevolution, implemented in pure

Adrian Westh 91 Dec 12, 2022
All materials of Cassandra Event, Udyam'22

Cassandra 2022 Workspace Workshop Materials Workshop-1 Workshop-2 Workshop-3 Workshop-4 Assignments Assignment-1 Assignment-2 Assignment-3 Resources P

36 Dec 31, 2022
It's a powerful version of linebot

CTPS-FINAL Linbot-sever.py 主程式 Algorithm.py 推薦演算法,媒合餐廳端資料與顧客端資料 config.ini 儲存 channel-access-token、channel-secret 資料 Preface 生活在成大將近4年,我們每天的午餐時間看著形形色色

1 Oct 17, 2022
Source Code for our paper: Understand me, if you refer to Aspect Knowledge: Knowledge-aware Gated Recurrent Memory Network

KaGRMN-DSG_ABSA This repository contains the PyTorch source Code for our paper: Understand me, if you refer to Aspect Knowledge: Knowledge-aware Gated

XingBowen 4 May 20, 2022
U-Net Brain Tumor Segmentation

U-Net Brain Tumor Segmentation 🚀 :Feb 2019 the data processing implementation in this repo is not the fastest way (code need update, contribution is

Hao 448 Jan 02, 2023
Scalable training for dense retrieval models.

Scalable implementation of dense retrieval. Training on cluster By default it trains locally: PYTHONPATH=.:$PYTHONPATH python dpr_scale/main.py traine

Facebook Research 90 Dec 28, 2022
Official implementation of Rethinking Graph Neural Architecture Search from Message-passing (CVPR2021)

Rethinking Graph Neural Architecture Search from Message-passing Intro The GNAS can automatically learn better architecture with the optimal depth of

Shaofei Cai 48 Sep 30, 2022
Weakly Supervised End-to-End Learning (NeurIPS 2021)

WeaSEL: Weakly Supervised End-to-end Learning This is a PyTorch-Lightning-based framework, based on our End-to-End Weak Supervision paper (NeurIPS 202

Auton Lab, Carnegie Mellon University 131 Jan 06, 2023
So-ViT: Mind Visual Tokens for Vision Transformer

So-ViT: Mind Visual Tokens for Vision Transformer        Introduction This repository contains the source code under PyTorch framework and models trai

Jiangtao Xie 44 Nov 24, 2022
A large-scale database for graph representation learning

A large-scale database for graph representation learning

Scott Freitas 29 Nov 25, 2022
なりすまし検出(anti-spoof-mn3)のWebカメラ向けデモ

FaceDetection-Anti-Spoof-Demo なりすまし検出(anti-spoof-mn3)のWebカメラ向けデモです。 モデルはPINTO_model_zoo/191_anti-spoof-mn3からONNX形式のモデルを使用しています。 Requirement mediapipe

KazuhitoTakahashi 8 Nov 18, 2022
Library for time-series-forecasting-as-a-service.

TIMEX TIMEX (referred in code as timexseries) is a framework for time-series-forecasting-as-a-service. Its main goal is to provide a simple and generi

Alessandro Falcetta 8 Jan 06, 2023
LogDeep is an open source deeplearning-based log analysis toolkit for automated anomaly detection.

LogDeep is an open source deeplearning-based log analysis toolkit for automated anomaly detection.

donglee 279 Dec 13, 2022