PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

Overview

neural-combinatorial-rl-pytorch

PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

I have implemented the basic RL pretraining model with greedy decoding from the paper. An implementation of the supervised learning baseline model is available here. Instead of a critic network, I got my results below on TSP from using an exponential moving average critic. The critic network is simply commented out in my code right now. From correspondence with a few others, it was determined that the exponential moving average critic significantly helped improve results.

My implementation uses a stochastic decoding policy in the pointer network, realized via PyTorch's torch.multinomial(), during training, and beam search (not yet finished, only supports 1 beam a.k.a. greedy) for decoding when testing the model.

Currently, there is support for a sorting task and the planar symmetric Euclidean TSP.

See main.sh for an example of how to run the code.

Use the --load_path $LOAD_PATH and --is_train False flags to load a saved model.

To load a saved model and view the pointer network's attention layer, also use the --plot_attention True flag.

Please, feel free to notify me if you encounter any errors, or if you'd like to submit a pull request to improve this implementation.

Adding other tasks

This implementation can be extended to support other combinatorial optimization problems. See sorting_task.py and tsp_task.py for examples on how to add. The key thing is to provide a dataset class and a reward function that takes in a sample solution, selected by the pointer network from the input, and returns a scalar reward. For the sorting task, the agent received a reward proportional to the length of the longest strictly increasing subsequence in the decoded output (e.g., [1, 3, 5, 2, 4] -> 3/5 = 0.6).

Dependencies

  • Python=3.6 (should be OK with v >= 3.4)
  • PyTorch=0.2 and 0.3
  • tqdm
  • matplotlib
  • tensorboard_logger

PyTorch 0.4 compatibility is available on branch pytorch-0.4.

TSP Results

Results for 1 random seed over 50 epochs (each epoch is 10,000 batches of size 128). After each epoch, I validated performance on 1000 held out graphs. I used the same hyperparameters from the paper, as can be seen in main.sh. The dashed line shows the value indicated in Table 2 of Bello, et. al for comparison. The log scale x axis for the training reward is used to show how the tour length drops early on.

TSP 20 Train TSP 20 Val TSP 50 Train TSP 50 Val

Sort Results

I trained a model on sort10 for 4 epochs of 1,000,000 randomly generated samples. I tested it on a dataset of size 10,000. Then, I tested the same model on sort15 and sort20 to test the generalization capabilities.

Test results on 10,000 samples (A reward of 1.0 means the network perfectly sorted the input):

task average reward variance
sort10 0.9966 0.0005
sort15 0.7484 0.0177
sort20 0.5586 0.0060

Example prediction on sort10:

input: [4, 7, 5, 0, 3, 2, 6, 8, 9, 1]
output: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Attention visualization

Plot the pointer network's attention layer with the argument --plot_attention True

TODO

  • Add RL pretraining-Sampling
  • Add RL pretraining-Active Search
  • Active Search
  • Asynchronous training a la A3C
  • Refactor USE_CUDA variable
  • Finish implementing beam search decoding to support > 1 beam
  • Add support for variable length inputs

Acknowledgements

Special thanks to the repos devsisters/neural-combinatorial-rl-tensorflow and MaximumEntropy/Seq2Seq-PyTorch for getting me started, and @ricgama for figuring out that weird bug with clone()

Owner
Patrick E.
Machine Learning PhD Candidate at Univ. of Florida. Deep generative models | object-centric representation learning | RL | transportation
Patrick E.
Course on computational design, non-linear optimization, and dynamics of soft systems at UIUC.

Computational Design and Dynamics of Soft Systems · This is a repository that contains the source code for generating the lecture notes, handouts, exe

Tejaswin Parthasarathy 4 Jul 21, 2022
[ICML 2020] DrRepair: Learning to Repair Programs from Error Messages

DrRepair: Learning to Repair Programs from Error Messages This repo provides the source code & data of our paper: Graph-based, Self-Supervised Program

Michihiro Yasunaga 155 Jan 08, 2023
This is the pytorch code for the paper Curious Representation Learning for Embodied Intelligence.

Curious Representation Learning for Embodied Intelligence This is the pytorch code for the paper Curious Representation Learning for Embodied Intellig

19 Oct 19, 2022
RANZCR-CLiP 7th Place Solution

RANZCR-CLiP 7th Place Solution This repository is WIP. (18 Mar 2021) Installation git clone https://github.com/analokmaus/kaggle-ranzcr-clip-public.gi

Hiroshechka Y 21 Oct 22, 2022
Libtorch yolov3 deepsort

Overview It is for my undergrad thesis in Tsinghua University. There are four modules in the project: Detection: YOLOv3 Tracking: SORT and DeepSORT Pr

Xu Wei 226 Dec 13, 2022
PyGCL: A PyTorch Library for Graph Contrastive Learning

PyGCL is a PyTorch-based open-source Graph Contrastive Learning (GCL) library, which features modularized GCL components from published papers, standa

PyGCL 588 Dec 31, 2022
[ICCV 2021] Official Tensorflow Implementation for "Single Image Defocus Deblurring Using Kernel-Sharing Parallel Atrous Convolutions"

KPAC: Kernel-Sharing Parallel Atrous Convolutional block This repository contains the official Tensorflow implementation of the following paper: Singl

Hyeongseok Son 50 Dec 29, 2022
PyTorch-Multi-Style-Transfer - Neural Style and MSG-Net

PyTorch-Style-Transfer This repo provides PyTorch Implementation of MSG-Net (ours) and Neural Style (Gatys et al. CVPR 2016), which has been included

Hang Zhang 906 Jan 04, 2023
SparseInst: Sparse Instance Activation for Real-Time Instance Segmentation, CVPR 2022

SparseInst 🚀 A simple framework for real-time instance segmentation, CVPR 2022 by Tianheng Cheng, Xinggang Wang†, Shaoyu Chen, Wenqiang Zhang, Qian Z

Hust Visual Learning Team 458 Jan 05, 2023
A Python library that enables ML teams to share, load, and transform data in a collaborative, flexible, and efficient way :chestnut:

Squirrel Core Share, load, and transform data in a collaborative, flexible, and efficient way What is Squirrel? Squirrel is a Python library that enab

Merantix Momentum 249 Dec 07, 2022
Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Networks

CyGNet This repository reproduces the AAAI'21 paper “Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Network

CunchaoZ 89 Jan 03, 2023
Official implementation of EdiTTS: Score-based Editing for Controllable Text-to-Speech

EdiTTS: Score-based Editing for Controllable Text-to-Speech Official implementation of EdiTTS: Score-based Editing for Controllable Text-to-Speech. Au

Neosapience 98 Dec 25, 2022
AutoML library for deep learning

Official Website: autokeras.com AutoKeras: An AutoML system based on Keras. It is developed by DATA Lab at Texas A&M University. The goal of AutoKeras

Keras 8.7k Jan 08, 2023
This is a classifier which basically predicts whether there is a gun law in a state or not, depending on various things like murder rates etc.

Gun-Laws-Classifier This is a classifier which basically predicts whether there is a gun law in a state or not, depending on various things like murde

Awais Saleem 1 Jan 20, 2022
Json2Xml tool will help you convert from json COCO format to VOC xml format in Object Detection Problem.

JSON 2 XML All codes assume running from root directory. Please update the sys path at the beginning of the codes before running. Over View Json2Xml t

Nguyễn Trường Lâu 6 Aug 22, 2022
Self-supervised spatio-spectro-temporal represenation learning for EEG analysis

EEG-Oriented Self-Supervised Learning and Cluster-Aware Adaptation This repository provides a tensorflow implementation of a submitted paper: EEG-Orie

Wonjun Ko 4 Jun 09, 2022
Aspect-Sentiment-Multiple-Opinion Triplet Extraction (NLPCC 2021)

The code and data for the paper "Aspect-Sentiment-Multiple-Opinion Triplet Extraction" Requirements Python 3.6.8 torch==1.2.0 pytorch-transformers==1.

慢半拍 5 Jul 02, 2022
Implementation of UNet on the Joey ML framework

Independent Research Project - Code Joey can be cloned from here https://github.com/devitocodes/joey/. Devito and other dependencies such as PyTorch a

Navjot Kukreja 1 Oct 21, 2021
A Simple LSTM-Based Solution for "Heartbeat Signal Classification and Prediction" in Tianchi

LSTM-Time-Series-Prediction A Simple LSTM-Based Solution for "Heartbeat Signal Classification and Prediction" in Tianchi Contest. The Link of the Cont

KevinCHEN 1 Jun 13, 2022