Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).

Overview

HiT-GAN Official TensorFlow Implementation

HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GANs). It achieves state-of-the-art performance for high-resolution image synthesis. Please check our NeurIPS 2021 paper "Improved Transformer for High-Resolution GANs" for more details.

This implementation is based on TensorFlow 2.x. We use tf.keras layers for building the model and use tf.data for our input pipeline. The model is trained using a custom training loop with tf.distribute on multiple TPUs/GPUs.

Environment setup

It is recommended to run distributed training to train our model with TPUs and evaluate it with GPUs. The code is compatible with TensorFlow 2.x. See requirements.txt for all prerequisites, and you can also install them using the following command.

pip install -r requirements.txt

ImageNet

At the first time, download ImageNet following tensorflow_datasets instruction from the official guide.

Train on ImageNet

To pretrain the model on ImageNet with Cloud TPUs, first check out the Google Cloud TPU tutorial for basic information on how to use Google Cloud TPUs.

Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for tensorflow_datasets, please set the following enviroment variables:

TPU_NAME=<tpu-name>
STORAGE_BUCKET=gs://<storage-bucket>
DATA_DIR=$STORAGE_BUCKET/<path-to-tensorflow-dataset>
MODEL_DIR=$STORAGE_BUCKET/<path-to-store-checkpoints>

The following command can be used to train a model on ImageNet (which reflects the default hyperparameters in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

To train the model on ImageNet with multiple GPUs, try the following command:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=False

Please set train_batch_size according to the number of GPUs for training. Note that storing Exponential Moving Average (EMA) models is not supported with GPUs currently (--use_ema_model=False), so training with GPUs will lead to slight performance drop.

Evaluate on ImageNet

Run the following command to evaluate the model on GPUs:

python run.py --mode=eval --dataset=imagenet2012 \
  --eval_batch_size=128 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --latent_dim=256 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

This command runs models with 8 P100 GPUs. Please set eval_batch_size according to the number of GPUs for evaluation. Please also note that train_steps and use_ema_model should be set according to the values used for training.

CelebA-HQ

At the first time, download CelebA-HQ following tensorflow_datasets instruction from the official guide.

Train on CelebA-HQ

The following command can be used to train a model on CelebA-HQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=celeb_a_hq/256 \
  --train_batch_size=256 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on CelebA-HQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=celeb_a_hq/256 \
  --eval_batch_size=128 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

FFHQ

At the first time, download the tfrecords of FFHQ from the official site and put them into $DATA_DIR.

Train on FFHQ

The following command can be used to train a model on FFHQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=ffhq/256 \
  --train_batch_size=256 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on FFHQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=ffhq/256 \
  --eval_batch_size=128 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

Cite

@inproceedings{zhao2021improved,
  title = {Improved Transformer for High-Resolution {GANs}},
  author = {Long Zhao and Zizhao Zhang and Ting Chen and Dimitris Metaxas abd Han Zhang},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year = {2021}
}

Disclaimer

This is not an officially supported Google product.

Heart Arrhythmia Classification

This program takes and input of an ECG in European Data Format (EDF) and outputs the classification for heartbeats into normal vs different types of arrhythmia . It uses a deep learning model for cla

4 Nov 02, 2022
Official implementation of NLOS-OT: Passive Non-Line-of-Sight Imaging Using Optimal Transport (IEEE TIP, accepted)

NLOS-OT Official implementation of NLOS-OT: Passive Non-Line-of-Sight Imaging Using Optimal Transport (IEEE TIP, accepted) Description In this reposit

Ruixu Geng(耿瑞旭) 16 Dec 16, 2022
Pre-trained models for a Cascaded-FCN in caffe and tensorflow that segments

Cascaded-FCN This repository contains the pre-trained models for a Cascaded-FCN in caffe and tensorflow that segments the liver and its lesions out of

300 Nov 22, 2022
Tensorflow 2 implementation of our high quality frame interpolation neural network

FILM: Frame Interpolation for Large Scene Motion Project | Paper | YouTube | Benchmark Scores Tensorflow 2 implementation of our high quality frame in

Google Research 1.6k Dec 28, 2022
Explainability of the Implications of Supervised and Unsupervised Face Image Quality Estimations Through Activation Map Variation Analyses in Face Recognition Models

Explainable_FIQA_WITH_AMVA Note This is the official repository of the paper: Explainability of the Implications of Supervised and Unsupervised Face I

3 May 08, 2022
Advancing Self-supervised Monocular Depth Learning with Sparse LiDAR

Official implementation for paper "Advancing Self-supervised Monocular Depth Learning with Sparse LiDAR"

Ziyue Feng 72 Dec 09, 2022
DeepLab-ResNet rebuilt in TensorFlow

DeepLab-ResNet-TensorFlow This is an (re-)implementation of DeepLab-ResNet in TensorFlow for semantic image segmentation on the PASCAL VOC dataset. Fr

Vladimir 1.2k Nov 04, 2022
PyTorch implementation of "Learn to Dance with AIST++: Music Conditioned 3D Dance Generation."

Learn to Dance with AIST++: Music Conditioned 3D Dance Generation. Installation pip install -r requirements.txt Prepare Dataset bash data/scripts/pre

Zj Li 8 Sep 07, 2021
Multi-Glimpse Network With Python

Multi-Glimpse Network Multi-Glimpse Network: A Robust and Efficient Classification Architecture based on Recurrent Downsampled Attention arXiv Require

9 May 10, 2022
Created as part of CS50 AI's coursework. This AI makes use of knowledge entailment to calculate the best probabilities to win Minesweeper.

Minesweeper-AI Created as part of CS50 AI's coursework. This AI makes use of knowledge entailment to calculate the best probabilities to win Minesweep

Beckham 0 Jul 20, 2022
a Lightweight library for sequential learning agents, including reinforcement learning

SaLinA: SaLinA - A Flexible and Simple Library for Learning Sequential Agents (including Reinforcement Learning) TL;DR salina is a lightweight library

Facebook Research 405 Dec 17, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
A deep learning CNN model to identify and classify and check if a person is wearing a mask or not.

Face Mask Detection The Model is designed to check if any human is wearing a mask or not. Dataset Description The Dataset contains a total of 11,792 i

1 Mar 01, 2022
Local Multi-Head Channel Self-Attention for FER2013

LHC-Net Local Multi-Head Channel Self-Attention This repository is intended to provide a quick implementation of the LHC-Net and to replicate the resu

12 Jan 04, 2023
Implementation of Graph Convolutional Networks in TensorFlow

Graph Convolutional Networks This is a TensorFlow implementation of Graph Convolutional Networks for the task of (semi-supervised) classification of n

Thomas Kipf 6.6k Dec 30, 2022
Several simple examples for popular neural network toolkits calling custom CUDA operators.

Neural Network CUDA Example Several simple examples for neural network toolkits (PyTorch, TensorFlow, etc.) calling custom CUDA operators. We provide

WeiYang 798 Jan 01, 2023
1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

Lihe Yang 209 Jan 01, 2023
Evaluation toolkit of the informative tracking benchmark comprising 9 scenarios, 180 diverse videos, and new challenges.

Informative-tracking-benchmark Informative tracking benchmark (ITB) higher diversity. It contains 9 representative scenarios and 180 diverse videos. m

Xin Li 15 Nov 26, 2022
PyTorch implementation of SimSiam: Exploring Simple Siamese Representation Learning

SimSiam: Exploring Simple Siamese Representation Learning This is a PyTorch implementation of the SimSiam paper: @Article{chen2020simsiam, author =

Facebook Research 834 Dec 30, 2022