A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
SymbLang are my programming language! Insired by the brainf**k.

SymbLang . - output as Unicode. , - input. ; - clear data. & - character that the main line start with. @value: 0 - 9 - character that the function

1 Apr 04, 2022
Notes on the Deep Learning book from Ian Goodfellow, Yoshua Bengio and Aaron Courville (2016)

The Deep Learning Book - Goodfellow, I., Bengio, Y., and Courville, A. (2016) This content is part of a series following the chapter 2 on linear algeb

hadrienj 1.7k Jan 07, 2023
36 key ergo split keyboard, designed around the Seeeduino Xiao platform

Slice36 Minimalist Split Keyboard 36 key ergo split keyboard, designed around the Seeeduino Xiao platform. Inspired by the Corne, Ferris, Ben Vallack'

54 Dec 21, 2022
Collection of tools to be more productive in your work environment and to avoid certain repetitive tasks. 💛💙💚

Collection of tools to be more productive in your work environment and to avoid certain repetitive tasks. 💛💙💚

Raja Rakotonirina 2 Jan 10, 2022
A Python module for decorators, wrappers and monkey patching.

wrapt The aim of the wrapt module is to provide a transparent object proxy for Python, which can be used as the basis for the construction of function

Graham Dumpleton 1.8k Jan 06, 2023
Viewer for NFO files

NFO Viewer NFO Viewer is a simple viewer for NFO files, which are "ASCII" art in the CP437 codepage. The advantages of using NFO Viewer instead of a t

Osmo Salomaa 114 Dec 29, 2022
Python Programming Bootcamp

python-bootcamp Python Programming Bootcamp Begin: 27th August 2021 End: 8th September 2021 Registration deadline: 22nd August 2021 Fees: No course or

Rohitash Chandra 11 Oct 19, 2022
This repo is for scripts to run various clients at the merge f2f

merge-f2f This repo is for scripts to run various clients at the merge f2f. Tested with Lighthouse! Tested with Geth! General dependecies sudo apt-get

Parithosh Jayanthi 2 Apr 03, 2022
Improving the Transferability of Adversarial Examples with Resized-Diverse-Inputs, Diversity-Ensemble and Region Fitting

Improving the Transferability of Adversarial Examples with Resized-Diverse-Inputs, Diversity-Ensemble and Region Fitting

Junhua Zou 7 Oct 20, 2022
Box CRUD API With Python

Box CRUD API: Consider a store which has an inventory of boxes which are all cuboid(which have length breadth and height). Each Cuboid has been added

Akhil Bhalerao 3 Feb 17, 2022
Python script for converting obsidian md-file to html (recursively adds all link/images)

ObsidianToHtmlConverter I made a small python script for converting obsidian md-file to static (local) html (recursively adds all link/images) I made

47 Jan 03, 2023
Simulation simplifiée du fonctionnement du protocole RIP

ProjetRIPlay v2 Simulation simplifiée du fonctionnement du protocole RIP par Eric Buonocore le 18/01/2022 Sur la base de l'exercice 5 du sujet zéro du

Eric Buonocore 2 Feb 15, 2022
Euler 021 Py - Euler Problem 021 solved in Python

Euler_021_Py Euler Problem 021 solved in Python Let d(n) be defined as the sum o

Ariel Tynan 1 Jan 24, 2022
Web UI for your scripts with execution management

Script-server is a Web UI for scripts. As an administrator, you add your existing scripts into Script server and other users would be ab

Iaroslav Shepilov 1.1k Jan 09, 2023
→ Plantilla de registro para Python

🔧 Pasos Necesarios CMD 🖥️ SOCKETS pip install sockets 🎨 COLORAMA pip install colorama 💻 Código register-by-inputs from turtle import color # Impor

Panda.xyz 4 Mar 12, 2022
Shopify Backend Developer Intern Challenge - Summer 2022

Shopify Backend Developer Intern The task is build an inventory tracking web application for a logistics company. The detailed task details can be fou

Meet Gandhi 11 Oct 08, 2022
This is an example manipulation package of for a robot manipulator based on Drake with ROS2.

This is an example manipulation package of for a robot manipulator based on Drake with ROS2.

Sotaro Katayama 1 Oct 21, 2021
this is a basic python project that I made using python

this is a basic python project that I made using python. This project is only for practice because my python skills are still newbie.

Elvira Firmansyah 2 Dec 14, 2022
Python library for datamining glitch information from Gen 1 Pokémon GameBoy ROMs

g1utils This is a Python library for datamining information about various glitches (glitch Pokémon, glitch maps, etc.) from Gen 1 Pokémon ROMs. TODO A

1 Jan 13, 2022
Santa's kitchen helper for python

Santa's Kitchen Helper Introduction/Overview Contents UX User Stories Design Wireframes Color Scheme Typography Imagery Features Exisiting Features Fe

Paul Browne 4 May 31, 2022