Deep Reinforcement Learning with pytorch & visdom

Overview

Deep Reinforcement Learning with

pytorch & visdom


  • Sample testings of trained agents (DQN on Breakout, A3C on Pong, DoubleDQN on CartPole, continuous A3C on InvertedPendulum(MuJoCo)):
  • Sample on-line plotting while training an A3C agent on Pong (with 16 learner processes): a3c_pong_plot

  • Sample loggings while training a DQN agent on CartPole (we use WARNING as the logging level currently to get rid of the INFO printouts from visdom):

[WARNING ] (MainProcess) <===================================>
[WARNING ] (MainProcess) bash$: python -m visdom.server
[WARNING ] (MainProcess) http://localhost:8097/env/daim_17040900
[WARNING ] (MainProcess) <===================================> DQN
[WARNING ] (MainProcess) <-----------------------------------> Env
[WARNING ] (MainProcess) Creating {gym | CartPole-v0} w/ Seed: 123
[INFO    ] (MainProcess) Making new env: CartPole-v0
[WARNING ] (MainProcess) Action Space: [0, 1]
[WARNING ] (MainProcess) State  Space: 4
[WARNING ] (MainProcess) <-----------------------------------> Model
[WARNING ] (MainProcess) MlpModel (
  (fc1): Linear (4 -> 16)
  (rl1): ReLU ()
  (fc2): Linear (16 -> 16)
  (rl2): ReLU ()
  (fc3): Linear (16 -> 16)
  (rl3): ReLU ()
  (fc4): Linear (16 -> 2)
)
[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch.
[WARNING ] (MainProcess) <===================================> Training ...
[WARNING ] (MainProcess) Validation Data @ Step: 501
[WARNING ] (MainProcess) Start  Training @ Step: 501
[WARNING ] (MainProcess) Reporting       @ Step: 2500 | Elapsed Time: 5.32397913933
[WARNING ] (MainProcess) Training Stats:   epsilon:          0.972
[WARNING ] (MainProcess) Training Stats:   total_reward:     2500.0
[WARNING ] (MainProcess) Training Stats:   avg_reward:       21.7391304348
[WARNING ] (MainProcess) Training Stats:   nepisodes:        115
[WARNING ] (MainProcess) Training Stats:   nepisodes_solved: 114
[WARNING ] (MainProcess) Training Stats:   repisodes_solved: 0.991304347826
[WARNING ] (MainProcess) Evaluating      @ Step: 2500
[WARNING ] (MainProcess) Iteration: 2500; v_avg: 1.73136949539
[WARNING ] (MainProcess) Iteration: 2500; tderr_avg: 0.0964358523488
[WARNING ] (MainProcess) Iteration: 2500; steps_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; steps_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; reward_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; reward_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; nepisodes: 107
[WARNING ] (MainProcess) Iteration: 2500; nepisodes_solved: 106
[WARNING ] (MainProcess) Iteration: 2500; repisodes_solved: 0.990654205607
[WARNING ] (MainProcess) Saving Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth ...
[WARNING ] (MainProcess) Saved  Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth.
[WARNING ] (MainProcess) Resume Training @ Step: 2500
...

What is included?

This repo currently contains the following agents:

  • Deep Q Learning (DQN) [1], [2]
  • Double DQN [3]
  • Dueling network DQN (Dueling DQN) [4]
  • Asynchronous Advantage Actor-Critic (A3C) (w/ both discrete/continuous action space support) [5], [6]
  • Sample Efficient Actor-Critic with Experience Replay (ACER) (currently w/ discrete action space support (Truncated Importance Sampling, 1st Order TRPO)) [7], [8]

Work in progress:

  • Testing ACER

Future Plans:

  • Deep Deterministic Policy Gradient (DDPG) [9], [10]
  • Continuous DQN (CDQN or NAF) [11]

Code structure & Naming conventions:

NOTE: we follow the exact code structure as pytorch-dnc so as to make the code easily transplantable.

  • ./utils/factory.py

We suggest the users refer to ./utils/factory.py, where we list all the integrated Env, Model, Memory, Agent into Dict's. All of those four core classes are implemented in ./core/. The factory pattern in ./utils/factory.py makes the code super clean, as no matter what type of Agent you want to train, or which type of Env you want to train on, all you need to do is to simply modify some parameters in ./utils/options.py, then the ./main.py will do it all (NOTE: this ./main.py file never needs to be modified).

  • namings

To make the code more clean and readable, we name the variables using the following pattern (mainly in inherited Agent's):

  • *_vb: torch.autograd.Variable's or a list of such objects
  • *_ts: torch.Tensor's or a list of such objects
  • otherwise: normal python datatypes

Dependencies


How to run:

You only need to modify some parameters in ./utils/options.py to train a new configuration.

  • Configure your training in ./utils/options.py:
  • line 14: add an entry into CONFIGS to define your training (agent_type, env_type, game, model_type, memory_type)
  • line 33: choose the entry you just added
  • line 29-30: fill in your machine/cluster ID (MACHINE) and timestamp (TIMESTAMP) to define your training signature (MACHINE_TIMESTAMP), the corresponding model file and the log file of this training will be saved under this signature (./models/MACHINE_TIMESTAMP.pth & ./logs/MACHINE_TIMESTAMP.log respectively). Also the visdom visualization will be displayed under this signature (first activate the visdom server by type in bash: python -m visdom.server &, then open this address in your browser: http://localhost:8097/env/MACHINE_TIMESTAMP)
  • line 32: to train a model, set mode=1 (training visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP); to test the model of this current training, all you need to do is to set mode=2 (testing visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP_test).
  • Run:

python main.py


Bonus Scripts :)

We also provide 2 additional scripts for quickly evaluating your results after training. (Dependecies: lmj-plot)

  • plot.sh (e.g., plot from log file: logs/machine1_17080801.log)
  • ./plot.sh machine1 17080801
  • the generated figures will be saved into figs/machine1_17080801/
  • plot_compare.sh (e.g., compare log files: logs/machine1_17080801.log,logs/machine2_17080802.log)

./plot.sh 00 machine1 17080801 machine2 17080802

  • the generated figures will be saved into figs/compare_00/
  • the color coding will be in the order of: red green blue magenta yellow cyan

Repos we referred to during the development of this repo:


Citation

If you find this library useful and would like to cite it, the following would be appropriate:

@misc{pytorch-rl,
  author = {Zhang, Jingwei and Tai, Lei},
  title = {jingweiz/pytorch-rl},
  url = {https://github.com/jingweiz/pytorch-rl},
  year = {2017}
}
Owner
Jingwei Zhang
Jingwei Zhang
PyTorch code for: Learning to Generate Grounded Visual Captions without Localization Supervision

Learning to Generate Grounded Visual Captions without Localization Supervision This is the PyTorch implementation of our paper: Learning to Generate G

Chih-Yao Ma 41 Nov 17, 2022
An implementation of the methods presented in Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

An implementation of the methods presented in Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

Andrew Jesson 9 Apr 04, 2022
A collection of educational notebooks on multi-view geometry and computer vision.

Multiview notebooks This is a collection of educational notebooks on multi-view geometry and computer vision. Subjects covered in these notebooks incl

Max 65 Dec 09, 2022
face_recognization (FaceNet) + TFHE (HNP) + hand_face_detection (Mediapipe)

SuperControlSystem Face_Recognization (FaceNet) 面部识别 (FaceNet) Fully Homomorphic Encryption over the Torus (HNP) 环面全同态加密 (TFHE) Hand_Face_Detection (M

liziyu0104 2 Dec 30, 2021
Api for getting bin info and getting encrypted card details for adyen.

Bin Info And Adyen Cse Enc Python api for getting bin info and getting encrypted

Roldex Stark 8 Dec 30, 2022
Reinforcement learning algorithms in RLlib

raylab Reinforcement learning algorithms in RLlib and PyTorch. Installation pip install raylab Quickstart Raylab provides agents and environments to b

Ângelo 50 Sep 08, 2022
Source code and dataset of the paper "Contrastive Adaptive Propagation Graph Neural Networks forEfficient Graph Learning"

CAPGNN Source code and dataset of the paper "Contrastive Adaptive Propagation Graph Neural Networks forEfficient Graph Learning" Paper URL: https://ar

1 Mar 12, 2022
Fast Neural Style for Image Style Transform by Pytorch

FastNeuralStyle by Pytorch Fast Neural Style for Image Style Transform by Pytorch This is famous Fast Neural Style of Paper Perceptual Losses for Real

Bengxy 81 Sep 03, 2022
Tello Drone Trajectory Tracking

With this library you can track the trajectory of your tello drone or swarm of drones in real time.

Kamran Asgarov 2 Oct 12, 2022
Visualizing Yolov5's layers using GradCam

YOLO-V5 GRADCAM I constantly desired to know to which part of an object the object-detection models pay more attention. So I searched for it, but I di

Pooya Mohammadi Kazaj 200 Jan 01, 2023
UMEC: Unified Model and Embedding Compression for Efficient Recommendation Systems

[ICLR 2021] "UMEC: Unified Model and Embedding Compression for Efficient Recommendation Systems" by Jiayi Shen, Haotao Wang*, Shupeng Gui*, Jianchao Tan, Zhangyang Wang, and Ji Liu

VITA 39 Dec 03, 2022
OBBDetection: an oriented object detection toolbox modified from MMdetection

OBBDetection note: If you have questions or good suggestions, feel free to propose issues and contact me. introduction OBBDetection is an oriented obj

MIXIAOXIN_HO 3 Nov 11, 2022
Editing a Conditional Radiance Field

Editing Conditional Radiance Fields Project | Paper | Video | Demo Editing Conditional Radiance Fields Steven Liu, Xiuming Zhang, Zhoutong Zhang, Rich

Steven Liu 216 Dec 30, 2022
RuDOLPH: One Hyper-Modal Transformer can be creative as DALL-E and smart as CLIP

[Paper] [Хабр] [Model Card] [Colab] [Kaggle] RuDOLPH 🦌 🎄 ☃️ One Hyper-Modal Tr

Sber AI 230 Dec 31, 2022
Differentiable rasterization applied to 3D model simplification tasks

nvdiffmodeling Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Automatic 3D Model

NVIDIA Research Projects 336 Dec 30, 2022
This reposityory contains the PyTorch implementation of our paper "Generative Dynamic Patch Attack".

Generative Dynamic Patch Attack This reposityory contains the PyTorch implementation of our paper "Generative Dynamic Patch Attack". Requirements PyTo

Xiang Li 8 Nov 17, 2022
Vision-Language Transformer and Query Generation for Referring Segmentation (ICCV 2021)

Vision-Language Transformer and Query Generation for Referring Segmentation Please consider citing our paper in your publications if the project helps

Henghui Ding 143 Dec 23, 2022
Notebook and code to synthesize complex and highly dimensional datasets using Gretel APIs.

Gretel Trainer This code is designed to help users successfully train synthetic models on complex datasets with high row and column counts. The code w

Gretel.ai 24 Nov 03, 2022
[ICLR2021] Unlearnable Examples: Making Personal Data Unexploitable

Unlearnable Examples Code for ICLR2021 Spotlight Paper "Unlearnable Examples: Making Personal Data Unexploitable " by Hanxun Huang, Xingjun Ma, Sarah

Hanxun Huang 98 Dec 07, 2022
Data-Driven Operational Space Control for Adaptive and Robust Robot Manipulation

OSCAR Project Page | Paper This repository contains the codebase used in OSCAR: Data-Driven Operational Space Control for Adaptive and Robust Robot Ma

NVIDIA Research Projects 74 Dec 22, 2022