A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Wordle-solve - Attempting to solve wordle

Wordle Solver Run with python wordle_beater.py. This hardmode wordle solver take

Tom Lockwood 42 Oct 11, 2022
Scrapper For Paste.pics

PrntScScrapper Scrapper for Paste.pics If you are bored you can find some random screenshots from prnt.sc Features Saving screenshots Open in Browser

Fareusz 1 Dec 29, 2021
A general purpose low level programming language written in Python.

A general purpose low level programming language written in Python. Basal is an easy mid level programming language compiling to C. It has an easy syntax, similar to Python, Rust etc.

Snm Logic 6 Mar 30, 2022
A Kodi add-on for watching content hosted on PeerTube.

A Kodi add-on for watching content hosted on PeerTube. This add-on is under development so only basic features work, and you're welcome to improve it.

1 Dec 18, 2021
A one place destination to check whatever is trending on the top social and news websites at present.

UpTrend A one place destination to check whatever is trending on the top social and news websites at present. Explore the docs » View Demo · Report Bu

Google Developer Student Clubs - JGEC 10 Oct 03, 2021
VCC-Generator is a python script that generate VCC for testing purposes only

VCC-Generator is a python script that generate VCC for testing purposes only

Spider Anongreyhat 10 Oct 23, 2022
BasicVSR++ function for VapourSynth

BasicVSR++ BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment Ported from https://github.com/open-mmlab/mmediting De

Holy Wu 34 Nov 28, 2022
Installer, package manager, build wrapper and version manager for Piccolo

Piccl Installer, package manager, build wrapper and version manager for Piccolo

1 Dec 19, 2021
Demo of connecting Rasa with Zalo

Demo of connecting Rasa with Zalo

6 Jul 25, 2022
Model synchronization from dbt to Metabase.

dbt-metabase Model synchronization from dbt to Metabase. If dbt is your source of truth for database schemas and you use Metabase as your analytics to

Mike Gouline 270 Jan 08, 2023
fast_bss_eval is a fast implementation of the bss_eval metrics for the evaluation of blind source separation.

fast_bss_eval Do you have a zillion BSS audio files to process and it is taking days ? Is your simulation never ending ? Fear no more! fast_bss_eval i

Robin Scheibler 99 Dec 13, 2022
Simple plug-and-play installer for users who want to LineageOS from stock firmware, or from another custom ROM.

LineageOS for the Teracube 2e Simple plug-and-play installer for users who want to LineageOS from stock firmware, or from another custom ROM. Dependen

Gagan Malvi 5 Mar 31, 2022
Flask html response minifier

Flask-HTMLmin Minify flask text/html mime type responses. Just add MINIFY_HTML = True to your deployment config to minify HTML and text responses of y

Hamid Feizabadi 85 Dec 07, 2022
Kellogg bad | Union good | Support strike funds

KelloggBot Credit to SeanDaBlack for the basis of the script. req.py is selenium python bot. sc.js is a the base of the ios shortcut [COMING SOON] Set

407 Nov 17, 2022
Implements a polyglot REPL which supports multiple languages and shared meta-object protocol scope between REPLs.

MetaCall Polyglot REPL Description This repository implements a Polyglot REPL which shares the state of the meta-object protocol between the REPLs. Us

MetaCall 10 Dec 28, 2022
Change your Windows background with this program safely & easily!

Background_Changer Table of Contents: About the Program Features Requirements Preview Credits Reach Me See Also About the Program: You can change your

Sina.f 0 Jul 14, 2022
PyDy, short for Python Dynamics, is a tool kit written in the Python

PyDy, short for Python Dynamics, is a tool kit written in the Python programming language that utilizes an array of scientific programs to enable the study of multibody dynamics. The goal is to have

PyDy 307 Jan 01, 2023
A prototype COG-based tile server for sparse Mars datasets

Mars tiler Mars Tiler is a prototype web application that serves tiles from cloud-optimized GeoTIFFs, with an emphasis on supporting planetary dataset

Daven Quinn 3 Mar 23, 2022
Add all JuliaLang unicode abbreviations to AutoKey.

Autokey Unicode characters Usage This script adds all the unicode character abbreviations supported by Julia to autokey. However, instead of [TAB], th

Randolf Scholz 49 Dec 02, 2022
Class and mathematical functions for quaternion numbers.

Quaternions Class and mathematical functions for quaternion numbers. Installation Python This is a Python 3 module. If you don't have Python installed

3 Nov 08, 2022