Pytorch implementation of Value Iteration Networks (NIPS 2016 best paper)

Overview

VIN: Value Iteration Networks

Architecture of Value Iteration Network

A quick thank you

A few others have released amazing related work which helped inspire and improve my own implementation. It goes without saying that this release would not be nearly as good if it were not for all of the following:

Why another VIN implementation?

  1. The Pytorch VIN model in this repository is, in my opinion, more readable and closer to the original Theano implementation than others I have found (both Tensorflow and Pytorch).
  2. This is not simply an implementation of the VIN model in Pytorch, it is also a full Python implementation of the gridworld environments as used in the original MATLAB implementation.
  3. Provide a more extensible research base for others to build off of without needing to jump through the possible MATLAB paywall.

Installation

This repository requires following packages:

Use pip to install the necessary dependencies:

pip install -U -r requirements.txt 

Note that PyTorch cannot be installed directly from PyPI; refer to http://pytorch.org/ for custom installation instructions specific to your needs.

How to train

8x8 gridworld

python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128

16x16 gridworld

python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128

28x28 gridworld

python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --epochs 30 --k 36 --batch_size 128

Flags:

  • datafile: The path to the data files.
  • imsize: The size of input images. One of: [8, 16, 28]
  • lr: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]
  • epochs: Number of epochs to train. Default: 30
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
  • batch_size: Batch size. Default: 128

How to test / visualize paths (requires training first)

8x8 gridworld

python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10

16x16 gridworld

python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20

28x28 gridworld

python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36

To visualize the optimal and predicted paths simply pass:

--plot

Flags:

  • weights: Path to trained weights.
  • imsize: The size of input images. One of: [8, 16, 28]
  • plot: If supplied, the optimal and predicted paths will be plotted
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.

Results

Gridworld Sample One Sample Two
8x8
16x16
28x28

Datasets

Each data sample consists of an obstacle image and a goal image followed by the (x, y) coordinates of current state in the gridworld.

Dataset size 8x8 16x16 28x28
Train set 81337 456309 1529584
Test set 13846 77203 251755

The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the dataset/make_training_data.py script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)

Performance: Success Rate

This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).

Success Rate 8x8 16x16 28x28
PyTorch 99.69% 96.99% 91.07%

Performance: Test Accuracy

NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.

Test Accuracy 8x8 16x16 28x28
PyTorch 99.83% 94.84% 88.54%
Comments
  • testing accuracy fairly low

    testing accuracy fairly low

    I just tried to follow the instructions in the repo, and tested models trained but got a fairly low accuracy. I'm using pyTorch 0.1.12_1. Is there anything I should pay attention to?

    opened by xinleipan 10
  • Prebuilt Dataset Generation

    Prebuilt Dataset Generation

    Hello,

    I was wondering how you generated the prebuilt datasets that are downloaded when running download_weights_and_datasets.sh, i.e. what were the max_obs and max_obs_size parameters?

    Did you follow this file in the original repo? https://github.com/avivt/VIN/blob/master/scripts/make_data_gridworld_nips.m

    Thanks, Emilio

    opened by eparisotto 5
  • the rollout accuracy in test script is lower than the test accuracy in train script.

    the rollout accuracy in test script is lower than the test accuracy in train script.

    Hello!

    I have a little doubt.Does the rollout accuracy indicate the success rate? If so, why is it lower than the prediction accuracy? In the Aviv's implementation, the success rate of the 8x8 grid world was as high as 99.6%. Why is the success rate in your experiment relatively low?

    Thanks!

    opened by albzni 4
  • RUN ERROR

    RUN ERROR

    when I run 'python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128', it's ok,but again 'python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128' was run, an error occurred as follows: [email protected]:~/pytorch-value-iteration-networks$ python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 10 --k 20 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 58, in _process images = images.astype(np.float32) MemoryError

    opened by N-Kingsley 3
  • Problem of running the test script

    Problem of running the test script

    Hello,

    I downloaded the data with the .sh downloading script you provided, I also got an nps weights file after training. When I ran the testing command I got the following error: Traceback (most recent call last): File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 158, in main(config) File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 85, in main _, predictions = vin(X_in, S1_in, S2_in, config) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 357, in call result = self.forward(*input, **kwargs) File "/home/research/DL/VIN/pytorch-value-iteration-networks/model.py", line 64, in forward return logits, self.sm(logits) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 352, in call for hook in self._forward_pre_hooks.values(): File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 398, in getattr type(self).name, name)) AttributeError: 'Softmax' object has no attribute '_forward_pre_hooks'

    Thanks for helping!

    opened by YantianZha 3
  • Improved readability of the VIN model, in addition to minor changes

    Improved readability of the VIN model, in addition to minor changes

    My main modification is in the forward method of the model where you extract the q_out from the q values, and not repeating q = F.conv2d(...) in two places. I also made minor improvements, such as adding argparse in the dataset creation script and changing .cuda() into .to(device) in test.py.

    opened by shuishida 2
  • Inconsistent tensor sizes when starting training

    Inconsistent tensor sizes when starting training

    Hey there. I'm trying to run

    python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128
    

    But I get the following error

    Number of Train Samples: 103926
    Number of Test Samples: 17434
         Epoch | Train Loss | Train Error | Epoch Time
    Traceback (most recent call last):
      File "train.py", line 147, in <module>
        train(net, trainloader, config, criterion, optimizer, use_GPU)
      File "train.py", line 40, in train
        outputs, predictions = net(X, S1, S2, config)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 224, in __call__
        result = self.forward(*input, **kwargs)
      File "/media/user_home2/j1k1000o/j1k/VINs/pytorch-value-iteration-networks/model.py", line 44, in forward
        q = F.conv2d(torch.cat([r, v], 1), 
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py", line 897, in cat
        return Concat.apply(dim, *iterable)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/_functions/tensor.py", line 317, in forward
        return torch.cat(inputs, dim)
    RuntimeError: inconsistent tensor sizes at /opt/conda/conda-bld/pytorch_1502009910772/work/torch/lib/THC/generic/THCTensorMath.cu:141
    

    I've executed

    ./download_weights_and_datasets.sh
    

    as well as

    python ./dataset/make_training_data.py
    

    And I'm running it on an Ubuntu 16.04, python 3.6 and with all the requirements installed.

    Can you help me out?

    opened by juancprzs 2
  • Don't understand VIN last step

    Don't understand VIN last step

        slice_s1 = S1.long().expand(config.imsize, 1, config.l_q, q.size(0))
        slice_s1 = slice_s1.permute(3, 2, 1, 0)
        q_out = q.gather(2, slice_s1).squeeze(2)
    

    What does this 3 lines do?

    opened by QiXuanWang 1
  • KeyError: 'arr_1 is not a file in the archive'

    KeyError: 'arr_1 is not a file in the archive'

    python3 train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 49, in _process S1 = f['arr_1'] File "/home/user/miniconda3/lib/python3.6/site-packages/numpy/lib/npyio.py", line 255, in getitem raise KeyError("%s is not a file in the archive" % key) KeyError: 'arr_1 is not a file in the archive'

    I got this error, could you please

    opened by derelearnro 1
  • Problem of running dataset/make_training_data.py script

    Problem of running dataset/make_training_data.py script

    Hi

    When I tried to run the make_training_data.py script to generate the gridworld.npz file, I got the following error:

    FileNotFoundError: [Errno 2] No such file or directory: 'dataset/gridworld_28x28.npz'
    

    And I found that line 101 should be modified as follows:

    save_path = "gridworld_{0}x{1}".format(dom_size[0], dom_size[1])
    
    opened by ruqing00 0
Owner
Kent Sommer
Software Engineer @ Toyota Research Institute (SF Bay Area)
Kent Sommer
Predict and time series avocado hass

RECOMMENDER SYSTEM MARKETING TỔNG QUAN VỀ HỆ THỐNG DỮ LIỆU 1. Giới thiệu - Tiki là một hệ sinh thái thương mại "all in one", trong đó có tiki.vn, là

hieulmsc 3 Jan 10, 2022
[ICLR2021oral] Rethinking Architecture Selection in Differentiable NAS

DARTS-PT Code accompanying the paper ICLR'2021: Rethinking Architecture Selection in Differentiable NAS Ruochen Wang, Minhao Cheng, Xiangning Chen, Xi

Ruochen Wang 86 Dec 27, 2022
An automated algorithm to extract the linear blend skinning (LBS) from a set of example poses

Dem Bones This repository contains an implementation of Smooth Skinning Decomposition with Rigid Bones, an automated algorithm to extract the Linear B

Electronic Arts 684 Dec 26, 2022
:boar: :bear: Deep Learning based Python Library for Stock Market Prediction and Modelling

bulbea "Deep Learning based Python Library for Stock Market Prediction and Modelling." Table of Contents Installation Usage Documentation Dependencies

Achilles Rasquinha 1.8k Jan 05, 2023
An ever-growing playground of notebooks showcasing CLIP's impressive zero-shot capabilities.

Playground for CLIP-like models Demo Colab Link GradCAM Visualization Naive Zero-shot Detection Smarter Zero-shot Detection Captcha Solver Changelog 2

Kevin Zakka 101 Dec 30, 2022
Neural Contours: Learning to Draw Lines from 3D Shapes (CVPR2020)

Neural Contours: Learning to Draw Lines from 3D Shapes This repository contains the PyTorch implementation for CVPR 2020 Paper "Neural Contours: Learn

93 Dec 16, 2022
PyTorch implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 13.4k Jan 08, 2023
Curating a dataset for bioimage transfer learning

CytoImageNet A large-scale pretraining dataset for bioimage transfer learning. Motivation In past few decades, the increase in speed of data collectio

Stanley Z. Hua 9 Jun 20, 2022
Image-based Navigation in Real-World Environments via Multiple Mid-level Representations: Fusion Models Benchmark and Efficient Evaluation

Image-based Navigation in Real-World Environments via Multiple Mid-level Representations: Fusion Models Benchmark and Efficient Evaluation This reposi

First Person Vision @ Image Processing Laboratory - University of Catania 1 Aug 21, 2022
Task-based end-to-end model learning in stochastic optimization

Task-based End-to-end Model Learning in Stochastic Optimization This repository is by Priya L. Donti, Brandon Amos, and J. Zico Kolter and contains th

CMU Locus Lab 164 Dec 29, 2022
null

DeformingThings4D dataset Video | Paper DeformingThings4D is an synthetic dataset containing 1,972 animation sequences spanning 31 categories of human

208 Jan 03, 2023
L-Verse: Bidirectional Generation Between Image and Text

Far beyond learning long-range interactions of natural language, transformers are becoming the de-facto standard for many vision tasks with their power and scalabilty

Kim, Taehoon 102 Dec 21, 2022
SSD: Single Shot MultiBox Detector pytorch implementation focusing on simplicity

SSD: Single Shot MultiBox Detector Introduction Here is my pytorch implementation of 2 models: SSD-Resnet50 and SSDLite-MobilenetV2.

Viet Nguyen 149 Jan 07, 2023
A annotation of yolov5-5.0

代码版本:0714 commit #4000 $ git clone https://github.com/ultralytics/yolov5 $ cd yolov5 $ git checkout 720aaa65c8873c0d87df09e3c1c14f3581d4ea61 这个代码只是注释版

Laughing 229 Dec 17, 2022
Deep Learning (with PyTorch)

Deep Learning (with PyTorch) This notebook repository now has a companion website, where all the course material can be found in video and textual for

Alfredo Canziani 6.2k Jan 07, 2023
Resilient projection-based consensus actor-critic (RPBCAC) algorithm

Resilient projection-based consensus actor-critic (RPBCAC) algorithm We implement the RPBCAC algorithm with nonlinear approximation from [1] and focus

Martin Figura 5 Jul 12, 2022
[CVPR 2021] Official PyTorch Implementation for "Iterative Filter Adaptive Network for Single Image Defocus Deblurring"

IFAN: Iterative Filter Adaptive Network for Single Image Defocus Deblurring Checkout for the demo (GUI/Google Colab)! The GUI version might occasional

Junyong Lee 173 Dec 30, 2022
Vector Quantization, in Pytorch

Vector Quantization - Pytorch A vector quantization library originally transcribed from Deepmind's tensorflow implementation, made conveniently into a

Phil Wang 665 Jan 08, 2023
Low-dose Digital Mammography with Deep Learning

Impact of loss functions on the performance of a deep neural network designed to restore low-dose digital mammography ====== This repository contains

WANG-AXIS 6 Dec 13, 2022
PyTorch Implementation of "Non-Autoregressive Neural Machine Translation"

Non-Autoregressive Transformer Code release for Non-Autoregressive Neural Machine Translation by Jiatao Gu, James Bradbury, Caiming Xiong, Victor O.K.

Salesforce 261 Nov 12, 2022