Generalized Decision Transformer for Offline Hindsight Information Matching

Overview

Generalized Decision Transformer for Offline Hindsight Information Matching

[arxiv]

If you use this codebase for your research, please cite the paper:

@article{furuta2021generalized,
  title={Generalized Decision Transformer for Offline Hindsight Information Matching},
  author={Hiroki Furuta and Yutaka Matsuo and Shixiang Shane Gu},
  journal={arXiv preprint arXiv:2111.10364},
  year={2021}
}

Installation

Experiments require MuJoCo. Follow the instructions in the mujoco-py repo to install. Then, dependencies can be installed with the following command:

conda env create -f conda_env.yml

Downloading datasets

Datasets are stored in the data directory. Install the D4RL repo, following the instructions there. Then, run the following script in order to download the datasets and save them in our format:

python download_d4rl_datasets.py

Run experiments

Run train_cdt.py to train Categorical DT:

python train_cdt.py --env halfcheetah --dataset medium-expert --gpu 0 --seed 0 --dist_dim 30 --n_bins 31 --condition 'reward' --save_model True

python train_cdt.py --env halfcheetah --dataset medium-expert --gpu 0 --seed 0 --dist_dim 30 --n_bins 31 --condition 'xvel' --save_model True

Run eval_cdt.py to eval CDT using saved weights:

python eval_cdt.py --env halfcheetah --dataset medium-expert --gpu 0 --seed 0 --dist_dim 30 --n_bins 31 --condition 'reward' --save_rollout True
python eval_cdt.py --env halfcheetah --dataset medium-expert --gpu 0 --seed 0 --dist_dim 30 --n_bins 31 --condition 'xvel' --save_rollout True

For Bi-directional DT, run train_bdt.py & eval_bdtf.py

python train_bdt.py --env halfcheetah --dataset medium-expert --gpu 0 --seed 0 --dist_dim 30 --n_bins 31 --z_dim 16 --save_model True
python eval_bdt.py --env halfcheetah --dataset medium-expert --gpu 0 --seed 0 --dist_dim 30 --n_bins 31 --z_dim 16 --save_rollout True

Reference

This repository is developed on top of original Decision Transformer.

Owner
Hiroki Furuta
The University of Tokyo/Reinforcement Learning
Hiroki Furuta
ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation

ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation This repository provides a PyTorch implementation of ADSPM. Requirements Pyth

24 Jul 24, 2022
Self-Supervised Pre-Training for Transformer-Based Person Re-Identification

Self-Supervised Pre-Training for Transformer-Based Person Re-Identification [pdf] The official repository for Self-Supervised Pre-Training for Transfo

Hao Luo 116 Jan 04, 2023
PyTorch implementation for Convolutional Networks with Adaptive Inference Graphs

Convolutional Networks with Adaptive Inference Graphs (ConvNet-AIG) This repository contains a PyTorch implementation of the paper Convolutional Netwo

Andreas Veit 176 Dec 07, 2022
Videocaptioning.pytorch - A simple implementation of video captioning

pytorch implementation of video captioning recommend installing pytorch and pyth

Yiyu Wang 2 Jan 01, 2022
Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks

Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks This repository contains a TensorFlow implementation of "

Jingwei Zheng 5 Jan 08, 2023
Pneumonia Detection using machine learning - with PyTorch

Pneumonia Detection Pneumonia Detection using machine learning. Training was done in colab: DEMO: Result (Confusion Matrix): Data I uploaded my datase

Wilhelm Berghammer 12 Jul 07, 2022
This is project is the implementation of the DeepShift: Towards Multiplication-Less Neural Networks paper

DeepShift This is project is the implementation of the DeepShift: Towards Multiplication-Less Neural Networks paper, that aims to replace multiplicati

Mostafa Elhoushi 88 Dec 23, 2022
🌎 The Modern Declarative Data Flow Framework for the AI Empowered Generation.

🌎 JSONClasses JSONClasses is a declarative data flow pipeline and data graph framework. Official Website: https://www.jsonclasses.com Official Docume

Fillmula Inc. 53 Dec 09, 2022
Faune proche - Retrieval of Faune-France data near a google maps location

faune_proche Récupération des données de Faune-France près d'un lieu google maps

4 Feb 15, 2022
Tensorflow implementation of the paper "HumanGPS: Geodesic PreServing Feature for Dense Human Correspondences", CVPR 2021.

HumanGPS: Geodesic PreServing Feature for Dense Human Correspondences Tensorflow implementation of the paper "HumanGPS: Geodesic PreServing Feature fo

Google Interns 50 Dec 21, 2022
Code for paper Adaptively Aligned Image Captioning via Adaptive Attention Time

Adaptively Aligned Image Captioning via Adaptive Attention Time This repository includes the implementation for Adaptively Aligned Image Captioning vi

Lun Huang 45 Aug 27, 2022
Experimenting with computer vision techniques to generate annotated image datasets from gameplay recordings automatically.

Experimenting with computer vision techniques to generate annotated image datasets from gameplay recordings automatically. The collected data will then be used to train a deep neural network that can

Martin Valchev 3 Apr 24, 2022
Adaptive Dropblock Enhanced GenerativeAdversarial Networks for Hyperspectral Image Classification

This repo holds the codes of our paper: Adaptive Dropblock Enhanced GenerativeAdversarial Networks for Hyperspectral Image Classification, which is ac

Feng Gao 17 Dec 28, 2022
This is the official code of our paper "Diversity-based Trajectory and Goal Selection with Hindsight Experience Relay" (PRICAI 2021)

Diversity-based Trajectory and Goal Selection with Hindsight Experience Replay This is the official implementation of our paper "Diversity-based Traje

Tianhong Dai 6 Jul 18, 2022
Iris prediction model is used to classify iris species created julia's DecisionTree, DataFrames, JLD2, PlotlyJS and Statistics packages.

Iris Species Predictor Iris prediction is used to classify iris species using their sepal length, sepal width, petal length and petal width created us

Siva Prakash 2 Jan 06, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intenti

NVIDIA Corporation 6.9k Jan 03, 2023
This is the official pytorch implementation of AutoDebias, an automatic debiasing method for recommendation.

AutoDebias This is the official pytorch implementation of AutoDebias, a debiasing method for recommendation system. AutoDebias is proposed in the pape

Dong Hande 77 Nov 25, 2022
Contenido del curso Bases de datos del DCC PUC versión 2021-2

IIC2413 - Bases de Datos Tabla de contenidos Equipo Profesores Ayudantes Contenidos Calendario Evaluaciones Resumen de notas Foro Política de integrid

54 Nov 23, 2022
Classify the disease status of a plant given an image of a passion fruit

Passion Fruit Disease Detection I tried to create an accurate machine learning models capable of localizing and identifying multiple Passion Fruits in

3 Nov 09, 2021
Air Quality Prediction Using LSTM

AirQualityPredictionUsingLSTM In this Repo, i present to you the winning solution of smart gujarat hackathon 2019 where the task was to predict the qu

Deepak Nandwani 2 Dec 13, 2022