A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
DownTime-Score is a Small project aimed to Monitor the performance and the availabillity of a variety of the Vital and Critical Moroccan Web Portals

DownTime-Score DownTime-Score is a Small project aimed to Monitor the performance and the availabillity of a variety of the Vital and Critical Morocca

adnane-tebbaa 5 Apr 30, 2022
NYCU(NCTU)-差勤-助教

NCTU-TA-fill 填寫 差勤-助教時數 有沒有覺得在差勤系統填助教時數有點浪費生命? 今天有個懶鬼浪費好多時間幫大家寫了code 只要填好的必要的資料,就可以讓電腦自動幫你完成差勤助教的時數填寫喔! https://pt-attendance.nctu.edu.tw/verify/userL

14 Dec 21, 2021
Automation in socks label validation

This is a project for socks card label validation where the socks card is validated comparing with the correct socks card whose coordinates are stored in the database. When the test socks card is com

1 Jan 19, 2022
EasyBuild is a software build and installation framework that allows you to manage (scientific) software on High Performance Computing (HPC) systems in an efficient way.

EasyBuild is a software build and installation framework that allows you to manage (scientific) software on High Performance Computing (HPC) systems in an efficient way.

EasyBuild community 87 Dec 27, 2022
Addon to give a keybind to automatically enable contact shadows on all lights in a scene

3-2-1 Contact(Shadow) An easy way to let you enable contact shadows on all your lights, because Blender doesn't enable it by default, and doesn't give

TDV Alinsa 3 Feb 02, 2022
With Christmas and New Year ahead, it is time for some festive coding. Here is a Christmas Card for you all!

Christmas Card With Christmas and New Year ahead, it is time for some festive coding! Here is a Christmas Card for you all! NOTE: I have not made this

CodeMaster7000 1 Dec 25, 2021
Gba-free-fonts - Free font resources for GBA game development

gba-free-fonts Free font resources for GBA game development This repo contains m

28 Dec 30, 2022
Like Docker, but for Squeak. You know, for kids.

Squeaker Like Docker, but for Smalltalk images. You know, for kids. It's a small program that helps in automated derivation of configured Smalltalk im

Tony Garnock-Jones 14 Sep 11, 2022
Convert text with ANSI color codes to HTML or to LaTeX.

Convert text with ANSI color codes to HTML or to LaTeX.

PyContribs 326 Dec 28, 2022
Block the annoying Token Grabbers on your discord

General We have seen that in the last time many discord servers are infected by fake discord nitro links we want to put an end to this and have develo

BadTiger Network 2 Jul 16, 2022
My attempt at this years Advent of Code!

Advent-of-code-2021 My attempt at this years Advent of Code! day 1: ** day 2: ** day 3: ** day 4: ** day 5: ** day 6: ** day 7: ** day 8: * day 9: day

1 Jul 06, 2022
Este projeto se trata de uma análise de campanhas de marketing de uma empresa que vende acessórios para veículos.

Marketing Campaigns Este projeto se trata de uma análise de campanhas de marketing de uma empresa que vende acessórios para veículos. 1. Problema A em

Bibiana Prevedello 1 Jan 12, 2022
A tool for RaceRoom Racing Experience which shows you launch data

R3E Launch Tool A tool for RaceRoom Racing Experience which shows you launch data. Usage Run the tool, change the Stop Speed to whatever you want, and

Yuval Rosen 2 Feb 01, 2022
A demo of a data science project using Kedro

iris Overview This is your new Kedro project, which was generated using Kedro 0.17.4. Take a look at the Kedro documentation to get started. Rules and

Khuyen Tran 14 Oct 14, 2022
Track testrail productivity in automated reporting to multiple teams

django_web_app_for_testrail testrail is a test case management tool which helps any organization to track all consumption and testing of manual and au

Vignesh 2 Nov 21, 2021
Materials and information for my PyCascades 2021 Presentation

Materials and information for PyCascades 2021 Presentation: Sparking Creativity in LED Art with CircuitPython

GeekMomProjects 19 May 04, 2022
Programmatic startup/shutdown of ASGI apps.

asgi-lifespan Programmatically send startup/shutdown lifespan events into ASGI applications. When used in combination with an ASGI-capable HTTP client

Florimond Manca 129 Dec 27, 2022
Blender Addon for Snapping a UV to a specific part of a Tilemap

UVGridSnapper A simple Blender Addon for easier texturing. A menu in the UV editor allows a square UV to be snapped to an Atlas texture, or Tilemap. P

2 Jul 17, 2022
To effectively detect the faulty wafers

wafer_fault_detection Aim of the project: In electronics, a wafer (also called a slice or substrate) is a thin slice of semiconductor, such as crystal

Arun Singh Babal 1 Nov 06, 2021
Amitkumar Mishra 2 Jan 14, 2022