PyTorch implementation for the visual prior component (i.e. perception module) of the Visually Grounded Physics Learner [Li et al., 2020].

Overview

VGPL-Visual-Prior

PyTorch implementation for the visual prior component (i.e. perception module) of the Visually Grounded Physics Learner (VGPL). Given visual obseravtions, the visual prior proposes their corresponding particle representations, in the form of particle positions and groupings. Please see the following paper for more details.

Visual Grounding of Learned Physical Models

Yunzhu Li, Toru Lin*, Kexin Yi*, Daniel M. Bear, Daniel L. K. Yamins, Jiajun Wu, Joshua B. Tenenbaum, and Antonio Torralba

ICML 2020 [website] [paper] [video]

Demo

Input RGB videos and predictions from our learned model

Prerequisites

  • Python 3
  • PyTorch 1.0 or higher, with NVIDIA CUDA Support
  • Other required packages in requirements.txt

Code overview

Helper files

config.py contains all configurations used for model training, model evaluation and output generation.

dataset.py contains helper functions for loading and standardizing data and related variables. Note that paths to data directories is specified in the _DATA_DIR variable in this file, not in config.py.

loss.py contains helper functions for calculating Chamfer loss in different settings (e.g. in a single frame, across a time sequence, etc.).

model.py implements the neural network model used for prediction.

Main files

The following files can be run directly; see "Training and evaluation" section for more details.

train.py trains a model that could convert input observations into their particle representations.

eval.py evaluates a trained model by visualizing its predictions, and/or stores the output predictions in .h5 format.

Training and evaluation

Download the training and evaluation data from the following links, and put them in data folder. Optionally, download our trained model checkpoints and put them in dump folder.

To train a model:

python train.py --set loss_type l2 dataset RigidFall

To debug (by overfitting model on small batch of data):

python train.py --set loss_type l2 dataset RigidFall debug True

To evaluate a trained model and generate outputs using our provided checkpoints:

python eval.py --set loss_type l2 dataset RigidFall n_frames 4 n_frames_eval 30 load_path dump/rigid_fall_4frame_l2.pth
python eval.py --set loss_type l2 dataset MassRope n_frames 4 n_frames_eval 30 load_path dump/mass_rope_4frame_l2.pth

See config.py for more details on customizable configurations.

Citing VGPL

If you find this codebase useful in your research, please consider citing:

@inproceedings{li2020visual,
    Title={Visual Grounding of Learned Physical Models},
    Author={Li, Yunzhu and Lin, Toru and Yi, Kexin and Bear, Daniel and Yamins, Daniel L.K. and Wu, Jiajun and Tenenbaum, Joshua B. and Torralba, Antonio},
    Booktitle={ICML},
    Year={2020}
}

@inproceedings{li2019learning,
    Title={Learning Particle Dynamics for Manipulating Rigid Bodies, Deformable Objects, and Fluids},
    Author={Li, Yunzhu and Wu, Jiajun and Tedrake, Russ and Tenenbaum, Joshua B and Torralba, Antonio},
    Booktitle={ICLR},
    Year={2019}
}
Owner
Toru
Toru
1st place solution in CCF BDCI 2021 ULSEG challenge

1st place solution in CCF BDCI 2021 ULSEG challenge This is the source code of the 1st place solution for ultrasound image angioma segmentation task (

Chenxu Peng 30 Nov 22, 2022
generate-2D-quadrilateral-mesh-with-neural-networks-and-tree-search

generate-2D-quadrilateral-mesh-with-neural-networks-and-tree-search This repository contains single-threaded TreeMesh code. I'm Hua Tong, a senior stu

Hua Tong 18 Sep 21, 2022
The repository is for safe reinforcement learning baselines.

Safe-Reinforcement-Learning-Baseline The repository is for Safe Reinforcement Learning (RL) research, in which we investigate various safe RL baseline

172 Dec 19, 2022
A script that trains a model to recognize handwritten digits using the MNIST data set.

handwritten-digits-recognition A script that trains a model to recognize handwritten digits using the MNIST data set. Then it loads external files and

Hamza Sayih 1 Oct 30, 2021
Python SDK for building, training, and deploying ML models

Overview of Kubeflow Fairing Kubeflow Fairing is a Python package that streamlines the process of building, training, and deploying machine learning (

Kubeflow 325 Dec 13, 2022
Brain tumor detection using Convolution-Neural Network (CNN)

Detect and Classify Brain Tumor using CNN. A system performing detection and classification by using Deep Learning Algorithms using Convolution-Neural Network (CNN).

assia 1 Feb 07, 2022
Implementation of ICLR 2020 paper "Revisiting Self-Training for Neural Sequence Generation"

Self-Training for Neural Sequence Generation This repo includes instructions for running noisy self-training algorithms from the following paper: Revi

Junxian He 45 Dec 31, 2022
joint detection and semantic segmentation, based on ultralytics/yolov5,

Multi YOLO V5——Detection and Semantic Segmentation Overeview This is my undergraduate graduation project which based on ultralytics YOLO V5 tag v5.0.

477 Jan 06, 2023
Image Super-Resolution Using Very Deep Residual Channel Attention Networks

Image Super-Resolution Using Very Deep Residual Channel Attention Networks

kongdebug 14 Oct 14, 2022
Code for the paper "TadGAN: Time Series Anomaly Detection Using Generative Adversarial Networks"

TadGAN: Time Series Anomaly Detection Using Generative Adversarial Networks This is a Python3 / Pytorch implementation of TadGAN paper. The associated

Arun 92 Dec 03, 2022
CAR-API: Cityscapes Attributes Recognition API

CAR-API: Cityscapes Attributes Recognition API This is the official api to download and fetch attributes annotations for Cityscapes Dataset. Content I

Kareem Metwaly 5 Dec 22, 2022
For IBM Quantum Challenge Africa 2021, 9 September (07:00 UTC) - 20 September (23:00 UTC).

IBM Quantum Challenge Africa 2021 To ensure Africa is able to apply quantum computing to solve problems relevant to the continent, the IBM Research La

Qiskit Community 48 Dec 25, 2022
This thesis is mainly concerned with state-space methods for a class of deep Gaussian process (DGP) regression problems

Doctoral dissertation of Zheng Zhao This thesis is mainly concerned with state-space methods for a class of deep Gaussian process (DGP) regression pro

Zheng Zhao 21 Nov 14, 2022
Collection of sports betting AI tools.

sports-betting sports-betting is a collection of tools that makes it easy to create machine learning models for sports betting and evaluate their perf

George Douzas 109 Dec 31, 2022
A Simulation Environment to train Robots in Large Realistic Interactive Scenes

iGibson: A Simulation Environment to train Robots in Large Realistic Interactive Scenes iGibson is a simulation environment providing fast visual rend

Stanford Vision and Learning Lab 493 Jan 04, 2023
Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo"

dblmahmc Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo" Requirements: https://github.com

1 Dec 17, 2021
Code for the paper "Balancing Training for Multilingual Neural Machine Translation, ACL 2020"

Balancing Training for Multilingual Neural Machine Translation Implementation of the paper Balancing Training for Multilingual Neural Machine Translat

Xinyi Wang 21 May 18, 2022
Unsupervised Learning of Video Representations using LSTMs

Unsupervised Learning of Video Representations using LSTMs Code for paper Unsupervised Learning of Video Representations using LSTMs by Nitish Srivast

Elman Mansimov 341 Dec 20, 2022
Unofficial Pytorch Lightning implementation of Contrastive Syn-to-Real Generalization (ICLR, 2021)

Unofficial Pytorch Lightning implementation of Contrastive Syn-to-Real Generalization (ICLR, 2021)

Gyeongjae Choi 17 Sep 23, 2021