Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

Related tags

Deep Learningcql-jax
Overview

CQL-JAX

This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on top of the SAC base of JAX-RL.

Usage

Install Dependencies-

pip install -r requirements.txt
pip install "jax[cuda111]<=0.21.1" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Run CQL-

python train_offline.py --env_name=hopper-expert-v0 --min_q_weight=5

Please use the following values of min_q_weight on MuJoCo tasks to reproduce CQL results from IQL paper-

Domain medium medium-replay medium-expert
walker 10 1 10
hopper 5 5 1
cheetah 90 80 100

For antmaze tasks min_q_weight=10 is found to work best.

In case of Out-Of Memory errors in JAX, try running with the following env variables-

XLA_PYTHON_CLIENT_MEM_FRACTION=0.80 python ...
XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 python ...

Performance & Runtime

Returns are more or less same as the torch implementation and comparable to IQL-

Task CQL(PyTorch) CQL(JAX) IQL
hopper-medium-v2 58.5 74.6 66.3
hopper-medium-replay-v2 95.0 92.1 94.7
hopper-medium-expert-v2 105.4 83.2 91.5
antmaze-umaze-v0 74.0 69.5 87.5
antmaze-umaze-diverse-v0 84.0 78.7 62.2
antmaze-medium-play-v0 61.2 14.2 71.2
antmaze-medium-diverse-v0 53.7 10.7 70.2
antmaze-large-play-v0 15.8 0.0 39.6
antmaze-large-diverse-v0 14.9 0.0 47.5

Wall-clock time averages to ~50 mins, improving over IQL paper's 80 min CQL and closing the gap with IQL's 20 min.

Task CQL(JAX) IQL
hopper-medium-v2 52 27
hopper-medium-replay-v2 54 30
hopper-medium-expert-v2 57 29

Time efficiency over the original torch implementation is more than 4 times.

For more offline RL algorithm implementations, check out the JAX-RL, IQL and rlkit repositories.

Citation

In case you use CQL-JAX for your research, please cite the following-

@misc{cqljax,
  author = {Suri, Karush},
  title = {{Conservative Q Learning in JAX.}},
  url = {https://github.com/karush17/cql-jax},
  year = {2021}
}

References

Owner
Karush Suri
Deep Learning Researcher at Huawei Noah's Ark Lab, Toronto.
Karush Suri
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022
Fair Recommendation in Two-Sided Platforms

Fair Recommendation in Two-Sided Platforms

gourabgggg 1 Nov 10, 2021
A python script to lookup Passport Index Dataset

visa-cli A python script to lookup Passport Index Dataset Installation pip install visa-cli Usage usage: visa-cli [-h] [-d DESTINATION_COUNTRY] [-f]

rand-net 16 Oct 18, 2022
STEAL - Learning Semantic Boundaries from Noisy Annotations (CVPR 2019)

STEAL This is the official inference code for: Devil Is in the Edges: Learning Semantic Boundaries from Noisy Annotations David Acuna, Amlan Kar, Sanj

469 Dec 26, 2022
Codes for NeurIPS 2021 paper "On the Equivalence between Neural Network and Support Vector Machine".

On the Equivalence between Neural Network and Support Vector Machine Codes for NeurIPS 2021 paper "On the Equivalence between Neural Network and Suppo

Leslie 8 Oct 25, 2022
Official code release for "GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis"

GRAF This repository contains official code for the paper GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis. You can find detailed usage i

349 Dec 29, 2022
🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

xmu-xiaoma66 7.7k Jan 05, 2023
Synthesize photos from PhotoDNA using machine learning 🌱

Ribosome Synthesize photos from PhotoDNA. See the blog post for more information. Installation Dependencies You can install Python dependencies using

Anish Athalye 112 Nov 23, 2022
Source codes for Improved Few-Shot Visual Classification (CVPR 2020), Enhancing Few-Shot Image Classification with Unlabelled Examples

Source codes for Improved Few-Shot Visual Classification (CVPR 2020), Enhancing Few-Shot Image Classification with Unlabelled Examples (WACV 2022) and Beyond Simple Meta-Learning: Multi-Purpose Model

PLAI Group at UBC 42 Dec 06, 2022
Robotics environments

Robotics environments Details and documentation on these robotics environments are available in OpenAI's blog post and the accompanying technical repo

Farama Foundation 121 Dec 28, 2022
This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021.

Off-Belief Learning Introduction This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021. Environment Setup

Facebook Research 32 Jan 05, 2023
SOTA easy to use PyTorch-based DL training library

Easily train or fine-tune SOTA computer vision models from one training repository. SuperGradients Introduction Welcome to SuperGradients, a free open

619 Jan 03, 2023
ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

Zongdai 107 Dec 20, 2022
Official PyTorch implementation of BlobGAN: Spatially Disentangled Scene Representations

BlobGAN: Spatially Disentangled Scene Representations Official PyTorch Implementation Paper | Project Page | Video | Interactive Demo BlobGAN.mp4 This

148 Dec 29, 2022
Code related to the manuscript "Averting A Crisis In Simulation-Based Inference"

Abstract We present extensive empirical evidence showing that current Bayesian simulation-based inference algorithms are inadequate for the falsificat

Montefiore Artificial Intelligence Research 3 Nov 14, 2022
Gym environment for FLIPIT: The Game of "Stealthy Takeover"

gym-flipit Gym environment for FLIPIT: The Game of "Stealthy Takeover" invented by Marten van Dijk, Ari Juels, Alina Oprea, and Ronald L. Rivest. Desi

Lisa Oakley 2 Dec 15, 2021
Code for the paper BERT might be Overkill: A Tiny but Effective Biomedical Entity Linker based on Residual Convolutional Neural Networks

Biomedical Entity Linking This repo provides the code for the paper BERT might be Overkill: A Tiny but Effective Biomedical Entity Linker based on Res

Tuan Manh Lai 24 Oct 24, 2022
AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation

AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation AniGAN: Style-Guided Generative Adversarial Networks for U

Bing Li 81 Dec 14, 2022
[SDM 2022] Towards Similarity-Aware Time-Series Classification

SimTSC This is the PyTorch implementation of SDM2022 paper Towards Similarity-Aware Time-Series Classification. We propose Similarity-Aware Time-Serie

Daochen Zha 49 Dec 27, 2022
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022