PyTorch implementation of Soft-DTW: a Differentiable Loss Function for Time-Series in CUDA

Overview

Soft DTW Loss Function for PyTorch in CUDA

This is a Pytorch Implementation of Soft-DTW: a Differentiable Loss Function for Time-Series which is batch supported computation, CUDA-friendly, and feasible to use as a final loss. I can confirm that you can train a (sequential) model with this as a final loss! The following image shows training logs of a TTS model using the Soft-DTW Loss Function.

There are some previous implementations:

  1. mblondel's soft-dtw
  2. lyprince's sdtw_pytorch
  3. Maghoumi's pytorch-softdtw-cuda

But they are either not supported by CUDA-friendly batch computation or not considering the jacobean w.r.t input matrix, which is necessary to be used as a final loss in recent deep learning frameworks. In the current implementation, all conditions are satisfied.

Usage

Same as Maghoumi's pytorch-softdtw-cuda:

from sdtw_cuda_loss import SoftDTW

# Create the sequences
batch_size, len_x, len_y, dims = 8, 15, 12, 5
x = torch.rand((batch_size, len_x, dims), requires_grad=True)
y = torch.rand((batch_size, len_y, dims))

# Create the "criterion" object
sdtw = SoftDTW(use_cuda=True, gamma=0.1)

# Compute the loss value
loss = sdtw(x, y)  # Just like any torch.nn.xyzLoss()

# Aggregate and call backward()
loss.mean().backward()

But the backward will compute the gradient w.r.t input target sequence x (which is not considered in the previous work).

Note

In the current implementation, only use_cuda=True is supported. But you can easily implement the CPU version as in Maghoumi's pytorch-softdtw-cuda.

Citation

@misc{lee2021soft_dtw_loss,
  author = {Lee, Keon},
  title = {Soft-DTW-Loss},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/keonlee9420/Soft-DTW-Loss}}
}
You might also like...
Seach Losses of our paper 'Loss Function Discovery for Object Detection via Convergence-Simulation Driven Search', accepted by ICLR 2021.
Seach Losses of our paper 'Loss Function Discovery for Object Detection via Convergence-Simulation Driven Search', accepted by ICLR 2021.

CSE-Autoloss Designing proper loss functions for vision tasks has been a long-standing research direction to advance the capability of existing models

Multi-scale discriminator feature-wise loss function

Multi-Scale Discriminative Feature Loss This repository provides code for Multi-Scale Discriminative Feature (MDF) loss for image reconstruction algor

clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation
clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation

README clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation CVPR 2021 Authors: Suprosanna Shit and Johannes C. Paetzo

HistoSeg : Quick attention with multi-loss function for multi-structure segmentation in digital histology images

HistoSeg : Quick attention with multi-loss function for multi-structure segmentation in digital histology images Histological Image Segmentation This

Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation

SSWS-loss_function_based_on_MS-TCN Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation Supervised Sliding Window

[CVPR 2022] Official code for the paper:
[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

MDCA Calibration This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved

Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation"

DSP Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation". Accepted by ACM Multimedia 2021. Authors

Softlearning is a reinforcement learning framework for training maximum entropy policies in continuous domains. Includes the official implementation of the Soft Actor-Critic algorithm.

Softlearning Softlearning is a deep reinforcement learning toolbox for training maximum entropy policies in continuous domains. The implementation is

Decorators for maximizing memory utilization with PyTorch & CUDA

torch-max-mem This package provides decorators for memory utilization maximization with PyTorch and CUDA by starting with a maximum parameter size and

Comments
  • Does this supports multi-gpu training?

    Does this supports multi-gpu training?

    Thanks for sharing impl of soft-dtw, I can use it in single-gpu env,but can't use it in multi-gpu envs.Currently, it doesn't support multi-gpu training?

    opened by mayfool 2
  • how to use dtw-loss to fit a curve?

    how to use dtw-loss to fit a curve?

    hello, I tried to fit a curve (discrete points) using Soft-DTW-Loss as a loss function. But the loss does not converge to the exact result in the end. Is there something wrong with the way I am using it? The code is as follows:

    if name == "main":

    batch_size = 1
    len_x = 15
    len_predict = 10
    dims = 1
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    x = torch.unsqueeze(torch.linspace(1, 4, steps=len_x, requires_grad=True), dim=0)
    y = x ** 2
    y = y.view(1, len_x, 1)
    x = x.view(1, len_x, 1)
    
    #(batch,length,dims)---->(1,15,2)
    truth_points = torch.cat((y, x), dim=2).cuda()
    
    #(1,20)
    input = torch.unsqueeze(torch.linspace(1, 4, steps=len_predict*2, requires_grad=True), dim=0).cuda()
    
    
    class testNN(torch.nn.Module):
        def __init__(self):
            super(testNN, self).__init__()
            self.layer = nn.Sequential(
                nn.Linear(20, 50),
                nn.ReLU(),
                nn.Linear(50, 200),
                nn.ReLU(),
                nn.Linear(200, 50),
                nn.ReLU(),
                nn.Linear(50, 20),
                nn.ReLU(),
            )
        def forward(self, x):
            x = self.layer(x)
            return x
    
    
    test = testNN()
    test = test.to(device)
    
    loss_function = SoftDTW(use_cuda=True, gamma=0.01, normalize=False)
    optimizer = torch.optim.Adam(test.parameters(), lr=0.01)
    
    
    for epoch in range(1000):
    
    
        predict = test(input)
        #(1,20) reshape to (1,10,2)
        predict = predict.reshape(1, len_predict, 2)
        loss = loss_function(predict, truth_points)
        optimizer.zero_grad()
        loss.mean().backward(retain_graph=True)
        optimizer.step()
    
    
        if epoch % 10 == 0:
            print("epoch : %d | loss : %f" % (epoch, loss))
            plt_predict = predict.cpu().detach().numpy()
            # print(plt_predict)
            plt_predict = plt_predict.reshape(1, len_predict, 2)
            print(plt_predict[0, :, 0])
            print(plt_predict[0, :, 1])
    
    opened by visionlyx 0
Releases(v1.0.0)
Owner
Keon Lee
Expressive Speech Synthesis | Conversational AI | Open-domain Dialog | NLP | Generative Models | Empathic Computing | HCI
Keon Lee
This is the repository of the NeurIPS 2021 paper "Curriculum Disentangled Recommendation withNoisy Multi-feedback"

Curriculum_disentangled_recommendation This is the repository of the NeurIPS 2021 paper "Curriculum Disentangled Recommendation with Noisy Multi-feedb

14 Dec 20, 2022
Mixed Neural Likelihood Estimation for models of decision-making

Mixed neural likelihood estimation for models of decision-making Mixed neural likelihood estimation (MNLE) enables Bayesian parameter inference for mo

mackelab 9 Dec 22, 2022
Lyapunov-guided Deep Reinforcement Learning for Stable Online Computation Offloading in Mobile-Edge Computing Networks

PyTorch code to reproduce LyDROO algorithm [1], which is an online computation offloading algorithm to maximize the network data processing capability subject to the long-term data queue stability an

Liang HUANG 87 Dec 28, 2022
In this project, we develop a face recognize platform based on MTCNN object-detection netcwork and FaceNet self-supervised network.

模式识别大作业——人脸检测与识别平台 本项目是一个简易的人脸检测识别平台,提供了人脸信息录入和人脸识别的功能。前端采用 html+css+js,后端采用 pytorch,

Xuhua Huang 5 Aug 02, 2022
S2-BNN: Bridging the Gap Between Self-Supervised Real and 1-bit Neural Networks via Guided Distribution Calibration (CVPR 2021)

S2-BNN (Self-supervised Binary Neural Networks Using Distillation Loss) This is the official pytorch implementation of our paper: "S2-BNN: Bridging th

Zhiqiang Shen 52 Dec 24, 2022
A TensorFlow implementation of DeepMind's WaveNet paper

A TensorFlow implementation of DeepMind's WaveNet paper This is a TensorFlow implementation of the WaveNet generative neural network architecture for

Igor Babuschkin 5.3k Dec 28, 2022
Official implementation of "One-Shot Voice Conversion with Weight Adaptive Instance Normalization".

One-Shot Voice Conversion with Weight Adaptive Instance Normalization By Shengjie Huang, Yanyan Xu*, Dengfeng Ke*, Mingjie Chen, Thomas Hain. This rep

31 Dec 07, 2022
A python tutorial on bayesian modeling techniques (PyMC3)

Bayesian Modelling in Python Welcome to "Bayesian Modelling in Python" - a tutorial for those interested in learning how to apply bayesian modelling t

Mark Regan 2.4k Jan 06, 2023
Collections for the lasted paper about multi-view clustering methods (papers, codes)

Multi-View Clustering Papers Collections for the lasted paper about multi-view clustering methods (papers, codes). There also exists some repositories

Andrew Guan 10 Sep 20, 2022
PyTorch implementation of SIFT descriptor

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
Code for "Multi-Time Attention Networks for Irregularly Sampled Time Series", ICLR 2021.

Multi-Time Attention Networks (mTANs) This repository contains the PyTorch implementation for the paper Multi-Time Attention Networks for Irregularly

The Laboratory for Robust and Efficient Machine Learning 68 Dec 17, 2022
Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Train longer, generalize better - Big batch training This is a code repository used to generate the results appearing in "Train longer, generalize bet

Elad Hoffer 145 Sep 16, 2022
1st Place Solution to ECCV-TAO-2020: Detect and Represent Any Object for Tracking

Instead, two models for appearance modeling are included, together with the open-source BAGS model and the full set of code for inference. With this code, you can achieve around 79 Oct 08, 2022

Deep Two-View Structure-from-Motion Revisited

Deep Two-View Structure-from-Motion Revisited This repository provides the code for our CVPR 2021 paper Deep Two-View Structure-from-Motion Revisited.

Jianyuan Wang 145 Jan 06, 2023
CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation

CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation We propose a novel approach to translate unpaired contrast computed

Nicolae Catalin Ristea 13 Jan 02, 2023
Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)"

Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)" which introduces a new class of deep generative models that gene

Guan-Horng Liu 43 Jan 03, 2023
Repository for self-supervised landmark discovery

self-supervised-landmarks Repository for self-supervised landmark discovery Requirements pytorch pynrrd (for 3d images) Usage The use of this models i

Riddhish Bhalodia 2 Apr 18, 2022
Implementation of "Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification"

hypergraph_reid Implementation of "Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification" If you find this help your research,

62 Dec 21, 2022
Point Cloud Denoising input segmentation output raw point-cloud valid/clear fog rain de-noised Abstract Lidar sensors are frequently used in environme

Point Cloud Denoising input segmentation output raw point-cloud valid/clear fog rain de-noised Abstract Lidar sensors are frequently used in environme

75 Nov 24, 2022
x-transformers-paddle 2.x version

x-transformers-paddle x-transformers-paddle 2.x version paddle 2.x版本 https://github.com/lucidrains/x-transformers 。 requirements paddlepaddle-gpu==2.2

yujun 7 Dec 08, 2022