A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Web3 Solidity Connector

With this project, you can compile your sol files and create new transactions including creating contract and calling the state changer functions. You can integrate integrate your sol files with Pyth

Fethi Tekyaygil 3 Oct 09, 2022
Fly DCS without a joystick

Intro Usage Delete all mouse view axis Install DCSEasyControlExports to your "Saved Games/DCS/" Path python DCSEasyControl/main.py Set DCS to F12 view

XuHao 36 Dec 27, 2022
Example code for the book Fluent Python, 1st Edition (O'Reilly, 2015)

Fluent Python, First Edition: example code This repository is archived and will not be updated.

Fluent Python 5.4k Jan 09, 2023
YourX: URL Clusterer With Python

YourX | URL Clusterer Screenshots Instructions for running Install requirements

ARPSyndicate 1 Mar 11, 2022
A one place destination to check whatever is trending on the top social and news websites at present.

UpTrend A one place destination to check whatever is trending on the top social and news websites at present. Explore the docs » View Demo · Report Bu

Google Developer Student Clubs - JGEC 10 Oct 03, 2021
Collection of functions for working with interlaced content in VapourSynth.

vsfieldkit Collection of functions for working with interlaced content in VapourSynth. It does not have any hard dependencies outside of VapourSynth.

Justin Turner Arthur 11 May 27, 2022
Let’s Play with Python3

Python3-FirstEdition a bunch of python programs and stuff Super Important Notice THIS IS LICENSED UNDER GNU PUBLIC LICENSE V3 also, refer to Contribut

Jym Patel 2 Nov 24, 2022
Fuzz introspector for python

Fuzz introspector High-level goals: Show fuzzing-relevant data about each function in a given project Show reachability of fuzzer(s) Integrate seamles

14 Mar 25, 2022
bib2xml - A tool for getting Word formatted XML from Bibtex files

bib2xml - A tool for getting Word formatted XML from Bibtex files Processes Bibtex files (.bib), produces Word Bibliography XML (.xml) output Why not

Matheus Sartor 1 May 05, 2022
ESteg - A simple steganography program for python

ESteg A simple steganography program to embed the contents of a text file into a

Jithin Renji 1 Jan 02, 2022
Rock-paper-scissors basic game in terminal with Python

piedra-papel-tijera Juego básico de piedra, papel o tijera en terminal con Python. El juego incluye: Nombre de jugador Número de veces a jugar Resulta

Isaías Flores 1 Dec 14, 2021
Ked interpreter built with Lex, Yacc and Python

Ked Ked is the first programming language known to hail from The People's Republic of Cork. It was first discovered and partially described by Adam Ly

Eoin O'Brien 1 Feb 08, 2022
A simple method to create strong password.

A simple method to create strong password.

1 Jan 23, 2022
SHF TEST BACKEND

➰ SHF TEST BACKEND ➿ 🐙 Goals Dada una matriz de números enteros. Obtenga el elemento máximo en la matriz que produce la suma más pequeña al agregar t

Wilmer Rodríguez S 1 Dec 19, 2021
Python script to preprocess images of all Pokémon to finetune ruDALL-E

ai-generated-pokemon-rudalle Python script to preprocess images of all Pokémon (the "official artwork" of each Pokémon via PokéAPI) into a format such

Max Woolf 132 Dec 11, 2022
Trackthis - This library can be used to track USPS and UPS shipments.

Trackthis - This library can be used to track USPS and UPS shipments. It has the option of returning the raw API response, or optionally, it can be used to standardize the USPS and UPS responses so t

Aaron Guzman 0 Mar 29, 2022
Datargsing is a data management and manipulation Python library

Datargsing What is It? Datargsing is a data management and manipulation Python library which is currently in deving Why this library is good? This Pyt

CHOSSY Lucas 10 Oct 24, 2022
Swim between bookmarks in the Windows terminal

Marlin Swim between bookmarks in the terminal! Marlin is an easy to use bookmark manager for the terminal. Choose a folder, bookmark it and swim there

wilfredinni 7 Nov 03, 2022
Types for the Rasterio package

types-rasterio Types for the rasterio package A work in progress Install Not yet published to PyPI pip install types-rasterio These type definitions

Kyle Barron 7 Sep 10, 2021
This app is to use algorithms to find the root of the equation

In this repository, I made an amazing app with tkinter python language and other libraries the idea of this app is to use algorithms to find the root of the equation I used three methods from numeric

Mohammad Al Jadallah 3 Sep 16, 2022