Tensorflow implementation and notebooks for Implicit Maximum Likelihood Estimation

Related tags

Deep Learningtf-imle
Overview

tf-imle

Tensorflow 2 and PyTorch implementation and Jupyter notebooks for Implicit Maximum Likelihood Estimation (I-MLE) proposed in the NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

I-MLE is also available as a PyTorch library: https://github.com/uclnlp/torch-imle

Introduction

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear programming (ILP) solvers, as well as complex discrete probability distributions in standard deep learning architectures. The figure below illustrates the setting I-MLE was developed for. is a standard neural network, mapping some input to the input parameters of a discrete combinatorial optimization algorithm or a discrete probability distribution, depicted as the black box. In the forward pass, the discrete component is executed and its discrete output fed into a downstream neural network . Now, with I-MLE it is possible to estimate gradients of with respect to a loss function, which are used during backpropagation to update the parameters of the upstream neural network.

Illustration of the problem addressed by I-MLE

The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and possibly intractable distribution. For this we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to such an empirical distribution, we have to design surrogate empirical distributions which we term target distributions. Here we propose two families of target distributions which are widely applicable and work well in practice.

Requirements:

TensorFlow 2 implementation:

  • tensorflow==2.3.0 or tensorflow-gpu==2.3.0
  • numpy==1.18.5
  • matplotlib==3.1.1
  • scikit-learn==0.24.1
  • tensorflow-probability==0.7.0

PyTorch implementation:

Example: I-MLE as a Layer

The following is an instance of I-MLE implemented as a layer. This is a class where the optimization problem is computing the k-subset configuration, the target distribution is based on perturbation-based implicit differentiation, and the perturb-and-MAP noise perturbations are drawn from the sum-of-gamma distribution.

class IMLESubsetkLayer(tf.keras.layers.Layer):
    
    def __init__(self, k, _tau=10.0, _lambda=10.0):
        super(IMLESubsetkLayer, self).__init__()
        # average number of 1s in a solution to the optimization problem
        self.k = k
        # the temperature at which we want to sample
        self._tau = _tau
        # the perturbation strength (here we use a target distribution based on perturbation-based implicit differentiation
        self._lambda = _lambda  
        # the samples we store for the backward pass
        self.samples = None 
        
    @tf.function
    def sample_sum_of_gamma(self, shape):
        
        s = tf.map_fn(fn=lambda t: tf.random.gamma(shape, 1.0/self.k, self.k/t), 
                  elems=tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))   
        # now add the samples
        s = tf.reduce_sum(s, 0)
        # the log(m) term
        s = s - tf.math.log(10.0)
        # divide by k --> each s[c] has k samples whose sum is distributed as Gumbel(0, 1)
        s = self._tau * (s / self.k)

        return s
    
    @tf.function
    def sample_discrete_forward(self, logits): 
        self.samples = self.sample_sum_of_gamma(tf.shape(logits))
        gamma_perturbed_logits = logits + self.samples
        # gamma_perturbed_logits is the input to the combinatorial opt algorithm
        # the next two lines can be replaced by a custom black-box algorithm call
        threshold = tf.expand_dims(tf.nn.top_k(gamma_perturbed_logits, self.k, sorted=True)[0][:,-1], -1)
        y = tf.cast(tf.greater_equal(gamma_perturbed_logits, threshold), tf.float32)
        
        return y
    
    @tf.function
    def sample_discrete_backward(self, logits):     
        gamma_perturbed_logits = logits + self.samples
        # gamma_perturbed_logits is the input to the combinatorial opt algorithm
        # the next two lines can be replaced by a custom black-box algorithm call
        threshold = tf.expand_dims(tf.nn.top_k(gamma_perturbed_logits, self.k, sorted=True)[0][:,-1], -1)
        y = tf.cast(tf.greater_equal(gamma_perturbed_logits, threshold), tf.float32)
        return y
    
    @tf.custom_gradient
    def subset_k(self, logits, k):

        # sample discretely with perturb and map
        z_train = self.sample_discrete_forward(logits)
        # compute the top-k discrete values
        threshold = tf.expand_dims(tf.nn.top_k(logits, self.k, sorted=True)[0][:,-1], -1)
        z_test = tf.cast(tf.greater_equal(logits, threshold), tf.float32)
        # at training time we sample, at test time we take the argmax
        z_output = K.in_train_phase(z_train, z_test)
        
        def custom_grad(dy):

            # we perturb (implicit diff) and then resuse sample for perturb and MAP
            map_dy = self.sample_discrete_backward(logits - (self._lambda*dy))
            # we now compute the gradients as the difference (I-MLE gradients)
            grad = tf.math.subtract(z_train, map_dy)
            # return the gradient            
            return grad, k

        return z_output, custom_grad

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
Owner
NEC Laboratories Europe
Research software developed at NEC Laboratories Europe
NEC Laboratories Europe
Entity-Based Knowledge Conflicts in Question Answering.

Entity-Based Knowledge Conflicts in Question Answering Run Instructions | Paper | Citation | License This repository provides the Substitution Framewo

Apple 35 Oct 19, 2022
Using Convolutional Neural Networks (CNN) for Semantic Segmentation of Breast Cancer Lesions (BRCA)

Using Convolutional Neural Networks (CNN) for Semantic Segmentation of Breast Cancer Lesions (BRCA). Master's thesis documents. Bibliography, experiments and reports.

Erick Cobos 73 Dec 04, 2022
This repo provides code for QB-Norm (Cross Modal Retrieval with Querybank Normalisation)

This repo provides code for QB-Norm (Cross Modal Retrieval with Querybank Normalisation) Usage example python dynamic_inverted_softmax.py --sims_train

36 Dec 29, 2022
Official implementation of "OpenPifPaf: Composite Fields for Semantic Keypoint Detection and Spatio-Temporal Association" in PyTorch.

openpifpaf Continuously tested on Linux, MacOS and Windows: New 2021 paper: OpenPifPaf: Composite Fields for Semantic Keypoint Detection and Spatio-Te

VITA lab at EPFL 50 Dec 29, 2022
Convert Table data to approximate values with GUI

Table_Editor Convert Table data to approximate values with GUIs... usage - Import methods for extension Tables. Imported method supposed to have only

CLJ 1 Jan 10, 2022
Wav2Vec for speech recognition, classification, and audio classification

Soxan در زبان پارسی به نام سخن This repository consists of models, scripts, and notebooks that help you to use all the benefits of Wav2Vec 2.0 in your

Mehrdad Farahani 140 Dec 15, 2022
Demos of essentia classifiers hosted on replicate.ai

essentia-replicate-demos Demos of Essentia models hosted on replicate.ai's MTG site. The models Check our site for a complete list of the models avail

Music Technology Group - Universitat Pompeu Fabra 12 Nov 14, 2022
DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data.

DWIPrep: A Robust Preprocessing Pipeline for dMRI Data DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data. The transp

Gal Ben-Zvi 1 Jan 09, 2023
This folder contains the python code of UR5E's advanced forward kinematics model.

This folder contains the python code of UR5E's advanced forward kinematics model. By entering the angle of the joint of UR5e, the detailed coordinates of up to 48 points around the robot arm can be c

Qiang Wang 4 Sep 17, 2022
RobustVideoMatting and background composing in one model by using onnxruntime.

RVM_onnx_compose RobustVideoMatting and background composing in one model by using onnxruntime. Usage pip install -r requirements.txt python infer_cam

Quantum Liu 4 Apr 07, 2022
Semi-Supervised Learning for Fine-Grained Classification

Semi-Supervised Learning for Fine-Grained Classification This repo contains the code of: A Realistic Evaluation of Semi-Supervised Learning for Fine-G

25 Nov 08, 2022
Official Pytorch implementation of "Beyond Static Features for Temporally Consistent 3D Human Pose and Shape from a Video", CVPR 2021

TCMR: Beyond Static Features for Temporally Consistent 3D Human Pose and Shape from a Video Qualtitative result Paper teaser video Introduction This r

Hongsuk Choi 215 Jan 06, 2023
Open source simulator for autonomous vehicles built on Unreal Engine / Unity, from Microsoft AI & Research

Welcome to AirSim AirSim is a simulator for drones, cars and more, built on Unreal Engine (we now also have an experimental Unity release). It is open

Microsoft 13.8k Jan 05, 2023
Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models"

Introduction Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models". In this work, we demonstrate that existi

Wei-Cheng Tseng 7 Nov 01, 2022
Conditional Gradients For The Approximately Vanishing Ideal

Conditional Gradients For The Approximately Vanishing Ideal Code for the paper: Wirth, E., and Pokutta, S. (2022). Conditional Gradients for the Appro

IOL Lab @ ZIB 0 May 25, 2022
Official code of Team Yao at Multi-Modal-Fact-Verification-2022

Official code of Team Yao at Multi-Modal-Fact-Verification-2022 A Multi-Modal Fact Verification dataset released as part of the De-Factify workshop in

Wei-Yao Wang 11 Nov 15, 2022
Like ThreeJS but for Python and based on wgpu

pygfx A render engine, inspired by ThreeJS, but for Python and targeting Vulkan/Metal/DX12 (via wgpu). Introduction This is a Python render engine bui

139 Jan 07, 2023
A paper using optimal transport to solve the graph matching problem.

GOAT A paper using optimal transport to solve the graph matching problem. https://arxiv.org/abs/2111.05366 Repo structure .github: Files specifying ho

neurodata 8 Jan 04, 2023
Google Brain - Ventilator Pressure Prediction

Google Brain - Ventilator Pressure Prediction https://www.kaggle.com/c/ventilator-pressure-prediction The ventilator data used in this competition was

Samuele Cucchi 1 Feb 11, 2022
An air quality monitoring service with a Raspberry Pi and a SDS011 sensor.

Raspberry Pi Air Quality Monitor A simple air quality monitoring service for the Raspberry Pi. Installation Clone the repository and run the following

rydercalmdown 24 Dec 09, 2022