Code for the paper "Reinforcement Learning as One Big Sequence Modeling Problem"

Overview

Trajectory Transformer

Code release for Reinforcement Learning as One Big Sequence Modeling Problem.

Installation

All python dependencies are in environment.yml. Install with:

conda env create -f environment.yml
conda activate trajectory
pip install -e .

For reproducibility, we have also included system requirements in a Dockerfile (see installation instructions), but the conda installation should work on most standard Linux machines.

Usage

Train a transformer with: python scripts/train.py --dataset halfcheetah-medium-v2

To reproduce the offline RL results: python scripts/plan.py --dataset halfcheetah-medium-v2

By default, these commands will use the hyperparameters in config/offline.py. You can override them with runtime flags:

python scripts/plan.py --dataset halfcheetah-medium-v2 \
	--horizon 5 --beam_width 32

A few hyperparameters are different from those listed in the paper because of changes to the discretization strategy. These hyperparameters will be updated in the next arxiv version to match what is currently in the codebase.

Pretrained models

We have provided pretrained models for 16 datasets: {halfcheetah, hopper, walker2d, ant}-{expert-v2, medium-expert-v2, medium-v2, medium-replay-v2}. Download them with ./pretrained.sh

The models will be saved in logs/$DATASET/gpt/pretrained. To plan with these models, refer to them using the gpt_loadpath flag:

python scripts/plan.py --dataset halfcheetah-medium-v2 \
	--gpt_loadpath gpt/pretrained

pretrained.sh will also download 15 plans from each model, saved to logs/$DATASET/plans/pretrained. Read them with python plotting/read_results.py.

To create the table of offline RL results from the paper, run python plotting/table.py. This will print a table that can be copied into a Latex document. (Expand to view table source.)
\begin{table*}[h]
\centering
\small
\begin{tabular}{llrrrrrr}
\toprule
\multicolumn{1}{c}{\bf Dataset} & \multicolumn{1}{c}{\bf Environment} & \multicolumn{1}{c}{\bf BC} & \multicolumn{1}{c}{\bf MBOP} & \multicolumn{1}{c}{\bf BRAC} & \multicolumn{1}{c}{\bf CQL} & \multicolumn{1}{c}{\bf DT} & \multicolumn{1}{c}{\bf TT (Ours)} \\ 
\midrule
Medium-Expert & HalfCheetah & $59.9$ & $105.9$ & $41.9$ & $62.4$ & $86.8$ & $95.0$ \scriptsize{\raisebox{1pt}{$\pm 0.2$}} \\ 
Medium-Expert & Hopper & $79.6$ & $55.1$ & $0.9$ & $111.0$ & $107.6$ & $110.0$ \scriptsize{\raisebox{1pt}{$\pm 2.7$}} \\ 
Medium-Expert & Walker2d & $36.6$ & $70.2$ & $81.6$ & $98.7$ & $108.1$ & $101.9$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\ 
Medium-Expert & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $116.1$ \scriptsize{\raisebox{1pt}{$\pm 9.0$}} \\ 
\midrule
Medium & HalfCheetah & $43.1$ & $44.6$ & $46.3$ & $44.4$ & $42.6$ & $46.9$ \scriptsize{\raisebox{1pt}{$\pm 0.4$}} \\ 
Medium & Hopper & $63.9$ & $48.8$ & $31.3$ & $58.0$ & $67.6$ & $61.1$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\ 
Medium & Walker2d & $77.3$ & $41.0$ & $81.1$ & $79.2$ & $74.0$ & $79.0$ \scriptsize{\raisebox{1pt}{$\pm 2.8$}} \\ 
Medium & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $83.1$ \scriptsize{\raisebox{1pt}{$\pm 7.3$}} \\ 
\midrule
Medium-Replay & HalfCheetah & $4.3$ & $42.3$ & $47.7$ & $46.2$ & $36.6$ & $41.9$ \scriptsize{\raisebox{1pt}{$\pm 2.5$}} \\ 
Medium-Replay & Hopper & $27.6$ & $12.4$ & $0.6$ & $48.6$ & $82.7$ & $91.5$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\ 
Medium-Replay & Walker2d & $36.9$ & $9.7$ & $0.9$ & $26.7$ & $66.6$ & $82.6$ \scriptsize{\raisebox{1pt}{$\pm 6.9$}} \\ 
Medium-Replay & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $77.0$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\ 
\midrule
\multicolumn{2}{c}{\bf Average (without Ant)} & 47.7 & 47.8 & 36.9 & 63.9 & 74.7 & 78.9 \hspace{.6cm} \\ 
\multicolumn{2}{c}{\bf Average (all settings)} & $-$ & $-$ & $-$ & $-$ & $-$ & 82.2 \hspace{.6cm} \\ 
\bottomrule
\end{tabular}
\label{table:d4rl}
\end{table*}

To create the average performance plot, run python plotting/plot.py. (Expand to view plot.)

Docker

Copy your MuJoCo key to the Docker build context and build the container:

cp ~/.mujoco/mjkey.txt azure/files/
docker build -f azure/Dockerfile . -t trajectory

Test the container:

docker run -it --rm --gpus all \
	--mount type=bind,source=$PWD,target=/home/code \
	--mount type=bind,source=$HOME/.d4rl,target=/root/.d4rl \
	trajectory \
	bash -c \
	"export PYTHONPATH=$PYTHONPATH:/home/code && \
	python /home/code/scripts/train.py --dataset hopper-medium-expert-v2 --exp_name docker/"

Running on Azure

Setup

  1. Launching jobs on Azure requires one more python dependency:
pip install git+https://github.com/JannerM/[email protected]
  1. Tag the image built in the previous section and push it to Docker Hub:
export DOCKER_USERNAME=$(docker info | sed '/Username:/!d;s/.* //')
docker tag trajectory ${DOCKER_USERNAME}/trajectory:latest
docker image push ${DOCKER_USERNAME}/trajectory
  1. Update azure/config.py, either by modifying the file directly or setting the relevant environment variables. To set the AZURE_STORAGE_CONNECTION variable, navigate to the Access keys section of your storage account. Click Show keys and copy the Connection string.

  2. Download azcopy: ./azure/download.sh

Usage

Launch training jobs with python azure/launch_train.py and planning jobs with python azure/launch_plan.py.

These scripts do not take runtime arguments. Instead, they run the corresponding scripts (scripts/train.py and scripts/plan.py, respectively) using the Cartesian product of the parameters in params_to_sweep.

Viewing results

To rsync the results from the Azure storage container, run ./azure/sync.sh.

To mount the storage container:

  1. Create a blobfuse config with ./azure/make_fuse_config.sh
  2. Run ./azure/mount.sh to mount the storage container to ~/azure_mount

To unmount the container, run sudo umount -f ~/azure_mount; rm -r ~/azure_mount

Reference

@article{janner2021sequence,
  title={Reinforcement Learning as One Big Sequence Modeling Problem},
  author={Michael Janner and Qiyang Li and Sergey Levine},
  journal={arXiv preprint arXiv:2106.02039},
  year={2021},
}

Acknowledgements

The GPT implementation is from Andrej Karpathy's minGPT repo.

Experiments and examples converting Transformers to ONNX

Experiments and examples converting Transformers to ONNX This repository containes experiments and examples on converting different Transformers to ON

Philipp Schmid 4 Dec 24, 2022
ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation

ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation This repository contains the source code of our paper, ESPNet (acc

Sachin Mehta 515 Dec 13, 2022
A unified 3D Transformer Pipeline for visual synthesis

Overview This is the official repo for the paper: "NÜWA: Visual Synthesis Pre-training for Neural visUal World creAtion". NÜWA is a unified multimodal

Microsoft 2.6k Jan 03, 2023
Mixed Neural Likelihood Estimation for models of decision-making

Mixed neural likelihood estimation for models of decision-making Mixed neural likelihood estimation (MNLE) enables Bayesian parameter inference for mo

mackelab 9 Dec 22, 2022
Active learning for Mask R-CNN in Detectron2

MaskAL - Active learning for Mask R-CNN in Detectron2 Summary MaskAL is an active learning framework that automatically selects the most-informative i

49 Dec 20, 2022
Code for "The Box Size Confidence Bias Harms Your Object Detector"

The Box Size Confidence Bias Harms Your Object Detector - Code Disclaimer: This repository is for research purposes only. It is designed to maintain r

Johannes G. 24 Dec 07, 2022
QAT(quantize aware training) for classification with MQBench

MQBench Quantization Aware Training with PyTorch I am using MQBench(Model Quantization Benchmark)(http://mqbench.tech/) to quantize the model for depl

Ling Zhang 29 Nov 18, 2022
TumorInsight is a Brain Tumor Detection and Classification model built using RESNET50 architecture.

A Brain Tumor Detection and Classification Model built using RESNET50 architecture. The model is also deployed as a web application using Flask framework.

Pranav Khurana 0 Aug 17, 2021
Distance correlation and related E-statistics in Python

dcor dcor: distance correlation and related E-statistics in Python. E-statistics are functions of distances between statistical observations in metric

Carlos Ramos Carreño 108 Dec 27, 2022
Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-like Documents.

Value Retrieval with Arbitrary Queries for Form-like Documents Introduction Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-

Salesforce 13 Sep 15, 2022
A GPU-optional modular synthesizer in pytorch, 16200x faster than realtime, for audio ML researchers.

torchsynth The fastest synth in the universe. Introduction torchsynth is based upon traditional modular synthesis written in pytorch. It is GPU-option

torchsynth 229 Jan 02, 2023
using yolox+deepsort for object-tracker

YOLOX_deepsort_tracker yolox+deepsort实现目标跟踪 最新的yolox尝尝鲜~~(yolox正处在频繁更新阶段,因此直接链接yolox仓库作为子模块) Install Clone the repository recursively: git clone --rec

245 Dec 26, 2022
Self-Learning - Books Papers, Courses & more I have to learn soon

Self-Learning This repository is intended to be used for personal use, all rights reserved to respective owners, please cite original authors and ask

Achint Chaudhary 968 Jan 02, 2022
Code for BMVC2021 "MOS: A Low Latency and Lightweight Framework for Face Detection, Landmark Localization, and Head Pose Estimation"

MOS-Multi-Task-Face-Detect Introduction This repo is the official implementation of "MOS: A Low Latency and Lightweight Framework for Face Detection,

104 Dec 08, 2022
Retrieve and analysis data from SDSS (Sloan Digital Sky Survey)

Author: Behrouz Safari License: MIT sdss A python package for retrieving and analysing data from SDSS (Sloan Digital Sky Survey) Installation Install

Behrouz 3 Oct 28, 2022
i-RevNet Pytorch Code

i-RevNet: Deep Invertible Networks Pytorch implementation of i-RevNets. i-RevNets define a family of fully invertible deep networks, built from a succ

Jörn Jacobsen 378 Dec 06, 2022
Source code for paper "Deep Diffusion Models for Robust Channel Estimation", TBA.

diffusion-channels Source code for paper "Deep Diffusion Models for Robust Channel Estimation". Generic flow: Use 'matlab/main.mat' to generate traini

The University of Texas Computational Sensing and Imaging Lab 15 Dec 22, 2022
An open-source online reverse dictionary.

An open-source online reverse dictionary.

THUNLP 6.3k Jan 09, 2023
Official implementation of "Synthetic Temporal Anomaly Guided End-to-End Video Anomaly Detection" (ICCV Workshops 2021: RSL-CV).

Official PyTorch implementation of "Synthetic Temporal Anomaly Guided End-to-End Video Anomaly Detection" This is the implementation of the paper "Syn

Marcella Astrid 11 Oct 07, 2022
Self Governing Neural Networks (SGNN): the Projection Layer

Self Governing Neural Networks (SGNN): the Projection Layer A SGNN's word projections preprocessing pipeline in scikit-learn In this notebook, we'll u

Guillaume Chevalier 22 Nov 06, 2022