PyTorch implementation of Advantage Actor Critic (A2C), Proximal Policy Optimization (PPO), Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation (ACKTR) and Generative Adversarial Imitation Learning (GAIL).

Overview

pytorch-a2c-ppo-acktr

Update (April 12th, 2021)

PPO is great, but Soft Actor Critic can be better for many continuous control tasks. Please check out my new RL repository in jax.

Please use hyper parameters from this readme. With other hyper parameters things might not work (it's RL after all)!

This is a PyTorch implementation of

  • Advantage Actor Critic (A2C), a synchronous deterministic version of A3C
  • Proximal Policy Optimization PPO
  • Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation ACKTR
  • Generative Adversarial Imitation Learning GAIL

Also see the OpenAI posts: A2C/ACKTR and PPO for more information.

This implementation is inspired by the OpenAI baselines for A2C, ACKTR and PPO. It uses the same hyper parameters and the model since they were well tuned for Atari games.

Please use this bibtex if you want to cite this repository in your publications:

@misc{pytorchrl,
  author = {Kostrikov, Ilya},
  title = {PyTorch Implementations of Reinforcement Learning Algorithms},
  year = {2018},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail}},
}

Supported (and tested) environments (via OpenAI Gym)

I highly recommend PyBullet as a free open source alternative to MuJoCo for continuous control tasks.

All environments are operated using exactly the same Gym interface. See their documentations for a comprehensive list.

To use the DeepMind Control Suite environments, set the flag --env-name dm.<domain_name>.<task_name>, where domain_name and task_name are the name of a domain (e.g. hopper) and a task within that domain (e.g. stand) from the DeepMind Control Suite. Refer to their repo and their tech report for a full list of available domains and tasks. Other than setting the task, the API for interacting with the environment is exactly the same as for all the Gym environments thanks to dm_control2gym.

Requirements

In order to install requirements, follow:

# PyTorch
conda install pytorch torchvision -c soumith

# Other requirements
pip install -r requirements.txt

Contributions

Contributions are very welcome. If you know how to make this code better, please open an issue. If you want to submit a pull request, please open an issue first. Also see a todo list below.

Also I'm searching for volunteers to run all experiments on Atari and MuJoCo (with multiple random seeds).

Disclaimer

It's extremely difficult to reproduce results for Reinforcement Learning methods. See "Deep Reinforcement Learning that Matters" for more information. I tried to reproduce OpenAI results as closely as possible. However, majors differences in performance can be caused even by minor differences in TensorFlow and PyTorch libraries.

TODO

  • Improve this README file. Rearrange images.
  • Improve performance of KFAC, see kfac.py for more information
  • Run evaluation for all games and algorithms

Visualization

In order to visualize the results use visualize.ipynb.

Training

Atari

A2C

python main.py --env-name "PongNoFrameskip-v4"

PPO

python main.py --env-name "PongNoFrameskip-v4" --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01

ACKTR

python main.py --env-name "PongNoFrameskip-v4" --algo acktr --num-processes 32 --num-steps 20

MuJoCo

Please always try to use --use-proper-time-limits flag. It properly handles partial trajectories (see https://github.com/sfujim/TD3/blob/master/main.py#L123).

A2C

python main.py --env-name "Reacher-v2" --num-env-steps 1000000

PPO

python main.py --env-name "Reacher-v2" --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps 1000000 --use-linear-lr-decay --use-proper-time-limits

ACKTR

ACKTR requires some modifications to be made specifically for MuJoCo. But at the moment, I want to keep this code as unified as possible. Thus, I'm going for better ways to integrate it into the codebase.

Enjoy

Load a pretrained model from my Google Drive.

Also pretrained models for other games are available on request. Send me an email or create an issue, and I will upload it.

Disclaimer: I might have used different hyper-parameters to train these models.

Atari

python enjoy.py --load-dir trained_models/a2c --env-name "PongNoFrameskip-v4"

MuJoCo

python enjoy.py --load-dir trained_models/ppo --env-name "Reacher-v2"

Results

A2C

BreakoutNoFrameskip-v4

SeaquestNoFrameskip-v4

QbertNoFrameskip-v4

beamriderNoFrameskip-v4

PPO

BreakoutNoFrameskip-v4

SeaquestNoFrameskip-v4

QbertNoFrameskip-v4

beamriderNoFrameskip-v4

ACKTR

BreakoutNoFrameskip-v4

SeaquestNoFrameskip-v4

QbertNoFrameskip-v4

beamriderNoFrameskip-v4

Owner
Ilya Kostrikov
Post doc
Ilya Kostrikov
PyTorch Implementation of Vector Quantized Variational AutoEncoders.

Pytorch implementation of VQVAE. This paper combines 2 tricks: Vector Quantization (check out this amazing blog for better understanding.) Straight-Th

Vrushank Changawala 2 Oct 06, 2021
This program writes christmas wish programmatically. It is using turtle as a pen pointer draw christmas trees and stars.

Introduction This is a simple program is written in python and turtle library. The objective of this program is to wish merry Christmas programmatical

Gunarakulan Gunaretnam 1 Dec 25, 2021
Overview of architecture and implementation of TEDS-Net, as described in MICCAI 2021: "TEDS-Net: Enforcing Diffeomorphisms in Spatial Transformers to Guarantee TopologyPreservation in Segmentations"

TEDS-Net Overview of architecture and implementation of TEDS-Net, as described in MICCAI 2021: "TEDS-Net: Enforcing Diffeomorphisms in Spatial Transfo

Madeleine K Wyburd 14 Jan 04, 2023
OOD Dataset Curator and Benchmark for AI-aided Drug Discovery

🔥 DrugOOD 🔥 : OOD Dataset Curator and Benchmark for AI Aided Drug Discovery This is the official implementation of the DrugOOD project, this is the

108 Dec 17, 2022
Simple Python project using Opencv and datetime package to recognise faces and log attendance data in a csv file.

Attendance-System-based-on-Facial-recognition-Attendance-data-stored-in-csv-file- Simple Python project using Opencv and datetime package to recognise

3 Aug 09, 2022
Real time Human Detection Counting

In this python project, we are going to build the Human Detection and Counting System through Webcam or you can give your own video or images. This is a deep learning project on computer vision, whic

Mir Nawaz Ahmad 2 Jun 17, 2022
A Fast Monotone Rotating Shallow Water model

pyRSW A Fast Monotone Rotating Shallow Water model How fast? As fast as a sustained 2 Gflop/s per core on a 2.5 GHz cpu (or 2048 Gflop/s with 1024 cor

Guillaume Roullet 13 Sep 28, 2022
Modifications of the official PyTorch implementation of StyleGAN3. Let's easily generate images and videos with StyleGAN2/2-ADA/3!

Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation of the NeurIPS 2021 paper Alias-Free Generative Adversarial Net

Diego Porres 185 Dec 24, 2022
A simple, unofficial implementation of MAE using pytorch-lightning

Masked Autoencoders in PyTorch A simple, unofficial implementation of MAE (Masked Autoencoders are Scalable Vision Learners) using pytorch-lightning.

Connor Anderson 20 Dec 03, 2022
Gesture recognition on Event Data

Event based Gesture Recognition Gesture recognition on Event Data usually involv

2 Feb 14, 2022
Crowd-Kit is a powerful Python library that implements commonly-used aggregation methods for crowdsourced annotation and offers the relevant metrics and datasets

Crowd-Kit: Computational Quality Control for Crowdsourcing Documentation Crowd-Kit is a powerful Python library that implements commonly-used aggregat

Toloka 125 Dec 30, 2022
Builds a LoRa radio frequency fingerprint identification (RFFI) system based on deep learning techiniques

This project builds a LoRa radio frequency fingerprint identification (RFFI) system based on deep learning techiniques.

20 Dec 30, 2022
An End-to-End Machine Learning Library to Optimize AUC (AUROC, AUPRC).

Logo by Zhuoning Yuan LibAUC: A Machine Learning Library for AUC Optimization Website | Updates | Installation | Tutorial | Research | Github LibAUC a

Optimization for AI 176 Jan 07, 2023
NeuroFind - A solution to the to the Task given by the Oberseminar of Messtechnik Institute of TU Dresden in 2021

NeuroFind A solution to the to the Task given by the Oberseminar of Messtechnik

1 Jan 20, 2022
Object detection on multiple datasets with an automatically learned unified label space.

Simple multi-dataset detection An object detector trained on multiple large-scale datasets with a unified label space; Winning solution of E

Xingyi Zhou 407 Dec 30, 2022
Source code for "Pack Together: Entity and Relation Extraction with Levitated Marker"

PL-Marker Source code for Pack Together: Entity and Relation Extraction with Levitated Marker. Quick links Overview Setup Install Dependencies Data Pr

THUNLP 173 Dec 30, 2022
Simple helper library to convert a collection of numpy data to tfrecord, and build a tensorflow dataset from the tfrecord.

numpy2tfrecord Simple helper library to convert a collection of numpy data to tfrecord, and build a tensorflow dataset from the tfrecord. Installation

Ryo Yonetani 2 Jan 16, 2022
MGFN: Multi-Graph Fusion Networks for Urban Region Embedding was accepted by IJCAI-2022.

Multi-Graph Fusion Networks for Urban Region Embedding (IJCAI-22) This is the implementation of Multi-Graph Fusion Networks for Urban Region Embedding

202 Nov 18, 2022
DaReCzech is a dataset for text relevance ranking in Czech

Dataset DaReCzech is a dataset for text relevance ranking in Czech. The dataset consists of more than 1.6M annotated query-documents pairs,

Seznam.cz a.s. 8 Jul 26, 2022
Rust bindings for the C++ api of PyTorch.

tch-rs Rust bindings for the C++ api of PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorc

Laurent Mazare 2.3k Dec 30, 2022