High-fidelity performance metrics for generative models in PyTorch

Overview

High-fidelity performance metrics for generative models in PyTorch

Documentation Status TestStatus PyPiVersion PyPiDownloads Twitter Follow

This repository provides precise, efficient, and extensible implementations of the popular metrics for generative model evaluation, including:

  • Inception Score (ISC)
  • Fréchet Inception Distance (FID)
  • Kernel Inception Distance (KID)
  • Perceptual Path Length (PPL)

Precision: Unlike many other reimplementations, the values produced by torch-fidelity match reference implementations up to machine precision. This allows using torch-fidelity for reporting metrics in papers instead of scattered and slow reference implementations. Read more about precision

Efficiency: Feature sharing between different metrics saves recomputation time, and an additional caching level avoids recomputing features and statistics whenever possible. High efficiency allows using torch-fidelity in the training loop, for example at the end of every epoch. Read more about efficiency

Extensibility: Going beyond 2D image generation is easy due to high modularity and abstraction of the metrics from input data, models, and feature extractors. For example, one can swap out InceptionV3 feature extractor for a one accepting 3D scan volumes, such as used in MRI. Read more about extensibility

TLDR; fast and reliable GAN evaluation in PyTorch

Installation

pip install torch-fidelity

See also: Installing the latest GitHub code

Usage Examples with Command Line

Below are three examples of using torch-fidelity to evaluate metrics from the command line. See more examples in the documentation.

Simple

Inception Score of CIFAR-10 training split:

> fidelity --gpu 0 --isc --input1 cifar10-train

inception_score_mean: 11.23678
inception_score_std: 0.09514061

Medium

Inception Score of a directory of images stored in ~/images/:

> fidelity --gpu 0 --isc --input1 ~/images/

Pro

Efficient computation of ISC and PPL for input1, and FID and KID between a generative model stored in ~/generator.onnx and CIFAR-10 training split:

> fidelity \
  --gpu 0 \
  --isc \
  --fid \
  --kid \
  --ppl \
  --input1 ~/generator.onnx \ 
  --input1-model-z-type normal \
  --input1-model-z-size 128 \
  --input1-model-num-samples 50000 \ 
  --input2 cifar10-train 

See also: Other usage examples

Quick Start with Python API

When it comes to tracking the performance of generative models as they train, evaluating metrics after every epoch becomes prohibitively expensive due to long computation times. torch_fidelity tackles this problem by making full use of caching to avoid recomputing common features and per-metric statistics whenever possible. Computing all metrics for 50000 32x32 generated images and cifar10-train takes only 2 min 26 seconds on NVIDIA P100 GPU, compared to >10 min if using original codebases. Thus, computing metrics 20 times over the whole training cycle makes overall training time just one hour longer.

In the following example, assume unconditional image generation setting with CIFAR-10, and the generative model generator, which takes a 128-dimensional standard normal noise vector.

First, import the module:

import torch_fidelity

Add the following lines at the end of epoch evaluation:

wrapped_generator = torch_fidelity.GenerativeModelModuleWrapper(generator, 128, 'normal', 0)

metrics_dict = torch_fidelity.calculate_metrics(
    input1=wrapped_generator, 
    input2='cifar10-train', 
    cuda=True, 
    isc=True, 
    fid=True, 
    kid=True, 
    verbose=False,
)

The resulting dictionary with computed metrics can logged directly to tensorboard, wandb, or console:

print(metrics_dict)

Output:

{
    'inception_score_mean': 11.23678, 
    'inception_score_std': 0.09514061, 
    'frechet_inception_distance': 18.12198,
    'kernel_inception_distance_mean': 0.01369556, 
    'kernel_inception_distance_std': 0.001310059
}

See also: Full API reference

Example of Integration with the Training Loop

Refer to sngan_cifar10.py for a complete training example.

Evolution of fixed generator latents in the example:

Evolution of fixed generator latents

A generator checkpoint resulting from training the example can be downloaded here.

Citation

Citation is recommended to reinforce the evaluation protocol in works relying on torch-fidelity. To ensure reproducibility when citing this repository, use the following BibTeX:

@misc{obukhov2020torchfidelity,
  author={Anton Obukhov and Maximilian Seitzer and Po-Wei Wu and Semen Zhydenko and Jonathan Kyl and Elvis Yu-Jing Lin},
  year=2020,
  title={High-fidelity performance metrics for generative models in PyTorch},
  url={https://github.com/toshas/torch-fidelity},
  publisher={Zenodo},
  version={v0.3.0},
  doi={10.5281/zenodo.4957738},
  note={Version: 0.3.0, DOI: 10.5281/zenodo.4957738}
}
Owner
Vikram Voleti
PhD student at Mila, University of Montreal
Vikram Voleti
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
3D-RETR: End-to-End Single and Multi-View3D Reconstruction with Transformers

3D-RETR: End-to-End Single and Multi-View 3D Reconstruction with Transformers (BMVC 2021) Zai Shi*, Zhao Meng*, Yiran Xing, Yunpu Ma, Roger Wattenhofe

Zai Shi 36 Dec 21, 2022
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

878 Dec 30, 2022
Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

PyTorch Implementation of Differentiable ODE Solvers This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpr

Ricky Chen 4.4k Jan 04, 2023
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

169 Dec 23, 2022
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library A library of tested, GPU implementations of core structured prediction algorithms for deep learning applic

HNLP 1.1k Jan 07, 2023
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
PyTorch toolkit for biomedical imaging

farabio is a minimal PyTorch toolkit for out-of-the-box deep learning support in biomedical imaging. For further information, see Wikis and Docs.

San Askaruly 47 Dec 28, 2022
TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

1k Dec 28, 2022
A few Windows specific scripts for PyTorch

It is a repo that contains scripts that makes using PyTorch on Windows easier. Easy Installation Update: Starting from 0.4.0, you can go to the offici

408 Dec 15, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
High-level batteries-included neural network training library for Pytorch

Pywick High-Level Training framework for Pytorch Pywick is a high-level Pytorch training framework that aims to get you up and running quickly with st

382 Dec 06, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022