GAN example for Keras. Cuz MNIST is too small and there should be something more realistic.

Overview

Keras-GAN-Animeface-Character

GAN example for Keras. Cuz MNIST is too small and there should an example on something more realistic.

Some results

Training for 22 epochs

Youtube Video, click on the image

Training for 22 epochs

Loss graph for 5000 mini-batches

Loss graph

1 mini-batch = 64 images. Dataset = 14490, hence 5000 mini-batches is approximately 22 epochs.

Some outputs of 5000th min-batch

Some ouptputs of 5000th mini-batch

Some training images

Some inputs

Useful resources, before you go on

How to run this example

Setup

  • My environment: Python 3.6 + Keras 2.0.4 + Tensorflow 1.x
    • If you are on Keras 2.0.0, you need to update it otherwise BatchNormalization() will cause bug, saying "you need to pass float to input" or something like that from Tensorflow back end.
  • Use virtualenv to initialize a similar environment (python and dependencies):
pip install virtualenv
virtualenv -p <PATH_TO_BIN_DIR>/python3.6 venv
source venv/bin/activate
pip install -r requirements.txt
  • I HATE making a program that has so many command line parameters to pass. Many of the parameters are there in the scripts. Adjust the script as you need. The "main()" function is at the bottom of the script as people do in C/C++
  • Most global parameters are defined in args.py.
    • They are defined as class variables not instance variables so you may have trouble running/training multiple instances of the GAN with different parameters. (which is very unlikely to happen)
  • Download dataset from http://www.nurs.or.jp/~nagadomi/animeface-character-dataset/
    • Extract it to this directory so that the scipt can find ./animeface-character-dataset/thumb/
    • Any dataset should work in principle but GAN is sensitive to hyperparameters and may not work on yours. I tuned the parameters for animeface-character-dataset.

Preprocessing

  • Run the preprocessing script. It saves training time to resize/scale the input than doing those tasks on the fly in the training loop.
    • ./data.py
    • The image, when loaded from PNG files, the RGB values have [0, 255]. (uint8 type). data.py will collect the images, resize the images to 64x64 and scale the RGB values so that they will be in [-1.0, 1.0] range.
    • Data.py will only sample a subset of the dataset if configured to do so. The size of the subset is determined by dataset_sz defined in args.py
    • The images will be written to data.hdf5.
      • Made it small to verify the training is working.
      • You can increase it but you need to adjust the network sizes accordingly.
    • Again, which files to read is defined in the script at the bottom, not by sys.argv.
  • You need a large enough dataset. Otherwise the discriminator will sort of "memorize" the true data and reject all that's generated.

Training

  • Open gan.py then at the bottom, uncomment train_autoenc() if you wish.
    • This is useful for seeing the generator network's capability to reproduce the input.
    • The auto-encoder will be trained on input images.
    • The output will be blurry, as the auto-encoder having mean-squared-error loss. (This is why GAN got invented in the first place!)
  • To run training, modify main() so that train_gan() is uncommented.
  • The script will dump reals.png and fakes.png every 10 epoch so that you can see how the training is going.
  • The training takes a while. For this example on Anime Face dataset, it took about 10000 mini-batches to get good results.
    • If you see only uniform color or "modern art" until 2000 then the training is not working!
  • The script also dumps weights every 10 batches. Utilize them to save training time. Weights before diverging is preferred :) Uncomment load_weights() in train_gan().

Training tips

What I experienced during my training of GAN.

  • As described in GAN Hacks, discriminator should be ahead of the generator so that the generator can be "guided" by the discriminator.
  • If you look at loss graph at https://github.com/osh/KerasGAN, they had gen loss in range of 2 to 4. Their training worked well. The discriminator loss is low, arond 0.1.
  • You'll need trial and error to get the hyper-pameters right so that the training stays in the stable, balanced zone. That includes learning rate of D and G, momentums, etc.
  • The convergence is quite sensitive with LR, beware!
  • If things go well, the discriminator loss for detecting real/fake = dloss0/dloss1 should be less than or around 0.1, which means it is good at telling whether the input is real or fake.
  • If learning rate is too high, the discriminator will diverge and one of the loss will get high and will not fall. Training fails in this case.
  • If you make LR too small, it will only slow the learning and will not prevent other issues such as oscillation. It only needs to be lower than certain threshold that is data dependent.
  • If adjusting LR doesn't work, it could be lack of complexity in the discriminator layer. Add more layers, or some other parameters. It could be anything :( Good luck!
  • On the other hand, generator loss will be relatively higher than discriminator loss. In this script, it oscillates in range 0.1 to 4.
  • If you see any of the D loss staying > 15 (when batch size is 32) the training is screwed.
  • In case of G loss > 15, see if it escapes within 30 batches. If it stays there for too long, it isn't good, I think.
  • In case you're seeing high G loss, it could mean it can't keep up with discriminator. You might need to increase LR. (Must be slower than discriminator though)
  • One final piece of the training I was missing was the parameter in BatchNormalization. I found about it in this link: https://github.com/shekkizh/neuralnetworks.thought-experiments/blob/master/Generative%20Models/GAN/Readme.md
    • Sort of interesting, in PyTorch, momentum parameter for BatchNorm is 0.1, according to the API documents, while in Keras it is 0.99. I'm not sure if 0.1 in PyTorch actually means 1 - 0.1. I didn't look into PyTorch backend implementation.
Deep Learning and Reinforcement Learning Library for Scientists and Engineers 🔥

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 29, 2022
Code for our EMNLP 2021 paper "Learning Kernel-Smoothed Machine Translation with Retrieved Examples"

KSTER Code for our EMNLP 2021 paper "Learning Kernel-Smoothed Machine Translation with Retrieved Examples" [paper]. Usage Download the processed datas

jiangqn 23 Nov 24, 2022
Repo for CVPR2021 paper "QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information"

QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information by Masato Tamura, Hiroki Ohashi, and Tomoaki Yosh

105 Dec 23, 2022
General neural ODE and DAE modules for power system dynamic modeling.

Py_PSNODE General neural ODE and DAE modules for power system dynamic modeling. The PyTorch-based ODE solver is developed based on torchdiffeq. Sample

14 Dec 31, 2022
Prompts - Read a textfile of prompts and import into anki via ankiconnect

prompts read a textfile of prompts and import into anki via ankiconnect Usage In

Alexander Cobleigh 2 Jul 28, 2022
A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery

PiSL A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery. Sun, F., Liu, Y. and Sun, H., 2021. Physics-informe

Fangzheng (Andy) Sun 8 Jul 13, 2022
PyTorch implementation of "Conformer: Convolution-augmented Transformer for Speech Recognition" (INTERSPEECH 2020)

PyTorch implementation of Conformer: Convolution-augmented Transformer for Speech Recognition. Transformer models are good at capturing content-based

Soohwan Kim 565 Jan 04, 2023
A PyTorch implementation of "ANEMONE: Graph Anomaly Detection with Multi-Scale Contrastive Learning", CIKM-21

ANEMONE A PyTorch implementation of "ANEMONE: Graph Anomaly Detection with Multi-Scale Contrastive Learning", CIKM-21 Dependencies python==3.6.1 dgl==

Graph Analysis & Deep Learning Laboratory, GRAND 30 Dec 14, 2022
Concept drift monitoring for HA model servers.

{Fast, Correct, Simple} - pick three Easily compare training and production ML data & model distributions Goals Boxkite is an instrumentation library

98 Dec 15, 2022
Lama-cleaner: Image inpainting tool powered by LaMa

Lama-cleaner: Image inpainting tool powered by LaMa

Qing 5.8k Jan 05, 2023
Official page of Struct-MDC (RA-L'22 with IROS'22 option); Depth completion from Visual-SLAM using point & line features

Struct-MDC (click the above buttons for redirection!) Official page of "Struct-MDC: Mesh-Refined Unsupervised Depth Completion Leveraging Structural R

Urban Robotics Lab. @ KAIST 37 Dec 22, 2022
Pytorch implementation of the unsupervised object discovery method LOST.

LOST Pytorch implementation of the unsupervised object discovery method LOST. More details can be found in the paper: Localizing Objects with Self-Sup

Valeo.ai 189 Dec 25, 2022
code for paper"A High-precision Semantic Segmentation Method Combining Adversarial Learning and Attention Mechanism"

PyTorch implementation of UAGAN(U-net Attention Generative Adversarial Networks) This repository contains the source code for the paper "A High-precis

Tong 8 Apr 25, 2022
Code for the SIGGRAPH 2021 paper "Consistent Depth of Moving Objects in Video".

Consistent Depth of Moving Objects in Video This repository contains training code for the SIGGRAPH 2021 paper "Consistent Depth of Moving Objects in

Google 203 Jan 05, 2023
Generic U-Net Tensorflow implementation for image segmentation

Tensorflow Unet Warning This project is discontinued in favour of a Tensorflow 2 compatible reimplementation of this project found under https://githu

Joel Akeret 1.8k Dec 10, 2022
An implementation of the Contrast Predictive Coding (CPC) method to train audio features in an unsupervised fashion.

CPC_audio This code implements the Contrast Predictive Coding algorithm on audio data, as described in the paper Unsupervised Pretraining Transfers we

Meta Research 283 Dec 30, 2022
Learning Temporal Consistency for Low Light Video Enhancement from Single Images (CVPR2021)

StableLLVE This is a Pytorch implementation of "Learning Temporal Consistency for Low Light Video Enhancement from Single Images" in CVPR 2021, by Fan

99 Dec 19, 2022
Fake News Detection Using Machine Learning Methods

Fake-News-Detection-Using-Machine-Learning-Methods Fake news is always a real and dangerous issue. However, with the presence and abundance of various

Achraf Safsafi 1 Jan 11, 2022
In the case of your data having only 1 channel while want to use timm models

timm_custom Description In the case of your data having only 1 channel while want to use timm models (with or without pretrained weights), run the fol

2 Nov 26, 2021
Full-featured Decision Trees and Random Forests learner.

CID3 This is a full-featured Decision Trees and Random Forests learner. It can save trees or forests to disk for later use. It is possible to query tr

Alejandro Penate-Diaz 3 Aug 15, 2022