Code for the paper "Offline Reinforcement Learning as One Big Sequence Modeling Problem"

Overview

Trajectory Transformer

Code release for Offline Reinforcement Learning as One Big Sequence Modeling Problem.

Installation

All python dependencies are in environment.yml. Install with:

conda env create -f environment.yml
conda activate trajectory
pip install -e .

For reproducibility, we have also included system requirements in a Dockerfile (see installation instructions), but the conda installation should work on most standard Linux machines.

Usage

Train a transformer with: python scripts/train.py --dataset halfcheetah-medium-v2

To reproduce the offline RL results: python scripts/plan.py --dataset halfcheetah-medium-v2

By default, these commands will use the hyperparameters in config/offline.py. You can override them with runtime flags:

python scripts/plan.py --dataset halfcheetah-medium-v2 \
	--horizon 5 --beam_width 32

A few hyperparameters are different from those listed in the paper because of changes to the discretization strategy. These hyperparameters will be updated in the next arxiv version to match what is currently in the codebase.

Pretrained models

We have provided pretrained models for 16 datasets: {halfcheetah, hopper, walker2d, ant}-{expert-v2, medium-expert-v2, medium-v2, medium-replay-v2}. Download them with ./pretrained.sh

The models will be saved in logs/$DATASET/gpt/pretrained. To plan with these models, refer to them using the gpt_loadpath flag:

python scripts/plan.py --dataset halfcheetah-medium-v2 \
	--gpt_loadpath gpt/pretrained

pretrained.sh will also download 15 plans from each model, saved to logs/$DATASET/plans/pretrained. Read them with python plotting/read_results.py.

To create the table of offline RL results from the paper, run python plotting/table.py. This will print a table that can be copied into a Latex document. (Expand to view table source.)
\begin{table*}[h]
\centering
\small
\begin{tabular}{llrrrrrr}
\toprule
\multicolumn{1}{c}{\bf Dataset} & \multicolumn{1}{c}{\bf Environment} & \multicolumn{1}{c}{\bf BC} & \multicolumn{1}{c}{\bf MBOP} & \multicolumn{1}{c}{\bf BRAC} & \multicolumn{1}{c}{\bf CQL} & \multicolumn{1}{c}{\bf DT} & \multicolumn{1}{c}{\bf TT (Ours)} \\
\midrule
Medium-Expert & HalfCheetah & $59.9$ & $105.9$ & $41.9$ & $91.6$ & $86.8$ & $95.0$ \scriptsize{\raisebox{1pt}{$\pm 0.2$}} \\
Medium-Expert & Hopper & $79.6$ & $55.1$ & $0.9$ & $105.4$ & $107.6$ & $110.0$ \scriptsize{\raisebox{1pt}{$\pm 2.7$}} \\
Medium-Expert & Walker2d & $36.6$ & $70.2$ & $81.6$ & $108.8$ & $108.1$ & $101.9$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\
Medium-Expert & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $116.1$ \scriptsize{\raisebox{1pt}{$\pm 9.0$}} \\
\midrule
Medium & HalfCheetah & $43.1$ & $44.6$ & $46.3$ & $44.0$ & $42.6$ & $46.9$ \scriptsize{\raisebox{1pt}{$\pm 0.4$}} \\
Medium & Hopper & $63.9$ & $48.8$ & $31.3$ & $58.5$ & $67.6$ & $61.1$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\
Medium & Walker2d & $77.3$ & $41.0$ & $81.1$ & $72.5$ & $74.0$ & $79.0$ \scriptsize{\raisebox{1pt}{$\pm 2.8$}} \\
Medium & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $83.1$ \scriptsize{\raisebox{1pt}{$\pm 7.3$}} \\
\midrule
Medium-Replay & HalfCheetah & $4.3$ & $42.3$ & $47.7$ & $45.5$ & $36.6$ & $41.9$ \scriptsize{\raisebox{1pt}{$\pm 2.5$}} \\
Medium-Replay & Hopper & $27.6$ & $12.4$ & $0.6$ & $95.0$ & $82.7$ & $91.5$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\
Medium-Replay & Walker2d & $36.9$ & $9.7$ & $0.9$ & $77.2$ & $66.6$ & $82.6$ \scriptsize{\raisebox{1pt}{$\pm 6.9$}} \\
Medium-Replay & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $77.0$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\
\midrule
\multicolumn{2}{c}{\bf Average (without Ant)} & 47.7 & 47.8 & 36.9 & 77.6 & 74.7 & 78.9 \hspace{.6cm} \\
\multicolumn{2}{c}{\bf Average (all settings)} & $-$ & $-$ & $-$ & $-$ & $-$ & 82.2 \hspace{.6cm} \\
\bottomrule
\end{tabular}
\label{table:d4rl}
\end{table*}

To create the average performance plot, run python plotting/plot.py. (Expand to view plot.)

Docker

Copy your MuJoCo key to the Docker build context and build the container:

cp ~/.mujoco/mjkey.txt azure/files/
docker build -f azure/Dockerfile . -t trajectory

Test the container:

docker run -it --rm --gpus all \
	--mount type=bind,source=$PWD,target=/home/code \
	--mount type=bind,source=$HOME/.d4rl,target=/root/.d4rl \
	trajectory \
	bash -c \
	"export PYTHONPATH=$PYTHONPATH:/home/code && \
	python /home/code/scripts/train.py --dataset hopper-medium-expert-v2 --exp_name docker/"

Running on Azure

Setup

  1. Launching jobs on Azure requires one more python dependency:
pip install git+https://github.com/JannerM/[email protected]
  1. Tag the image built in the previous section and push it to Docker Hub:
export DOCKER_USERNAME=$(docker info | sed '/Username:/!d;s/.* //')
docker tag trajectory ${DOCKER_USERNAME}/trajectory:latest
docker image push ${DOCKER_USERNAME}/trajectory
  1. Update azure/config.py, either by modifying the file directly or setting the relevant environment variables. To set the AZURE_STORAGE_CONNECTION variable, navigate to the Access keys section of your storage account. Click Show keys and copy the Connection string.

  2. Download azcopy: ./azure/download.sh

Usage

Launch training jobs with python azure/launch_train.py and planning jobs with python azure/launch_plan.py.

These scripts do not take runtime arguments. Instead, they run the corresponding scripts (scripts/train.py and scripts/plan.py, respectively) using the Cartesian product of the parameters in params_to_sweep.

Viewing results

To rsync the results from the Azure storage container, run ./azure/sync.sh.

To mount the storage container:

  1. Create a blobfuse config with ./azure/make_fuse_config.sh
  2. Run ./azure/mount.sh to mount the storage container to ~/azure_mount

To unmount the container, run sudo umount -f ~/azure_mount; rm -r ~/azure_mount

Reference

@inproceedings{janner2021sequence,
  title = {Offline Reinforcement Learning as One Big Sequence Modeling Problem},
  author = {Michael Janner and Qiyang Li and Sergey Levine},
  booktitle = {Advances in Neural Information Processing Systems},
  year = {2021},
}

Acknowledgements

The GPT implementation is from Andrej Karpathy's minGPT repo.

Hydra Lightning Template for Structured Configs

Hydra Lightning Template for Structured Configs Template for creating projects with pytorch-lightning and hydra. How to use this template? Create your

Model-driven Machine Learning 4 Jul 19, 2022
The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021]. https://arxiv.org/pdf/2101.12378.pdf

NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021] Release Notes The offical PyTorch implementation of NeMo, p

Angtian Wang 76 Nov 23, 2022
Self-Supervised Pillar Motion Learning for Autonomous Driving (CVPR 2021)

Self-Supervised Pillar Motion Learning for Autonomous Driving Chenxu Luo, Xiaodong Yang, Alan Yuille Self-Supervised Pillar Motion Learning for Autono

QCraft 101 Dec 05, 2022
Free Book about Deep-Learning approaches for Chess (like AlphaZero, Leela Chess Zero and Stockfish NNUE)

Free Book about Deep-Learning approaches for Chess (like AlphaZero, Leela Chess Zero and Stockfish NNUE)

Dominik Klein 189 Dec 21, 2022
Code and data of the Fine-Grained R2R Dataset proposed in paper Sub-Instruction Aware Vision-and-Language Navigation

Fine-Grained R2R Code and data of the Fine-Grained R2R Dataset proposed in the EMNLP2020 paper Sub-Instruction Aware Vision-and-Language Navigation. C

YicongHong 34 Nov 15, 2022
Turning pixels into virtual points for multimodal 3D object detection.

Multimodal Virtual Point 3D Detection Turning pixels into virtual points for multimodal 3D object detection. Multimodal Virtual Point 3D Detection, Ti

Tianwei Yin 204 Jan 08, 2023
Official implementation of Rethinking Graph Neural Architecture Search from Message-passing (CVPR2021)

Rethinking Graph Neural Architecture Search from Message-passing Intro The GNAS can automatically learn better architecture with the optimal depth of

Shaofei Cai 48 Sep 30, 2022
The Python3 import playground

The Python3 import playground I have been confused about python modules and packages, this text tries to clear the topic up a bit. Sources: https://ch

Michael Moser 5 Feb 22, 2022
Pytorch implementation of Supporting Clustering with Contrastive Learning, NAACL 2021

Supporting Clustering with Contrastive Learning SCCL (NAACL 2021) Dejiao Zhang, Feng Nan, Xiaokai Wei, Shangwen Li, Henghui Zhu, Kathleen McKeown, Ram

231 Jan 05, 2023
Pytorch implementation of DeepMind's differentiable neural computer paper.

DNC pytorch This is a Pytorch implementation of DeepMind's Differentiable Neural Computer (DNC) architecture introduced in their recent Nature paper:

Yuanpu Xie 91 Nov 21, 2022
Code for "AutoMTL: A Programming Framework for Automated Multi-Task Learning"

AutoMTL: A Programming Framework for Automated Multi-Task Learning This is the website for our paper "AutoMTL: A Programming Framework for Automated M

Ivy Zhang 40 Dec 04, 2022
PyTorch implementation of PNASNet-5 on ImageNet

PNASNet.pytorch PyTorch implementation of PNASNet-5. Specifically, PyTorch code from this repository is adapted to completely match both my implemetat

Chenxi Liu 314 Nov 25, 2022
MultiTaskLearning - Multi Task Learning for 3D segmentation

Multi Task Learning for 3D segmentation Perception stack of an Autonomous Drivin

2 Sep 22, 2022
NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

5 Nov 03, 2022
NCVX (NonConVeX): A User-Friendly and Scalable Package for Nonconvex Optimization in Machine Learning.

NCVX NCVX: A User-Friendly and Scalable Package for Nonconvex Optimization in Machine Learning. Please check https://ncvx.org for detailed instruction

SUN Group @ UMN 28 Aug 03, 2022
A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution.

Awesome Pretrained StyleGAN2 A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution. Note the readme is a

Justin 1.1k Dec 24, 2022
Python based Advanced AI Assistant

Knick is a virtual artificial intelligence project, fully developed in python. The objective of this project is to develop a virtual assistant that can handle our minor, intermediate as well as heavy

19 Nov 15, 2022
Save-restricted-v-3 - Save restricted content Bot For telegram

Save restricted content Bot Contact: Telegram A stable telegram bot to get restr

DEVANSH 11 Dec 21, 2022
Mmdet benchmark with python

mmdet_benchmark 本项目是为了研究 mmdet 推断性能瓶颈,并且对其进行优化。 配置与环境 机器配置 CPU:Intel(R) Core(TM) i9-10900K CPU @ 3.70GHz GPU:NVIDIA GeForce RTX 3080 10GB 内存:64G 硬盘:1T

杨培文 (Yang Peiwen) 24 May 21, 2022
Beginner-friendly repository for Hacktober Fest 2021. Start your contribution to open source through baby steps. 💜

Hacktober Fest 2021 🎉 Open source is changing the world – one contribution at a time! 🎉 This repository is made for beginners who are unfamiliar wit

Abhilash M Nair 32 Dec 11, 2022