PyTorch implementation of Advantage Actor Critic (A2C), Proximal Policy Optimization (PPO), Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation (ACKTR) and Generative Adversarial Imitation Learning (GAIL).

Overview

pytorch-a2c-ppo-acktr

Update (April 12th, 2021)

PPO is great, but Soft Actor Critic can be better for many continuous control tasks. Please check out my new RL repository in jax.

Please use hyper parameters from this readme. With other hyper parameters things might not work (it's RL after all)!

This is a PyTorch implementation of

  • Advantage Actor Critic (A2C), a synchronous deterministic version of A3C
  • Proximal Policy Optimization PPO
  • Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation ACKTR
  • Generative Adversarial Imitation Learning GAIL

Also see the OpenAI posts: A2C/ACKTR and PPO for more information.

This implementation is inspired by the OpenAI baselines for A2C, ACKTR and PPO. It uses the same hyper parameters and the model since they were well tuned for Atari games.

Please use this bibtex if you want to cite this repository in your publications:

@misc{pytorchrl,
  author = {Kostrikov, Ilya},
  title = {PyTorch Implementations of Reinforcement Learning Algorithms},
  year = {2018},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail}},
}

Supported (and tested) environments (via OpenAI Gym)

I highly recommend PyBullet as a free open source alternative to MuJoCo for continuous control tasks.

All environments are operated using exactly the same Gym interface. See their documentations for a comprehensive list.

To use the DeepMind Control Suite environments, set the flag --env-name dm.<domain_name>.<task_name>, where domain_name and task_name are the name of a domain (e.g. hopper) and a task within that domain (e.g. stand) from the DeepMind Control Suite. Refer to their repo and their tech report for a full list of available domains and tasks. Other than setting the task, the API for interacting with the environment is exactly the same as for all the Gym environments thanks to dm_control2gym.

Requirements

In order to install requirements, follow:

# PyTorch
conda install pytorch torchvision -c soumith

# Other requirements
pip install -r requirements.txt

Contributions

Contributions are very welcome. If you know how to make this code better, please open an issue. If you want to submit a pull request, please open an issue first. Also see a todo list below.

Also I'm searching for volunteers to run all experiments on Atari and MuJoCo (with multiple random seeds).

Disclaimer

It's extremely difficult to reproduce results for Reinforcement Learning methods. See "Deep Reinforcement Learning that Matters" for more information. I tried to reproduce OpenAI results as closely as possible. However, majors differences in performance can be caused even by minor differences in TensorFlow and PyTorch libraries.

TODO

  • Improve this README file. Rearrange images.
  • Improve performance of KFAC, see kfac.py for more information
  • Run evaluation for all games and algorithms

Visualization

In order to visualize the results use visualize.ipynb.

Training

Atari

A2C

python main.py --env-name "PongNoFrameskip-v4"

PPO

python main.py --env-name "PongNoFrameskip-v4" --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01

ACKTR

python main.py --env-name "PongNoFrameskip-v4" --algo acktr --num-processes 32 --num-steps 20

MuJoCo

Please always try to use --use-proper-time-limits flag. It properly handles partial trajectories (see https://github.com/sfujim/TD3/blob/master/main.py#L123).

A2C

python main.py --env-name "Reacher-v2" --num-env-steps 1000000

PPO

python main.py --env-name "Reacher-v2" --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps 1000000 --use-linear-lr-decay --use-proper-time-limits

ACKTR

ACKTR requires some modifications to be made specifically for MuJoCo. But at the moment, I want to keep this code as unified as possible. Thus, I'm going for better ways to integrate it into the codebase.

Enjoy

Load a pretrained model from my Google Drive.

Also pretrained models for other games are available on request. Send me an email or create an issue, and I will upload it.

Disclaimer: I might have used different hyper-parameters to train these models.

Atari

python enjoy.py --load-dir trained_models/a2c --env-name "PongNoFrameskip-v4"

MuJoCo

python enjoy.py --load-dir trained_models/ppo --env-name "Reacher-v2"

Results

A2C

BreakoutNoFrameskip-v4

SeaquestNoFrameskip-v4

QbertNoFrameskip-v4

beamriderNoFrameskip-v4

PPO

BreakoutNoFrameskip-v4

SeaquestNoFrameskip-v4

QbertNoFrameskip-v4

beamriderNoFrameskip-v4

ACKTR

BreakoutNoFrameskip-v4

SeaquestNoFrameskip-v4

QbertNoFrameskip-v4

beamriderNoFrameskip-v4

Owner
Ilya Kostrikov
Post doc
Ilya Kostrikov
OpenMMLab Detection Toolbox and Benchmark

MMDetection is an open source object detection toolbox based on PyTorch. It is a part of the OpenMMLab project.

OpenMMLab 22.5k Jan 05, 2023
ONNX-PackNet-SfM: Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Ibai Gorordo 14 Dec 09, 2022
[IEEE TPAMI21] MobileSal: Extremely Efficient RGB-D Salient Object Detection [PyTorch & Jittor]

MobileSal IEEE TPAMI 2021: MobileSal: Extremely Efficient RGB-D Salient Object Detection This repository contains full training & testing code, and pr

Yu-Huan Wu 52 Jan 06, 2023
Fashion Entity Classification

Fashion-Entity-Classification - Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grays

ADITYA SHAH 1 Jan 04, 2022
Virtual Dance Reality Stage: a feature that offers you to share a stage with another user virtually

Portrait Segmentation using Tensorflow This script removes the background from an input image. You can read more about segmentation here Setup The scr

291 Dec 24, 2022
Snscrape-jsonl-urls-extractor - Extracts urls from jsonl produced by snscrape

snscrape-jsonl-urls-extractor extracts urls from jsonl produced by snscrape Usag

1 Feb 26, 2022
Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

SE3 Transformer - Pytorch Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 resu

Phil Wang 207 Dec 23, 2022
Data-Uncertainty Guided Multi-Phase Learning for Semi-supervised Object Detection

An official implementation of paper Data-Uncertainty Guided Multi-Phase Learning for Semi-supervised Object Detection

11 Nov 23, 2022
FinEAS: Financial Embedding Analysis of Sentiment 📈

FinEAS: Financial Embedding Analysis of Sentiment 📈 (SentenceBERT for Financial News Sentiment Regression) This repository contains the code for gene

LHF Labs 31 Dec 13, 2022
End-To-End Memory Network using Tensorflow

MemN2N Implementation of End-To-End Memory Networks with sklearn-like interface using Tensorflow. Tasks are from the bAbl dataset. Get Started git clo

Dominique Luna 339 Oct 27, 2022
Social Distancing Detector

Computer vision has opened up a lot of opportunities to explore into AI domain that were earlier highly limited. Here is an application of haarcascade classifier and OpenCV to develop a social distan

Ashish Pandey 2 Jul 18, 2022
Colab notebook for openai/glide-text2im.

GLIDE text2im on Colab This repository provides a Colab notebook to produce images conditioned on text prompts with GLIDE [1]. Usage Run text2im.ipynb

Wok 19 Oct 19, 2022
Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022

PyCRE Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022 Dependencies This project is developed

<a href=[email protected]"> 7 May 06, 2022
Code Impementation for "Mold into a Graph: Efficient Bayesian Optimization over Mixed Spaces"

Code Impementation for "Mold into a Graph: Efficient Bayesian Optimization over Mixed Spaces" This repo contains the implementation of GEBO algorithm.

Jaeyeon Ahn 2 Mar 22, 2022
Implementation of the Swin Transformer in PyTorch.

Swin Transformer - PyTorch Implementation of the Swin Transformer architecture. This paper presents a new vision Transformer, called Swin Transformer,

597 Jan 03, 2023
Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)

This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support mnist, svhn cifar10, cifar100 st

Aaron Chen 2.4k Dec 28, 2022
MAU: A Motion-Aware Unit for Video Prediction and Beyond, NeurIPS2021

MAU (NeurIPS2021) Zheng Chang, Xinfeng Zhang, Shanshe Wang, Siwei Ma, Yan Ye, Xinguang Xiang, Wen GAo. Official PyTorch Code for "MAU: A Motion-Aware

ZhengChang 20 Nov 25, 2022
CM building dataset Timisoara

CM_building_dataset_Timisoara Date created: Febr-2020 The Timi\c{s}oara Building Dataset - TMBuD - is composed of 160 images with the resolution of 76

Orhei Ciprian 5 Sep 07, 2022
This repo implements a 3D segmentation task for an airport baggage dataset.

3D CT Scan Segmentation With Occupancy Network This repo implements a 3D superresolution segmentation task for an airport baggage dataset. Our final p

Christoph Reich 2 Mar 28, 2022
Learning to Stylize Novel Views

Learning to Stylize Novel Views [Project] [Paper] Contact: Hsin-Ping Huang ([ema

34 Nov 27, 2022