A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Context-free grammar to Sublime-syntax file

Generate a sublime-syntax file from a non-left-recursive, follow-determined, context-free grammar

Haggai Nuchi 8 Nov 17, 2022
This is a Python program I wrote to simulate the solar system with 79 lines of code.

Solar System With Python This is a Python program I wrote to simulate the solar system with 79 lines of code. Required modules tkinter, math, time Why

Mehmet Aydoğmuş 1 Oct 26, 2021
Sudo type me a payload

payloadSecretary Sudo type me a payload Have you ever found yourself having to perform a test, and a client has provided you with a VM inside a VDI in

7 Jul 21, 2022
OB_Template is a vault template reference for using Obsidian.

Obsidian Template OB_Template is a vault template reference for using Obsidian. If you've tested out Obsidian. and worked through the "Obsidian Help"

323 Dec 27, 2022
CMPE 204 Modelling Project

CISC/CMPE 204 Modelling Project Welcome to the major project for CISC/CMPE 204 (Fall 2021)! Change this README.md file to summarize your project (few

totallyrin 2 May 16, 2022
Mixtaper - Web app to make mixtapes

Mixtaper A web app which allows you to input songs in the form of youtube links

suryansh 1 Feb 14, 2022
Project Faros is a reference implimentation of Red Hat OpenShift 4 on small footprint, bare-metal clusters.

Project Faros Project Faros is a reference implimentation of Red Hat OpenShift 4 on small footprint, bare-metal clusters. The project includes referen

project: Faros 9 Jul 18, 2022
Structural basis for solubility in protein expression systems

Structural basis for solubility in protein expression systems Large-scale protein production for biotechnology and biopharmaceutical applications rely

ProteinQure 16 Aug 18, 2022
Metal Gear Rising: Revengeance's DAT archive (un)packer

DOOMP Metal Gear Rising: Revengeance's DAT archive (un)packer

Christopher Holzmann Pérez 5 Sep 02, 2022
YourCity is a platform to match people to their prefect city.

YourCity YourCity is a city matching App that matches users to their ideal city. It is a fullstack React App made with a Redux state manager and a bac

Nico G Pierson 6 Sep 25, 2021
Aoc 2021 kedro playground with python

AOC 2021 Overview This is your new Kedro project, which was generated using Kedro 0.17.5. Take a look at the Kedro documentation to get started. Rules

1 Dec 20, 2021
Your E-Canteen that is convenient and accessible wherever you are in the campus

Food Web E-Canteen System Your E-Canteen that is convenient and accessible wherever you are in the campus. Table of Contents About The Project Contrib

Pudding 5 Jan 07, 2023
CPython extension implementing Shared Transactional Memory with native-looking interface

CPython extension implementing Shared Transactional Memory with native-looking interface

21 Jul 22, 2022
Advent of Code is an Advent calendar of small programming puzzles for a variety of skill sets and skill levels that can be solved in any programming language you like.

Advent Of Code 2021 - Python English Advent of Code is an Advent calendar of small programming puzzles for a variety of skill sets and skill levels th

Coral Izquierdo Muñiz 2 Jan 09, 2022
A Python wrapper API for operating and working with the Neo4j Graph Data Science (GDS) library

gdsclient NOTE: This is a work in progress and many GDS features are known to be missing or not working properly. This repo hosts the sources for gdsc

Neo4j 100 Dec 20, 2022
Script para generar automatización de registro de formularios IEEH

Formularios_IEEH Script para generar automatización de registro de formularios IEEH Corresponde a un conjunto de script en python que permiten la auto

vhevia11 1 Jan 06, 2022
Automatically deletes Capital One Eno virtual cards for when you've made a couple too many.

eno-delete Automatically deletes Capital One Eno virtual cards for when you've made a couple too many. Warning: Program will delete ALL virtual cards

h3x 3 Sep 29, 2022
People tracker on the Internet: OSINT analysis and research tool by Jose Pino

trape (stable) v2.0 People tracker on the Internet: Learn to track the world, to avoid being traced. Trape is an OSINT analysis and research tool, whi

Jose Pino 7.3k Dec 30, 2022
A flexible free and unlimited python tool to translate between different languages in a simple way using multiple translators.

deep-translator Translation for humans A flexible FREE and UNLIMITED tool to translate between different languages in a simple way using multiple tran

Nidhal Baccouri 806 Jan 04, 2023
An improved version of the common ˙pacman -S˙

BetterPacmanLook An improved version of the common pacman -S. Installation I know that this is probably one of the worst solutions and i will be worki

1 Nov 06, 2021