Repository for "Improving evidential deep learning via multi-task learning," published in AAAI2022

Overview

Improving evidential deep learning via multi task learning

It is a repository of AAAI2022 paper, “Improving evidential deep learning via multi-task learning”, by Dongpin Oh and Bonggun Shin.

This repository contains the code to reproduce the Multi-task evidential neural network (MT-ENet), which uses the Lipschitz MSE loss function as the additional loss function of the evidential regression network (ENet). The Lipschitz MSE loss function can improve the accuracy of the ENet while preserving its uncertainty estimation capability, by avoiding gradient conflict with the NLL loss function—the original loss function of the ENet.

drawing

Setup

Please refer to "requirements.txt" for requring packages of this repo.

pip install -r requirements.txt

Training the ENet with the Lipschitz-MSE loss: example

from mtevi.mtevi import EvidentialMarginalLikelihood, EvidenceRegularizer, modified_mse
...
net = EvidentialNetwork() ## Evidential regression network
nll_loss = EvidentialMarginalLikelihood() ## original loss, NLL loss
reg = EvidenceRegularizer() ## evidential regularizer
mmse_loss = modified_mse ## lipschitz MSE loss
...
for inputs, labels in dataloader:
	gamma, nu, alpha, beta = net(inputs)
	loss = nll_loss(gamma, nu, alpha, beta, labels)
	loss += reg(gamma, nu, alpha, beta, labels)
	loss += mmse_loss(gamma, nu, alpha, beta, labels)
	loss.backward()	

Quick start

  • Synthetic data experiment.
python synthetic_exp.py
  • UCI regression benchmark experiments.
python uci_exp_norm -p energy
  • Drug target affinity (DTA) regression task on KIBA and Davis datasets.
python train_evinet.py -o test --type davis -f 0 --evi # ENet
python train_evinet.py -o test --type davis -f 0  # MT-ENet
  • Gradient conflict experiment on the DTA benchmarks
python check_conflict.py --type davis -f 0 # Conflict between the Lipschitz MSE (proposed) and NLL loss. 
python check_conflict.py --type davis -f 0 --abl # Conflict between the simple MSE loss and NLL loss.

Characteristic of the Lipschitz MSE loss

drawing

  • The Lipschitz MSE loss function can support training the ENet to more accurately predicts target values.
  • It regularizes its gradient to prevent gradient conflict with the NLL loss--the original loss function--if the NLL loss increases predictive uncertainty of the ENet.
  • Please check our paper for details.
Owner
deargen
deargen
The Pytorch code of "Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification", CVPR 2022 (Oral).

DeepBDC for few-shot learning        Introduction In this repo, we provide the implementation of the following paper: "Joint Distribution Matters: Dee

FeiLong 116 Dec 19, 2022
Dynamic Capacity Networks using Tensorflow

Dynamic Capacity Networks using Tensorflow Dynamic Capacity Networks (DCN; http://arxiv.org/abs/1511.07838) implementation using Tensorflow. DCN reduc

Taeksoo Kim 8 Feb 23, 2021
HDR Video Reconstruction: A Coarse-to-fine Network and A Real-world Benchmark Dataset (ICCV 2021)

Code for HDR Video Reconstruction HDR Video Reconstruction: A Coarse-to-fine Network and A Real-world Benchmark Dataset (ICCV 2021) Guanying Chen, Cha

Guanying Chen 64 Nov 19, 2022
🙄 Difficult algorithm, Simple code.

🎉TensorFlow2.0-Examples🎉! "Talk is cheap, show me the code." ----- Linus Torvalds Created by YunYang1994 This tutorial was designed for easily divin

1.7k Dec 25, 2022
FEMDA: Robust classification with Flexible Discriminant Analysis in heterogeneous data

FEMDA: Robust classification with Flexible Discriminant Analysis in heterogeneous data. Flexible EM-Inspired Discriminant Analysis is a robust supervised classification algorithm that performs well i

0 Sep 06, 2022
PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)

Score-Based Generative Modeling through Stochastic Differential Equations This repo contains a PyTorch implementation for the paper Score-Based Genera

Yang Song 757 Jan 04, 2023
A package for music online and offline rhythmic information analysis including music Beat, downbeat, tempo and meter tracking.

BeatNet A package for music online and offline rhythmic information analysis including music Beat, downbeat, tempo and meter tracking. This repository

Mojtaba Heydari 157 Dec 27, 2022
SBINN: Systems-biology informed neural network

SBINN: Systems-biology informed neural network The source code for the paper M. Daneker, Z. Zhang, G. E. Karniadakis, & L. Lu. Systems biology: Identi

Lu Group 15 Nov 19, 2022
Implementation of Gans

GAN Generative Adverserial Networks are an approach to generative data modelling using Deep learning methods. I have currently implemented : DCGAN on

Sibam Parida 5 Sep 07, 2021
A Kernel fuzzer focusing on race bugs

Razzer: Finding kernel race bugs through fuzzing Environment setup $ source scripts/envsetup.sh scripts/envsetup.sh sets up necessary environment var

Systems and Software Security Lab at Seoul National University (SNU) 328 Dec 26, 2022
Repositorio oficial del curso IIC2233 Programación Avanzada 🚀✨

IIC2233 - Programación Avanzada Evaluación Las evaluaciones serán efectuadas por medio de actividades prácticas en clases y tareas. Se calculará la no

IIC2233 @ UC 47 Sep 06, 2022
A paper using optimal transport to solve the graph matching problem.

GOAT A paper using optimal transport to solve the graph matching problem. https://arxiv.org/abs/2111.05366 Repo structure .github: Files specifying ho

neurodata 8 Jan 04, 2023
Red Team tool for exfiltrating files from a target's Google Drive that you have access to, via Google's API.

GD-Thief Red Team tool for exfiltrating files from a target's Google Drive that you(the attacker) has access to, via the Google Drive API. This includ

Antonio Piazza 39 Dec 27, 2022
Pytorch implementation of the paper Improving Text-to-Image Synthesis Using Contrastive Learning

T2I_CL This is the official Pytorch implementation of the paper Improving Text-to-Image Synthesis Using Contrastive Learning Requirements Linux Python

42 Dec 31, 2022
CAR-API: Cityscapes Attributes Recognition API

CAR-API: Cityscapes Attributes Recognition API This is the official api to download and fetch attributes annotations for Cityscapes Dataset. Content I

Kareem Metwaly 5 Dec 22, 2022
The backbone CSPDarkNet of YOLOX.

YOLOX-Backbone The backbone CSPDarkNet of YOLOX. In this project, you can enjoy: CSPDarkNet-S CSPDarkNet-M CSPDarkNet-L CSPDarkNet-X CSPDarkNet-Tiny C

Jianhua Yang 9 Aug 22, 2022
PointPillars inference with TensorRT

A project demonstrating how to use CUDA-PointPillars to deal with cloud points data from lidar.

NVIDIA AI IOT 315 Dec 31, 2022
Code for Mining the Benefits of Two-stage and One-stage HOI Detection

Status: Archive (code is provided as-is, no updates expected) PPO-EWMA [Paper] This is code for training agents using PPO-EWMA and PPG-EWMA, introduce

OpenAI 33 Dec 15, 2022
Clustergram - Visualization and diagnostics for cluster analysis in Python

Clustergram Visualization and diagnostics for cluster analysis Clustergram is a diagram proposed by Matthias Schonlau in his paper The clustergram: A

Martin Fleischmann 96 Dec 26, 2022
PyTorch Implementation of Realtime Multi-Person Pose Estimation project.

PyTorch Realtime Multi-Person Pose Estimation This is a pytorch version of Realtime_Multi-Person_Pose_Estimation, origin code is here Realtime_Multi-P

Dave Fang 157 Nov 12, 2022